rosinality / vq-vae-2-pytorch

Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch
Other
1.6k stars 270 forks source link

[Question] What is PixelSnail? How to Train it? #4

Open EibrielInv opened 5 years ago

EibrielInv commented 5 years ago

Hi! I'm failing to understand the function of PixelSnail. Is it to generate a latent space similar to a GAN?

I trained VQVAE correctly (until the samples were good enought):

python train_vqvae.py ./dataset_path

Then I performed a test to train PixelSnail. Is it correct?

Extracted the codes (I assume that are the encoding for each image on the dataset):

python extract_code.py --ckpt checkpoint/vqvae_241.pt --name small256 ./dataset_path

Then I trained Top hierarchy (about 30 minutes per batch, only trained 1 batch):

python train_pixelsnail.py --hier top --batch 8 small256

Then I trained Bottom hierarchy (about 30 minutes per batch, only trained 1 batch):

python train_pixelsnail.py --hier bottom --batch 8 small256

And finally I sampled:

python sample.py --vqvae checkpoint/vqvae_001.pt --top checkpoint/pixelsnail_top_001.pt --bottom checkpoint/pixelsnail_bottom_001.pt output.png

The output, as expected, is just noise since I only trained 1 batch on Pixelsnail.

output

If I just keep training PixelSnail will I be able to obtain good samples?

Hardware: NVIDIA 1080Ti

Thank you!

rosinality commented 5 years ago

Yes, it will generates sample of latent code for VQ-VAE. I checked it can make some samples if you train enough. But you will need to use a quite large model.

pclucas14 commented 5 years ago

would you mind sharing samples ? Just to get an idea of what to expect

rosinality commented 5 years ago

sample Not very nice, but it is from somewhat smaller model than the model in the paper.

pclucas14 commented 5 years ago

that's pretty good! thanks for sharing :)

1Konny commented 5 years ago

looks great! would you mind sharing your (hyper-)parameter setting and the resultant accuracy of top/bottom PixelSNAIL for this result?

rosinality commented 5 years ago
Top

Trained 109 epochs with lr 1e-4 and 9 epochs with lr 1e-5, and accuracy was about 48%

Bottom

Trained 70 epochs with lr 1e-4 and 4 epochs with lr 1e-5, and accuracy was about 21%

I think you can increase res_channel & dropout to match the hyperparameters in the paper, but I can't use that setting because of large amount of memory requirements.

Hope this helps.

1Konny commented 5 years ago

it really helps! thanks for the details.

k-eak commented 5 years ago
Top
  • channel: 512
  • n_block: 4
  • n_res_block: 5
  • res_channel: 512
  • n_cond_res_block: 0
  • n_out_res_block: 5
  • attention: True
  • dropout: 0.1
  • batch size: 63 (1e-4) / 64 (1e-5)

Trained 109 epochs with lr 1e-4 and 9 epochs with lr 1e-5, and accuracy was about 48%

Bottom
  • channel: 512
  • n_block: 4
  • n_res_block: 5
  • res_channel: 512
  • n_cond_res_block: 5
  • cond_res_channel: 512
  • n_out_res_block: 0
  • attention: False
  • dropout: 0.1
  • batch size: 64

Trained 70 epochs with lr 1e-4 and 4 epochs with lr 1e-5, and accuracy was about 21%

I think you can increase res_channel & dropout to match the hyperparameters in the paper, but I can't use that setting because of large amount of memory requirements.

Hope this helps.

Thank you for sharing details, can you also share how many GPU you use to train these networks?

rosinality commented 5 years ago

@k-eak I have used 4 V100s with mixed precision training.

ywang370 commented 5 years ago

@k-eak I have used 4 V100s with mixed precision training.

@rosinality Thanks for sharing the details. I am curious about the mixed precision training, do you use some package ? dose mixed training help you increase the batch size? Is it possible to share more on this part?

rosinality commented 5 years ago

