lucidrains / recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch
MIT License
194 stars 14 forks source link

notes #2

Closed ajabri closed 1 year ago

ajabri commented 1 year ago

Hi Phil,

Thanks for your interest in the architecture!

I just wanted to point out that we don't actually tie the weights across RIN blocks of the same forward pass: https://github.com/lucidrains/recurrent-interface-network-pytorch/blob/main/rin_pytorch/rin_pytorch.py#L398 This is important because tieing the weights in the forward pass (which I experimented with) gives the model significantly less capacity. Will highlight this more in the paper's pseudo-code in an updated version.

In that way, the architecture is actually more similar to models like Perceiver / BigBird / memory transformer, rather than ISAB, because achieving depth without reinitializing "inducing points" is quite important for the model to learn how to use the latents effectively (ISAB mostly achieves clustering of the data tokens at each block, since it resets inducing points at each block).

You can view RIN as factoring a global-self attention matrix that operates on a set of [global, data tokens] into global-global, global-data (write), data-global (read), and data-data interactions. We never use data-data interactions (like Perceiver), but show that adding global-data interactions (writing) is important. We use global-global interactions (like BigBird / memory transformer), but apply them more frequently than other interactions; importantly, global tokens have more channels and are processed with wider blocks.

The 'recurrent' aspect actually comes from the fact that warm-starting the latents across iterations (i.e. forward passes) gives you an effect akin to an RNN, since you can propagate latent information; the result is that the model is effectively very deep at inference.

Happy to discuss further :) Allan

thuanz123 commented 1 year ago

Hi @ajabri, what do you mean by tying the weights ?

lucidrains commented 1 year ago

hey Allan, given the claims of the paper, could you try training on 1024 x 1024 FFHQ, especially if you still have access to google-level compute? the results would be most informative. I'm skeptical you can stick with 256 latents and still get good results, but willing to change my mind.

darius-lam commented 1 year ago

@lucidrains How long did you train the 130k flowers model for on how many GPUs? What were the hyperparameters you used, like model dimension and number of layers?

lucidrains commented 1 year ago

@Lamikins pretty much what is in the readme

it should be even better, as the run i did was using weight tied layers

darius-lam commented 1 year ago

@ajabri I was unable to reproduce your CIFAR10 results from the paper,

here are my FID scores over ~90k steps using RIN with dim=256, patch_size=1, num_latents=256, depth=6, latent_self_attn_depth=2 and batch_size=128. Learning Rate scaled from 1e-3 to ~2e-4 over training. EMA set to .9999 after a few thousand steps.

Screenshot 2023-01-07 at 11 32 38 AM

I am wondering how you were able to achieve the 1.81 FID in the paper-- did you use any additional model changes, or additional hyperparameters?

Example Generations:

media_images_demo_grid_89500_ca78db829a9986f9398b

thuanz123 commented 1 year ago

@Lamikins, from what I understand the latent (z) dim and attention should be larger than image (x) dim and attention, yet they are the same in this implementation. Is that correct @ajabri, @lucidrains ?

darius-lam commented 1 year ago

Yea that's correct; Although the cross attention requires the dims of Q,K,V to be the same, so eventually we will have to map to the same dimensions anyway

thuanz123 commented 1 year ago

Yeah except for the the x-z cross attention and x self-attention, the self-attention for z can be much larger so that the model can have greater learning capacity

darius-lam commented 1 year ago

oh interesting, I overlooked that. I'll try running those experiments.

lucidrains commented 1 year ago

@Lamikins i apologize but i had a big bug where i tied the cross attention layers going to and from image to latents - should be fixed in 0.1.0! i'll rerun my experiments overnight

@thuanz123 thanks Thuan for noticing that the latent dimension can be separate from the image dimension. i've added the option as dim_latent

darius-lam commented 1 year ago

Thanks Phil! I caught that as well, unfortunately despite my best efforts this weekend I was unable to get RIN to converge. I used this architecture in the k-diffusion repo and the model is very unstable during training.

I added layer norms to the architecture similar to the normal transformer and that seemed to help a bit.

Screenshot 2023-01-08 at 11 27 46 AM

However, despite adding layer norm, untying the weights, and enabling different latent_dim and interface_dim, the FID scores are nowhere near the paper (nor near the base model of k-diffusion, a CIFAR FID score of ~28). The increasing FID scores are due to nans in the base model (I use the EMA model to calculate FID).

