EugenHotaj / pytorch-generative

Easy generative modeling in PyTorch.
MIT License
422 stars 68 forks source link

Sampling on 3 channels looks corrupted #20

Closed eyalbetzalel closed 2 years ago

eyalbetzalel commented 3 years ago

Hi

I trained PixelSnail on CIFAR10 and try to sample from the model with this function that I wrote (based on the code from this project notebook) :

        print("Sample")

        if self._epoch % 10 == 0 :
            print("Epoch Number: " + str(self._epoch))
            print("sampling")
            curr_path = 'sample_from_epoch_' + str(self._epoch) + '.png'
            print(curr_path)
            sampleTensor=self._model.sample((10, 3, 32, 32))
            sampleTensor=sampleTensor.cpu()
            cu.imsave(sampleTensor, figsize=(50, 5),filename = curr_path)

      self._summary_writer.close()

Where cu.imsave is :

def imsave(batch_or_tensor, title=None, figsize=None, filename="sample.png"):
  """Renders tensors as an image using Matplotlib.
  Args:
    batch_or_tensor: A batch or single tensor to render as images. If the batch
      size > 1, the tensors are flattened into a horizontal strip before being
      rendered.
    title: The title for the rendered image. Passed to Matplotlib.
    figsize: The size (in inches) for the image. Passed to Matplotlib.
  """
  batch = batch_or_tensor
  for _ in range(4 - batch.ndim):
    batch = batch.unsqueeze(0)
  n, c, h, w = batch.shape
  tensor = batch.permute(1, 2, 0, 3).reshape(c, h, -1)
  image = _IMAGE_UNLOADER(tensor)

  plt.figure(figsize=figsize)
  plt.title(title)
  plt.axis('off')
  plt.imsave(filename,image)

For some reason the output after 90 epochs looks bad :

sample_from_epoch_90

am I missing anything?

EugenHotaj commented 3 years ago

Hey @eyalbetzalel,

Could you post your whole training script?

Here are a few things to look into:

1) The PixelSnail implementation assumes by default that your data comes from a Bernoulli distribution (i.e. all the pixel values are either 0 or 1) which is a reasonable assumption for Binarized MNIST. For CIFAR10 (and even regular MNIST), the pixel values can take any value between 0 and 255, so the Bernoulli assumption no longer makes sense and we must try something else. The original PixelCNN paper makes the assumption that CIFAR10 data comes from a Categorical distribution with 256 values while follow up papers (such as PixelCNN++ and PixelSnail) assume the data comes from a discretized logistic mixture. What this means in practice is that our loss function needs to change. For the categorical case, the change is simply to replace BCEWithLogitsLoss to NLLLoss. The discretized logistic mixture case would need a custom loss which we have not implemented yet (but happy to work together if you want to submit a PR!).

  1. Similarly, we must also update our sample_fn. By default we use a Bernoulli distribution for sampling but we'd instead need to use the appropriate sampling function (for example the Categorical distribution in the categorical case).

There are also a lot more details in the papers I linked above.

eyalbetzalel commented 3 years ago

Thanks for your reply.

I will look further into this and let you know :)