google-research / pix2seq

Pix2Seq codebase: multi-tasks with generative modeling (autoregressive and diffusion)
Apache License 2.0
857 stars 71 forks source link

Trouble training RIN on CIFAR-10 #42

Closed leon-w closed 1 year ago

leon-w commented 1 year ago

Hi,

I'm currently trying to train RIN on CIFAR-10 using the code and config that was provided in the repository.

I did some minor changes to the code to get it to work:

The entire diff can be seen here: https://github.com/google-research/pix2seq/compare/main..leon-w:b6609 This was the exact code I used for training and evaluation.

Training Setup

I trained the model using this command:

python run.py \
 --config configs/config_diffusion_cifar10.py \
 --mode train \
 --model_dir results/cifar10 \
 --config.train.checkpoint_epochs 5 \
 --config.train.keep_checkpoint_max 2

train_log.txt

The training finishes after around 2 hours. These are the training curves logged to Tensorboard: image

Eval Setup

I then run the trained model in evaluation mode to create a few samples using:

python run.py \
 --config configs/config_diffusion_cifar10.py \
 --mode eval \
 --model_dir results/cifar10 \
 --config.eval.steps 1

eval_log.txt

Unfortunately, the generated samples don't seem to contain anything meaningful and only look like pure noise:

image

I also tried training the model 10x longer but still got only noise.

Has anyone successfully trained a RIN model using this code base before and has any idea how I can get this to work? Any help would be highly appreciated!

leon-w commented 1 year ago

I figured it out: Instead of using the config that get_config() in config_diffusion_cifar10.py returns, one should use the hyperparameters that are contained in get_sweep() in the same file. With that i managed to get very convincing results, here are some samples for the first 67k steps:

tape_cifar10_samples.webm

If anyone is looking for ready-to-use code, feel free to use this: https://github.com/leon-w/pix2seq/tree/cifar10