This is using a batch size of 128, lr=1e-3 with ADAM, and varying between 30 and 75M parameters.

darius-lam commented 1 year ago

One thing to note: without using layernorm on every latent output, the magnitude of the latents grows very large, which is a reason for the NaNs.

Screenshot 2023-01-08 at 11 37 14 AM

Red has no layer norm, the others do. Obviously this effect is worse as depth increases.

lucidrains commented 1 year ago

@Lamikins oh that's strange, i have pre-layernorms for all the attention and feedforward layers

lucidrains commented 1 year ago

@Lamikins try lowering your learning rate to something like 2e-4

lucidrains commented 1 year ago

@Lamikins where are you adding the layernorms?

lucidrains commented 1 year ago

@Lamikins actually yea, that's another thing with the ISAB-like architecture. i've run into stability issues trying to apply it for a contract work where i needed long context, so i could believe what you are telling me

darius-lam commented 1 year ago

yea I think learning rate may be too high-- although I don't have the compute to do 12 hour long runs with a lower rate šŸ˜… Let me know if you want to collaborate on reproduction.

In this block:

      # latent self attention

      for attn, ff, ln1, ln2 in self.latent_self_attns:
          latents = attn(ln1(latents), time=t) + latents
          latents = ln2(ff(latents, time=t) + latents)

I found that ln2(ff(latents, time=t) + latents) versus the more standard ff( ln2(latents), time=t) + latents makes it harder for the latent vector to explode. Upon your comment I realize ln1 is redundant, so I'll try removing it

lucidrains commented 1 year ago

@Lamikins ohh got it, yes, SwinV2 actually uses that pattern, but they place the ln2 every 3rd transformer block

lucidrains commented 1 year ago

@Lamikins actually yea, it may make sense to have a norm just before cross attending from the image features to the latents, let me add that

lucidrains commented 1 year ago

@ajabri hi Allan, just saw this paper https://arxiv.org/abs/2301.10972

Sorry for doubting your results. Will be putting more work into RIN soon šŸ™

juvekaradheesh commented 1 year ago

@lucidrains hi Phil, can I ask what's the reason of adding input self conditioning? Did you find any performance increase from it? Since in the paper I see only latent self conditioning being used.

lucidrains commented 1 year ago

@juvekaradheesh i can make that optional, was just improvising

juvekaradheesh commented 1 year ago

No that's alright, I was just curious if it improved performance, to decide if I should keep it or not. But anyway, it might be dataset specific.

darius-lam commented 1 year ago

@lucidrains @ajabri I was able to reproduce the CIFAR 10 results using eps pred instead of x_0 pred in this repo. One challenge: because we use latent warm-starting, it's hard to get good results with smaller sampling steps. Sampling seems to require the latents to be "close" to the learned distribution, ie t-1 is very similar to t

Allan does this seem plausible? Any ideas on how to mitigate the effect?

ajabri commented 1 year ago

Hi @Lamikins, For shorter sampling I found that DDIM works better. The way the latent self-condition is trained atm is a bit naive... it would be interesting to try unrolling the network, training it with different gaps between the latent self-cond, or using progressive distillation (perhaps with bptt). And yea, eps pred works significantly better; also good to use the lamb optimizer with weight decay on kernels and pos embeddings.

lucidrains commented 1 year ago

oh oops, i don't remember why i had predict x0 objective as default

will get the predict eps objective in there tomorrow morning! (as well as the resolution dependent noise schedules from Ting Chen)

darius-lam commented 1 year ago

Great ideas @ajabri, I'm investigating the bptt unrolling idea. I also think the latent shifting idea is good, but haven't been able to get it to work, will try again.

Phil here are the CIFAR10 generations I got after 70k steps at bs=512, AdamW 1e-3, eps_pred, DDPM 1000 steps. It took longer to train than I expected, around 12 hours on 1xa100, and I think I can achieve better results with another 12 hours. However, I calculated that it should only have taken ~6-12 hours total, assuming 1xa100 = 2x TPUv3 (3 hours on 8 chips). Let me know what results you are able to get.

media_images_sample_58650_feee851cdc4be39a9d82

Darius

ajabri commented 1 year ago

@Lamikins LAMB optimizer with lr=1e-3 and wd on pos embeddings + kernels should work better. Sampling with 250 steps of DDIM should be decent. Should train for at least 150k steps (w/ bs=1024, so longer with bs=512). What kind of arch did you use?

