yilundu / improved_contrastive_divergence

[ICML'21] Improved Contrastive Divergence Training of Energy Based Models
61 stars 14 forks source link

no label_map function in ResNetModel #6

Closed scilearner closed 3 years ago

scilearner commented 3 years ago

Hi Thanks for the impressive papers(this one and IGEBM) and the pytorch code. I have two problems about the conditional training.

  1. I found latent = self.label_map(latent) the label_map function is missed in ResNetModel for cifar10. But MNISTModel and CelebAModel have it. I'm not sure which one and the layers' shapes you use for cifar10.

  2. What's the difference between your unconditional and conditional EBM models. Conditional EBM training uses the conditional gains and biases per class [Dumoulin et al.] for conditional models, so the conditional one can have ~50% classification accuracy, while I think the unconditional one is not able to classify or cluster well, right?

Thanks a lot.

yilundu commented 3 years ago

Hi,

1) Thanks for your questions. The self.label_map function maps an input 1 hot label into a latent embedding which is then embed to conditional gains and biases in each resnet block. For ResNetModel I actually do not map the 1 hot label into a latent embedding. Instead the one hot vector is directly mapped into conditional gains and biases in the resnet block. I think the code should hopefully directly run for cifar10. Please let me know if this is not the case.

2) Whether a model is conditional or unconditional is determined by FLAGS.cond keyword argument. When FLAGS.cond is False, then no conditional gain is passed into a residual network, while is it is True, conditional gains and biases are not passed into the residual network. The unconditional one is not able to classify.

Let me know if you have any questions! Yilun

scilearner commented 3 years ago

Thanks for your reply.

  1. I ran python train.py --exp=cifar10_model --step_lr=100.0 --num_steps=40 --cuda --ensembles=1 --kl_coeff=1.0 --kl=True --self_attn --cond with the modification you mentioned https://github.com/yilundu/improved_contrastive_divergence/blob/master/models.py#L394 image

It fails with the error when it compute with the latent x = self.res_1a(x, latent) image

I also have tried to add the mapping the 1 hot label into a latent embedding, then it worked. So I think there's a small mistake in the code.

2 Thank you very much!

yilundu commented 3 years ago

Hi, thanks for raising the issue. I've pushed some changes to the code so that it should run on the CIFAR-10 conditional setting.