pclucas14 / pixel-cnn-pp

Pytorch Implementation of OpenAI's PixelCNN++
Other
345 stars 76 forks source link

is the causality constrain satisfied? #24

Closed Weafre closed 2 years ago

Weafre commented 2 years ago

Hi, thanks for the implementation.

The idea behind autoregressive modeling (pixelcnn, pixelcnn++,...) is that pixels are generated sequentially depending on the previous pixels. it is so called causality constraint.

with up and down sampling, i think the constraint is not satisfied.

It could be because of up and downsampling layers in your network.

Please correct me if I am wrong.

Here is the code:

model.train(False)  # trained on cifar dataset 
data = torch.zeros(1, obs[0], obs[1], obs[2])
data = data.cuda()

data_v = Variable(data, volatile=True)
out = model(data_v, sample=True)

#check output at spatial position [20,20]
print(out[0,10:13,20,20]) # return: tensor([-0.4630, -0.0477,  0.2698],)

data1=data
#change rgb value at a future pixel in the input
data1[0,:,25,24]=1000 
data_v1 = Variable(data1, volatile=True)
out1 = model(data_v1, sample=True)

#and check the output again
print(out1[0,10:13,20,20]) #return: tensor([-0.4629, -0.0476,  0.2698])

There is just a slight difference between the two output tensors, but theoretically, they must be equal.

pclucas14 commented 2 years ago

Hi,

thanks for raising this issue, let me get back to you

pclucas14 commented 2 years ago

Ok, I think I know why; the output when sampling comes from a stochastic distribution (or else you would always get the same image from the model). This distribution should indeed depend only on previous pixels as you pointed out. You should try the same test with sampling=False

Weafre commented 2 years ago

Hi Lucas, thanks for answering.

actually i printed out the distribution parameters (weights, means, scales), this is before sampling step.

the output parameters at any spatial location should be equal regardless any changes in the future pixels, cuz we are predicting the distribution p(x_i | x_i-1, x_i-2,…,x_1) in the form of logistic distribution. Right?

do you think the downsampling, upsampling cause the problem? In pixelcnn (not pixelcnn++) the spatial dimension is always kept as the input.

pclucas14 commented 2 years ago

It’s hard to tell; typically when there is a leak, your loss goes down to something very small, which is not the case here. I would guess that there is no leak since my number matches the original implementation, but I can’t quite figure out why the numbers are different

pclucas14 commented 2 years ago

can you retry the previous experiment, but instead of modifying a future pixel to 1000, set it to nan or inf?

Weafre commented 2 years ago

by modifying the future pixel to 1000 while the original value is in range [-1, 1], there is just a slight change in the output. I guess there is a leak but a tiny leak, and in the end, we compute the loss after sampling from the distribution which make the difference even smaller. if this is a problem, it is not a problem of your implementation, I guess there is the same problem in the original paper. Unfortunately, I tried to perform the same experiment using the open ai tensorflow implementation but having issue with env configuration 🥲 Why do you want to try with nan or inf?

pclucas14 commented 2 years ago

I ask because honestly the change is so small I'm wondering if it could simply be somehow due to numerical precision. With nan or inf, if there is a leak it will propagate because nan * 2 = nan

Weafre commented 2 years ago
Screenshot 2022-02-28 at 09 21 38 Screenshot 2022-02-28 at 09 22 13

Hello, this is the output if we change the future pixel to nan or inf, and yes, the output of the current pixel are NaN in both case

Weafre commented 2 years ago

is there any possible explanation for this except a leak?

pclucas14 commented 2 years ago

I'm not sure, there could very well be. I tried your experiment from an old checkpoint (and a newer python version) but got a different result

model.load_state_dict({k[len('module')+1:]: v for k,v in torch.load('models/pcnn_lr:0.00040_nr-resnet5_nr-filters160_499.pth').items()})
model.train(False)  # trained on cifar dataset
data = torch.zeros(1, obs[0], obs[1], obs[2])
data = data.cuda()

data_v = Variable(data, volatile=True)
out = model(data_v, sample=True)

#check output at spatial position [20,20]
print([x.item() for x in out[0,10:13,20,20]])

data1=data
#change rgb value at a future pixel in the input
data1[0,:,25,24]=1000
data_v1 = Variable(data1, volatile=True)
out1 = model(data_v1, sample=True)

#and check the output again
print([x.item() for x in out1[0,10:13,20,20]])

which gave me

[0.00214940682053566, -0.00433339923620224, 0.0053053549490869045]
[0.00214940682053566, -0.00433339923620224, 0.0053053549490869045]

I'm running python 3.9 and torch 1.9, so maybe something to consider if you plan on using this code

Weafre commented 2 years ago

hi, I am running on the same env configuration. Have you tried with nan and inf input?

pclucas14 commented 2 years ago

Yes, I get the same answer even with data1[0,:,25,24]= np.nan

Diving a bit more I get the expected behavior, which is that (25,24) is ok but (25,25) is nan

image

Weafre commented 2 years ago

