EugenHotaj / pytorch-generative

Easy generative modeling in PyTorch.
MIT License
422 stars 68 forks source link

Replicating NLL Results for PIXEL CNN #21

Closed anton-jeran closed 3 years ago

anton-jeran commented 3 years ago

Could you please let me know, how can I replicate NLL results in the paper. For how many epochs should I train. Could you please give me script to generate an image using the trained model?

EugenHotaj commented 3 years ago

Hey @anton-jeran,

Could you please let me know, how can I replicate NLL results in the paper. For how many epochs should I train.

For training the models, take a look at train.py.

Which results are you trying to replicate? Generally, you should use the hyperparameters discussed in the paper. Here is the training script that I used for Binarized MNIST (note that this won't work for CIFAR10 and would have to be updated accordingly):

N_EPOCHS = 500

model = PixelCNN(in_channels=1,
                 out_channels=1,
                 residual_channels=16,
                 head_channels=32,
                 n_residual_blocks=15)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda _: .999977)

def loss_fn(x, _, preds):
  batch_size = x.shape[0]
  x, preds = x.view((batch_size, -1)), preds.view((batch_size, -1))
  F.binary_cross_entropy_with_logits(preds, x).sum(dim=1).mean()

trainer = pg.trainer.Trainer(model=model, 
                             loss_fn=loss_fn, 
                             optimizer=optimizer, 
                             train_loader=train_loader, 
                             eval_loader=test_loader,
                             lr_scheduler=scheduler,
                             device=torch.device("cuda"))
trainer.interleaved_train_and_eval(N_EPOCHS)

Could you please give me script to generate an image using the trained model?

Once the model is trained you can generate images by calling model.sample() 😃.

EugenHotaj commented 3 years ago

Hey @anton-jeran,

I've now added reproduce functions to all models which encode the hyperparameters. You should be able to reproduce the results by simply calling that function.

A super simple way to do this is to just run python train.py --model pixel_cnn.