@ywang370 I have used NVIDIA apex amp (https://github.com/NVIDIA/apex), with opt_level O1. I think mixed precision training was quite helpful for increasing batch sizes and reducing training times. It is hard to compare directly as GPU is different, but with mixed precision training on V100s is more than 2x faster with 2x batch sizes than FP32 training on P40s.

phongnhhn92 commented 5 years ago

@rosinality Hi I am having 2 P100, is there any improvement if I use apex for training the Pixel SNAIL in general ? Would you mind sharing the code you have used to enable mixed precision training you have mentioned above using apex ? I can not find where did you use apex in the github repo.

rosinality commented 5 years ago

@phongnhhn92 Added simple support form apex at 7a2fbda.

zaitoun90 commented 5 years ago

Hi, I have train vqvae and I got very similar images. my dataset is 159 images. then I train extract_code.py (my point here) How many checkpoints should I use in the end!?

after that I tried to train train_pixelsnail.py ( ervery time I got a problem in line 40 in dataset.py it is about no decode) then i tried to check if the lmdb file has some data or not , i print the env.state and I got this out put ({'psize': 4096, 'depth': 0, 'branch_pages': 0, 'leaf_pages': 0, 'overflow_pages': 0, 'entries':)

I am trying to solve it but it is not working.

thanks a lot.

karamarieliu commented 5 years ago

How long did it take you per epoch (and how many iterations did you have in an epoch)? I'm finding it takes a considerable amount of time (~7 hours for 34k iterations of batch size 32).

zaitoun90 commented 5 years ago

How long did it take you per epoch (and how many iterations did you have in an epoch)? I'm finding it takes a considerable amount of time (~7 hours for 34k iterations of batch size 32).

for which one did you mean ( train_vqvae.py or extract_code.py)!? for both of them, is not that much 15 mints. I have a small dataset and I am using 2x gtx 1080 GPU.

for train_pixelsnail.py I ma not succeed till now, I have the above problem.

I used the original parameters and I change the batch_size to 32 for(train_vqvae.py).

rosinality commented 5 years ago

@zaitoun90 Could you recheck extract_code.py step? I think it might lmdb related problems. @karamarieliu Yes, train_pixelsnail.py requires a lot of time as PixelSNAIL model is quite large.

zaitoun90 commented 5 years ago

@rosinality thanks, now it is working.

zaitoun90 commented 5 years ago

Hi one more question, I run everything correctly but still I am getting the samples similar to the original one. I thought that I can generate different images !? Could be!!:: @rosinality is the samples that you shared is different from the original dataset!?

rosinality commented 5 years ago

@zaitoun90 Do you mean output from ground truth code input? Then it should be similar to input images. To get samples from the model you can use sample.py.

zaitoun90 commented 5 years ago

@rosinality yes, I use sample.py but still, the output similar to the input images. I expect after this long training of vqvae and pxielsnail that I can generate different samples.

rosinality commented 5 years ago

@zaitoun90 sample.py doesn't use image inputs. sample.py should generate samples from scratch.

k-eak commented 5 years ago

@rosinality when I check the sample.py I noticed that F.one_hot function seem to be taking too much of time (190 seconds for top-level with batch size:32). I tried to change it with a scatter function to update it according to the previous samples but for some reason, the network is processing much slower now. Do you have any idea why this is happening and have any suggestions on how to improve the sampling time?

rosinality commented 5 years ago

@k-eak Current implementation is quite inefficient, for example one_hot will operate on sequences of 16896 elements per example at top-level. Maybe you can use some kind of caching. I also have tried to implement caching, but I got only 2x improvements...

k-eak commented 5 years ago

@rosinality Thank you for the suggestion. I replaced the one-hot and now I update it after each sample with the scatter function. Although this improved the speed compared one_hot, the network is now taking longer to process and in the end, the improvement is very small. Do you think I am missing something?

Here is the changed sampling code: (I removed one_hot function in pixelsnail)

row = torch.zeros(batch, 512, *size, dtype=torch.int64).to(device)
row_sample = torch.zeros(batch, *size, dtype=torch.int64).to(device)
cache = {}
for i in tqdm(range(size[0])):
    for j in range(size[1]):
        out, cache = model(row[:, :, : i + 1, :], condition=condition, cache=cache)
        prob = torch.softmax(out[:, :, i, j] / temperature, 1)
        sample = torch.multinomial(prob, 1)
        row[:,:,i,j] = row[:,:,i,j].scatter(1, sample, 1)
        row_sample[:, i, j] = sample.squeeze(-1)
return row_sample
rosinality commented 5 years ago

@k-eak Did you added torch.cuda.synchronize()? I think speed measurement can be inaccurate because of asynchronous nature of PyTorch. Also speed gain can be small as much of the computation will occur in the rest of the model.

k-eak commented 5 years ago

@rosinality oh my bad, I needed to add torch.cuda.synchronize(). So my method does not change the speed that much and mostly saves a couple of seconds for large batches. I might try adding caching idea from "https://github.com/PrajitR/fast-pixel-cnn/blob/master/fast_pixel_cnn_pp/fast_nn.py" but might take some time to implement it on PyTorch.

Mut1nyJD commented 5 years ago

When I use train_pixelsnail.py accuracy immediately hits 1.0 and the loss goes to basically zero after less than 100 iterations. This feels weird to me, what is going on?

I've got these settings:

amp='O0', batch=12, channel=512, ckpt=None, dropout=0.1, epoch=200, hier='top', lr=0.0001, n_cond_res_block=0, n_out_res_block=5, n_res_block=5, n_res_channel=512

Slimco86 commented 4 years ago

@Mut1nyJD

D When I use train_pixelsnail.py accuracy immediately hits 1.0 and the loss goes to basically zero after less than 100 iterations. This feels weird to me, what is going on?

I've got these settings:

amp='O0', batch=12, channel=512, ckpt=None, dropout=0.1, epoch=200, hier='top', lr=0.0001, n_cond_res_block=0, n_out_res_block=5, n_res_block=5, n_res_channel=512

I have the same issue. In my case the data_set might be "too simple", this is just my guess... What about your data???

Mut1nyJD commented 4 years ago

@Slimco86

No I don't think it is too simple, I am using this on here:

https://www.mut1ny.com/peoplepose20k

Slimco86 commented 4 years ago

@Mut1nyJD Ok, I figured it out, in my case the VQVAE training converged to some local minima, so the reconstruction samples where not good and almost identical. I retrained it, playing around with hyperparameters and now everything is fine.

drtonyr commented 4 years ago

I've seen this problem, in my case I can reproduce it easily by setting the learning rate, --lr, high.

I also get the opposite problem, that the latent goes to zero and it settles into producing a uniform colour output. This is another minima - just don't pass any information through and output a constant.

I'm guessing that it is weighting the latent cost too much. I don't know where latent_loss_weight = 0.25 comes from. I reduced the factor from 0.25 to 0.05 (complete guess) and it seemed to fix the problem.

Overall I don't find that VQ-VAE-2 fits with pixelSNAIL very cleanly, too many things to train independenly and hope they fit together at the end. There has to be a single cost function that would be cleaner.

Mut1nyJD commented 4 years ago

@drtonyr
Okay thanks I was using default -lr setting but I'll try again with your suggestion as soon as I find some time in my training slot :). Meaning decreasing the latent_loss and if that does not help the lr . But I agree this combo VQ-VAE-2 and pixelSnail feels suboptimal. Even though VQ-VAE-2 on itself does provide good reconstruction

LinfengLiu98 commented 3 years ago

@zaitoun90 Could you recheck extract_code.py step? I think it might lmdb related problems. @karamarieliu Yes, train_pixelsnail.py requires a lot of time as PixelSNAIL model is quite large.

Hey, with regarding the training time, do you count them in days or just few hours? Because mine takes really long and I think probably days? I'm using 1 32G GPU tesla-smx2. Thank you!

easonoob commented 1 year ago

@zaitoun90 sample.py doesn't use image inputs. sample.py should generate samples from scratch.

@rosinality umm then what happens if I input an image? Will it become a whole new image? But this is a vae that reconstructs images? What is the PixelSnail doing? Is it a generator? Thanks!