vislearn / conditional_INNs

Code for the paper "Guided Image Generation with Conditional Invertible Neural Networks" (2019)
93 stars 22 forks source link

is NLL correct? #5

Open MokkeMeguru opened 4 years ago

MokkeMeguru commented 4 years ago

Hello, I wonder your nll is correct?

In, https://github.com/VLL-HD/conditional_invertible_neural_networks/blob/master/mnist_minimal_example/eval.py#L44-L45

I think, your z is the latent variable, and jac is log_det_jacobian in normalizing flow.

 nll = (log_prob(z) + log_det_jacobian) / pixels

But I think you forgot the transformation cost about a discrete image to a continuous image.

In RealNVP, he calculates its cost image by image. https://github.com/tensorflow/models/blob/master/research/real_nvp/real_nvp_multiscale_dataset.py#L1063-L1077 (in Glow, Flow++, we can find this cost function, too)

psteinb commented 4 years ago

maybe related to this, the magnitude of the trainable parameters theta is missing.

The preprint reads: image but AFAIK the minimal examples do not list this last term: https://github.com/VLL-HD/conditional_invertible_neural_networks/blob/45dc7250ebfecf10d1a278edafde0fe899f30aa1/colorization_minimal_example/train.py#L25

I may oversee something or tau*norm(theta)**2 is always constant which is hard to believe at plain sight.

MokkeMeguru commented 4 years ago

@psteinb theta means the model's parameter. So the term of theta means l2 normalization in the whole parameter such as inv1x1conv networks etc. (l2 normalization often applied implicitly because Tensorflow and PyTorch do so with some layer's argument like l2_normalizaiont=True . (I hate this implicitly loss pollution.))

Tips: l2 normalization is good for inv1x1conv which is discussed in Open AI's Glow https://github.com/openai/glow/issues/40#issuecomment-462103120

psteinb commented 4 years ago

thanks for the hint. I must confess that I think I understand the math but maybe I am just too new to pytorch and freia.

Can you or @ardizzone et al give me direct hint/pointer where in this repo or its dependencies this L2 weight normalisation is performed? I'd appreciate that.

MokkeMeguru commented 4 years ago

@ardizzone No. You will learn the neural networks and their optimization.

You should watch the hyper-parameter, weight_decay in Adam Optimizer. https://github.com/VLL-HD/conditional_invertible_neural_networks/blob/master/colorization_cINN/model.py#L249

See weight-decay's explanation. https://pytorch.org/docs/stable/optim.html

(But we know l2 regularization is good for kernel parameter, not good for bias... https://stats.stackexchange.com/a/167608)

psteinb commented 4 years ago

Alright, triggered by your reply I looked into this a bit. As mentioned, please excuse my novice expertise with pytorch and feel free to correct me at any point.

So if I understand correctly, the L2 regularization term mentioned in equation (6) of the preprint to this repo is assumed to be backed into lines like https://github.com/VLL-HD/conditional_invertible_neural_networks/blob/45dc7250ebfecf10d1a278edafde0fe899f30aa1/colorization_cINN/model.py#L249

But I start to believe that this assumption does not hold and should rather be used with AdamW. So here is the story:

So this supports my notion expressed above, that the code does not do what the paper promises, i.e. perform L2 regularisation on the weights in the loss term. However, the only thing I can suggest to do in order to mitigate it is adding the L2 regularisation explicitly to the loss term. E.g. here https://github.com/VLL-HD/conditional_invertible_neural_networks/blob/45dc7250ebfecf10d1a278edafde0fe899f30aa1/colorization_cINN/train.py#L67

#untested code!
l = torch.mean(neg_log_likeli) / tot_output_size + tau*torch.norm(model.params_trainable)

Or you use plain SGD instead.

MokkeMeguru commented 4 years ago

Yeah, you are correct. Adam is the black box for me, too.

And also, we should read their all release notes. But I don't recommend it. We should implement custom Adam, use SGD or use AdamW.

Or... In OpenAI's Glow or many other Flow-based Model Implementation, they don't use l2 regularization for whole training parameters.