mhm, it seems like the implementation is fine. Please keep the issue open, I'll check from my side again shortly. is your checkpoint published somewhere?

from a theoretically perspective, I still think that down and up sampling destroy the causality, as you can see in the sketch, am I right?274056354_653745759232433_8554351282479193409_n

pclucas14 commented 2 years ago

Ok I see where your concern is coming from. They address this by padding the original input (so this way everything is shifted to keep the causality constraint respected) (see https://github.com/pclucas14/pixel-cnn-pp/blob/master/layers.py#L38-L41)

Weafre commented 2 years ago

are padding and shifting the other way around to enforce the causality instead of masking (pixelcnn) ? Could you explain why we can just easily downsample and upsample without worrying about the causality? Thanks.

pclucas14 commented 2 years ago

Yes to your first question. At the end of the day, iirc the upsampling and downsampling are implemented as convolutions, so you can apply the same logic

Weafre commented 2 years ago

Hi, I have tested the causality many times but still getting the different output values.

Could you please share your checkpoint? or test with a not yet trained/ early check point?

I initialize the model and just train for 1 epoch before testing the causality, and that's the only difference between your experiment and mine.

pclucas14 commented 2 years ago

Hi,

I tried using a fresh version of the repo (to make sure I didn't have any new changes), and running this extract script still gives me the expected output

import time
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, utils
from tensorboardX import SummaryWriter
from utils import *
from model import *
from PIL import Image
import numpy as np

parser = argparse.ArgumentParser()
# data I/O
parser.add_argument('-i', '--data_dir', type=str,
                    default='../pixel-cnn-pp/data', help='Location for the dataset')
parser.add_argument('-o', '--save_dir', type=str, default='models',
                    help='Location for parameter checkpoints and samples')
parser.add_argument('-d', '--dataset', type=str,
                    default='cifar', help='Can be either cifar|mnist')
parser.add_argument('-p', '--print_every', type=int, default=50,
                    help='how many iterations between print statements')
parser.add_argument('-t', '--save_interval', type=int, default=10,
                    help='Every how many epochs to write checkpoint/samples?')
parser.add_argument('-r', '--load_params', type=str, default=None,
                    help='Restore training from previous model checkpoint?')
# model
parser.add_argument('-q', '--nr_resnet', type=int, default=5,
                    help='Number of residual blocks per stage of the model')
parser.add_argument('-n', '--nr_filters', type=int, default=160,
                    help='Number of filters to use across the model. Higher = larger model.')
parser.add_argument('-m', '--nr_logistic_mix', type=int, default=10,
                    help='Number of logistic components in the mixture. Higher = more flexible model')
parser.add_argument('-l', '--lr', type=float,
                    default=0.0002, help='Base learning rate')
parser.add_argument('-e', '--lr_decay', type=float, default=0.999995,
                    help='Learning rate decay, applied every step of the optimization')
parser.add_argument('-b', '--batch_size', type=int, default=64,
                    help='Batch size during training per GPU')
parser.add_argument('-x', '--max_epochs', type=int,
                    default=5000, help='How many epochs to run in total?')
parser.add_argument('-s', '--seed', type=int, default=1,
                    help='Random seed to use')
parser.add_argument('-D', '--debug', action='store_true', help='debug mode: does not save the model')
args = parser.parse_args()

# reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)

obs = (3, 32, 32)

model = PixelCNN(nr_resnet=args.nr_resnet, nr_filters=args.nr_filters,
            input_channels=3, nr_logistic_mix=args.nr_logistic_mix)

model = model.cuda()

model.load_state_dict({k[len('module')+1:]: v for k,v in torch.load('../pixel-cnn-pp/models/pcnn_lr:0.00040_nr-resnet5_nr-filters160_499.pth').items()})

model.train(False)  # trained on cifar dataset
data = torch.zeros(1, obs[0], obs[1], obs[2])
data = data.cuda()

data_v = Variable(data, volatile=True)
out = model(data_v, sample=True)

#check output at spatial position [20,20]
print([x.item() for x in out[0,10:13,20,20]]) 

data1=data
#change rgb value at a future pixel in the input
data1[0,:,25,24]= np.nan
data_v1 = Variable(data1, volatile=True)
out1 = model(data_v1, sample=True)

#and check the output again
print([x.item() for x in out1[0,10:13,20,20]])

Using the model checkpoint the output is

[0.002149447798728943, -0.004333361983299255, 0.005305382888764143]
[0.002149447798728943, -0.004333361983299255, 0.005305382888764143]

And using a model freshly initialized (no checkpoint) :

[0.12006216496229172, 0.1424560248851776, -0.09435388445854187]
[0.12006216496229172, 0.1424560248851776, -0.09435388445854187]
pclucas14 commented 2 years ago

Again, this was done with torch 1.9.0 and python 3.9.6. I'm not really sure where the error is coming from on your end, I would suggest to do like me and pull from source with the package versions above and see if you still get the same error.

Weafre commented 2 years ago

it's working now, reinstalling torch to a newer version (1.10) solve the problem. (i was using torch 1.8.1 before) Thanks