EugenHotaj / pytorch-generative

Easy generative modeling in PyTorch
MIT License
428 stars 68 forks source link

Loading 3 Channels Images to PixelSnail #18

Closed eyalbetzalel closed 4 years ago

eyalbetzalel commented 4 years ago

Hi!

I try to load the ImageNet dataset to the PixelSnail Model.

after fixing this lines in the pixel_snail.py scripts :

 `class PixelSNAIL(base.AutoregressiveModel):
  """The PixelSNAIL model.
  Unlike [1], we implement skip connections from each block to the output.
  We find that this makes training a lot more stable and allows for much faster
  convergence.
  """
  def __init__(self, 
               in_channels=3 ####### <------ I changed this, 
               out_dim=1,
               probs_fn=torch.sigmoid,
               sample_fn=lambda x: distributions.Bernoulli(probs=x).sample(),
               n_channels=64,
               n_pixel_snail_blocks=8,
               n_residual_blocks=2,
               attention_key_channels=4,
               attention_value_channels=32,
               head_channels=1):
    """Initializes a new PixelSNAIL instance.`

I still get this error :

RuntimeError: Given groups=1, weight of size [64, 1, 3, 3], expected input[128, 3, 64, 64] to have 1 channels, but got 3 channels instead

Am I missing something? should I do anything else when trying to load 3 channels images?

Thanks, Eyal

EugenHotaj commented 4 years ago

Thanks for filing this!

There was a bug in the PixelSnail code where the in_channels argument was not getting passed correctly to some of the underlying blocks. Should be fixed now in dbe3828a4287f0b575f8b6fe726e2b6de94dfbf0. Give it another shot and let me know if you run into any issues.

By the way, I noticed you may be using an old version of the code, would suggest to pull/rebase.

eyalbetzalel commented 4 years ago

@EugenHotaj

Thanks for the quick replay :)

I'll try it and let you know.

EugenHotaj commented 4 years ago

Going to close this out, please reopen if you see any more issues!