phlippe / CategoricalNF

Official repository for "Categorical Normalizing Flows via Continuous Transformations"
https://arxiv.org/abs/2006.09790
MIT License
55 stars 11 forks source link

The implementation of your importance sampling for estimating likelihood #2

Closed Xiaohui9607 closed 1 year ago

Xiaohui9607 commented 1 year ago

Hi, I was wondering which part is the implementation of your importance sampling? I couldn't find it...

phlippe commented 1 year ago

Hi, you are right that this seems to be missing from the code, not sure whether I had done it for simplicity back then since the scores didn't differ significantly or whether it was actually missing. In any case, if you want to re-introduce the importance sampling during testing, you can replace the following lines in def eval(...) function in task.py:

batch_size = batch[0].size(0) if isinstance(batch, tuple) else batch.size(0)
batch_nll = self._eval_batch(batch, is_test=is_test)

replace with

batch_size = batch[0].size(0) if isinstance(batch, tuple) else batch.size(0)
if not is_test:
    batch_nll = self._eval_batch(batch, is_test=False)
else:
    all_ll = []
    for _ in range(num_rep):
        batch_ll, _ = -self._eval_batch(batch, is_test=True)
        all_ll.append(batch_ll)
    batch_ll = torch.logsumexp(torch.stack(all_ll, dim=-1), dim=-1) - np.log(num_rep)
    batch_nll = -batch_ll.mean()

Additionally, the task you are interested should not average the batch dimension of the NLL score under is_test.

Xiaohui9607 commented 1 year ago

logsumexp should operate on ll, not nll right?

phlippe commented 1 year ago

True, I had in my code that the eval_batch function actually returned log-likelihoods if is_test=True. Sorry for forgetting it here, I'll adjust the code