lucidrains commented 1 year ago

Great ideas @ajabri, I'm investigating the bptt unrolling idea. I also think the latent shifting idea is good, but haven't been able to get it to work, will try again.

Phil here are the CIFAR10 generations I got after 70k steps at bs=512, AdamW 1e-3, eps_pred, DDPM 1000 steps. It took longer to train than I expected, around 12 hours on 1xa100, and I think I can achieve better results with another 12 hours. However, I calculated that it should only have taken ~6-12 hours total, assuming 1xa100 = 2x TPUv3 (3 hours on 8 chips). Let me know what results you are able to get.

media_images_sample_58650_feee851cdc4be39a9d82

Darius

yup it looks good

early generations look unfamiliar and checkerboardy given the patch training, but it converges nicely

i will try Ting Chen's logsnr shifting idea soon

lucidrains commented 1 year ago

@Lamikins LAMB optimizer with lr=1e-3 and wd on pos embeddings + kernels should work better. Sampling with 250 steps of DDIM should be decent. Should train for at least 150k steps (w/ bs=1024, so longer with bs=512). What kind of arch did you use?

hmm, i think LAMB is strictly a google thing, all the papers i've read suggest it doesn't do that much

but i'll stay open minded and look around for an open sourced implementation (or code up my own)

ajabri commented 1 year ago

@lucidrains In case it might save you time, note that the logsnr shifting isn't necessary for images at 256 and below (only really helps significantly when going to 512 and 1024, and mainly for faster convergence). Also, you'll find that adamw converges to higher losses and worse sample metrics across the board; from my experience, LAMB is very useful when learning latent cross attention (see also use in Perceiver/PerceiverIO).

ajabri commented 1 year ago

@lucidrains yes ofc bigger is better :) (and the main motivation for RIN) ... but training a 512/1024 model is ofc more expensive (e.g. the 1024px model was trained on 256-TPUv4 w bsz=512 for 1M iterations), so may not be ideal for tuning.

darius-lam commented 1 year ago

@lucidrains actually some friends put together this open-sourced LAMB: https://github.com/cybertronai/pytorch-lamb

It's 3 years old and slow so it might be worth forking + updating

lucidrains commented 1 year ago

ok, i've incorporated Ting Chen's findings in the latest version, so 512 / 1024 training for RIN should be optimal by setting the scale to be less than 1

ajabri commented 1 year ago

@Lamikins sorry I missed your message above, here are some useful hyper params for CIFAR-10: Optimizer: LAMB with lr3e-3, cosine decay, weight decay of 0.01, dropout 0.1. Arch: 3 blocks of depth 2, 128 latents, 512 latent dim, 256 interface dim.

lucidrains commented 1 year ago

closing this issue, as mostly resolved except for the tangent on optimizers

qsh-zh commented 1 year ago

Great ideas @ajabri, I'm investigating the bptt unrolling idea. I also think the latent shifting idea is good, but haven't been able to get it to work, will try again.

Phil here are the CIFAR10 generations I got after 70k steps at bs=512, AdamW 1e-3, eps_pred, DDPM 1000 steps. It took longer to train than I expected, around 12 hours on 1xa100, and I think I can achieve better results with another 12 hours. However, I calculated that it should only have taken ~6-12 hours total, assuming 1xa100 = 2x TPUv3 (3 hours on 8 chips). Let me know what results you are able to get.

media_images_sample_58650_feee851cdc4be39a9d82

Darius

@Lamikins What FID can you get on CIFAR10?

nicolas-dufour commented 1 year ago

Hey @ajabri. I've tried to reproduce your results on CIFAR but i could not achieved the same speed performance that you have on the paper (3h on 8 TPUv3). I'm myself running on 8 V100 with global batch-size of 256 and doing 150k iteration. I'm using the configs you shared with 3 blocks of 2 modules. My training speed is around 4.5it/s so 9h training. I'm curious if you had any tricks to speed up and what training speed you achieved? Thanks!

ajabri commented 1 year ago

Hi Nicolas,

We've finally been able to open source the code: Architecture: https://github.com/google-research/pix2seq/blob/main/architectures/tape.py Image diffusion: https://github.com/google-research/pix2seq/blob/main/models/image_diffusion_model.py CIFAR-10 config: https://github.com/google-research/pix2seq/blob/main/configs/config_diffusion_cifar10.py

Hope this helps!

Allan