Closed eyalbetzalel closed 3 years ago
Hey @eyalbetzalel,
Could you post your whole training script?
Here are a few things to look into:
1) The PixelSnail implementation assumes by default that your data comes from a Bernoulli distribution (i.e. all the pixel values are either 0
or 1
) which is a reasonable assumption for Binarized MNIST. For CIFAR10 (and even regular MNIST), the pixel values can take any value between 0
and 255
, so the Bernoulli assumption no longer makes sense and we must try something else. The original PixelCNN paper makes the assumption that CIFAR10 data comes from a Categorical distribution with 256 values while follow up papers (such as PixelCNN++ and PixelSnail) assume the data comes from a discretized logistic mixture. What this means in practice is that our loss function needs to change. For the categorical case, the change is simply to replace BCEWithLogitsLoss
to NLLLoss
. The discretized logistic mixture case would need a custom loss which we have not implemented yet (but happy to work together if you want to submit a PR!).
sample_fn
. By default we use a Bernoulli distribution for sampling but we'd instead need to use the appropriate sampling function (for example the Categorical distribution in the categorical case).There are also a lot more details in the papers I linked above.
Thanks for your reply.
I will look further into this and let you know :)
Hi
I trained PixelSnail on CIFAR10 and try to sample from the model with this function that I wrote (based on the code from this project notebook) :
Where cu.imsave is :
For some reason the output after 90 epochs looks bad :
am I missing anything?