Closed ajabri closed 1 year ago
Hi @ajabri, what do you mean by tying the weights ?
hey Allan, given the claims of the paper, could you try training on 1024 x 1024 FFHQ, especially if you still have access to google-level compute? the results would be most informative. I'm skeptical you can stick with 256 latents and still get good results, but willing to change my mind.
@lucidrains How long did you train the 130k flowers model for on how many GPUs? What were the hyperparameters you used, like model dimension and number of layers?
@Lamikins pretty much what is in the readme
it should be even better, as the run i did was using weight tied layers
@ajabri I was unable to reproduce your CIFAR10 results from the paper,
here are my FID scores over ~90k steps using RIN with dim=256, patch_size=1, num_latents=256, depth=6, latent_self_attn_depth=2 and batch_size=128. Learning Rate scaled from 1e-3 to ~2e-4 over training. EMA set to .9999 after a few thousand steps.
I am wondering how you were able to achieve the 1.81 FID in the paper-- did you use any additional model changes, or additional hyperparameters?
Example Generations:
@Lamikins, from what I understand the latent (z) dim and attention should be larger than image (x) dim and attention, yet they are the same in this implementation. Is that correct @ajabri, @lucidrains ?
Yea that's correct; Although the cross attention requires the dims of Q,K,V to be the same, so eventually we will have to map to the same dimensions anyway
Yeah except for the the x-z cross attention and x self-attention, the self-attention for z can be much larger so that the model can have greater learning capacity
oh interesting, I overlooked that. I'll try running those experiments.
@Lamikins i apologize but i had a big bug where i tied the cross attention layers going to and from image to latents - should be fixed in 0.1.0! i'll rerun my experiments overnight
@thuanz123 thanks Thuan for noticing that the latent dimension can be separate from the image dimension. i've added the option as dim_latent
Thanks Phil! I caught that as well, unfortunately despite my best efforts this weekend I was unable to get RIN to converge. I used this architecture in the k-diffusion repo and the model is very unstable during training.
I added layer norms to the architecture similar to the normal transformer and that seemed to help a bit.
However, despite adding layer norm, untying the weights, and enabling different latent_dim and interface_dim, the FID scores are nowhere near the paper (nor near the base model of k-diffusion, a CIFAR FID score of ~28). The increasing FID scores are due to nans in the base model (I use the EMA model to calculate FID).
This is using a batch size of 128, lr=1e-3 with ADAM, and varying between 30 and 75M parameters.
One thing to note: without using layernorm on every latent output, the magnitude of the latents grows very large, which is a reason for the NaNs.
Red has no layer norm, the others do. Obviously this effect is worse as depth increases.
@Lamikins oh that's strange, i have pre-layernorms for all the attention and feedforward layers
@Lamikins try lowering your learning rate to something like 2e-4
@Lamikins where are you adding the layernorms?
@Lamikins actually yea, that's another thing with the ISAB-like architecture. i've run into stability issues trying to apply it for a contract work where i needed long context, so i could believe what you are telling me
yea I think learning rate may be too high-- although I don't have the compute to do 12 hour long runs with a lower rate š Let me know if you want to collaborate on reproduction.
In this block:
# latent self attention
for attn, ff, ln1, ln2 in self.latent_self_attns:
latents = attn(ln1(latents), time=t) + latents
latents = ln2(ff(latents, time=t) + latents)
I found that ln2(ff(latents, time=t) + latents)
versus the more standard ff( ln2(latents), time=t) + latents
makes it harder for the latent vector to explode.
Upon your comment I realize ln1 is redundant, so I'll try removing it
@Lamikins ohh got it, yes, SwinV2 actually uses that pattern, but they place the ln2
every 3rd transformer block
@Lamikins actually yea, it may make sense to have a norm just before cross attending from the image features to the latents, let me add that
@ajabri hi Allan, just saw this paper https://arxiv.org/abs/2301.10972
Sorry for doubting your results. Will be putting more work into RIN soon š
@lucidrains hi Phil, can I ask what's the reason of adding input self conditioning? Did you find any performance increase from it? Since in the paper I see only latent self conditioning being used.
@juvekaradheesh i can make that optional, was just improvising
No that's alright, I was just curious if it improved performance, to decide if I should keep it or not. But anyway, it might be dataset specific.
@lucidrains @ajabri I was able to reproduce the CIFAR 10 results using eps pred instead of x_0 pred in this repo. One challenge: because we use latent warm-starting, it's hard to get good results with smaller sampling steps. Sampling seems to require the latents to be "close" to the learned distribution, ie t-1 is very similar to t
Allan does this seem plausible? Any ideas on how to mitigate the effect?
Hi @Lamikins, For shorter sampling I found that DDIM works better. The way the latent self-condition is trained atm is a bit naive... it would be interesting to try unrolling the network, training it with different gaps between the latent self-cond, or using progressive distillation (perhaps with bptt). And yea, eps pred works significantly better; also good to use the lamb optimizer with weight decay on kernels and pos embeddings.
oh oops, i don't remember why i had predict x0 objective as default
will get the predict eps objective in there tomorrow morning! (as well as the resolution dependent noise schedules from Ting Chen)
Great ideas @ajabri, I'm investigating the bptt unrolling idea. I also think the latent shifting idea is good, but haven't been able to get it to work, will try again.
Phil here are the CIFAR10 generations I got after 70k steps at bs=512, AdamW 1e-3, eps_pred, DDPM 1000 steps. It took longer to train than I expected, around 12 hours on 1xa100, and I think I can achieve better results with another 12 hours. However, I calculated that it should only have taken ~6-12 hours total, assuming 1xa100 = 2x TPUv3 (3 hours on 8 chips). Let me know what results you are able to get.
Darius
@Lamikins LAMB optimizer with lr=1e-3 and wd on pos embeddings + kernels should work better. Sampling with 250 steps of DDIM should be decent. Should train for at least 150k steps (w/ bs=1024, so longer with bs=512). What kind of arch did you use?
Great ideas @ajabri, I'm investigating the bptt unrolling idea. I also think the latent shifting idea is good, but haven't been able to get it to work, will try again.
Phil here are the CIFAR10 generations I got after 70k steps at bs=512, AdamW 1e-3, eps_pred, DDPM 1000 steps. It took longer to train than I expected, around 12 hours on 1xa100, and I think I can achieve better results with another 12 hours. However, I calculated that it should only have taken ~6-12 hours total, assuming 1xa100 = 2x TPUv3 (3 hours on 8 chips). Let me know what results you are able to get.
Darius
yup it looks good
early generations look unfamiliar and checkerboardy given the patch training, but it converges nicely
i will try Ting Chen's logsnr shifting idea soon
@Lamikins LAMB optimizer with lr=1e-3 and wd on pos embeddings + kernels should work better. Sampling with 250 steps of DDIM should be decent. Should train for at least 150k steps (w/ bs=1024, so longer with bs=512). What kind of arch did you use?
hmm, i think LAMB is strictly a google thing, all the papers i've read suggest it doesn't do that much
but i'll stay open minded and look around for an open sourced implementation (or code up my own)
@lucidrains In case it might save you time, note that the logsnr shifting isn't necessary for images at 256 and below (only really helps significantly when going to 512 and 1024, and mainly for faster convergence). Also, you'll find that adamw converges to higher losses and worse sample metrics across the board; from my experience, LAMB is very useful when learning latent cross attention (see also use in Perceiver/PerceiverIO).
@lucidrains yes ofc bigger is better :) (and the main motivation for RIN) ... but training a 512/1024 model is ofc more expensive (e.g. the 1024px model was trained on 256-TPUv4 w bsz=512 for 1M iterations), so may not be ideal for tuning.
@lucidrains actually some friends put together this open-sourced LAMB: https://github.com/cybertronai/pytorch-lamb
It's 3 years old and slow so it might be worth forking + updating
ok, i've incorporated Ting Chen's findings in the latest version, so 512 / 1024 training for RIN should be optimal by setting the scale
to be less than 1
@Lamikins sorry I missed your message above, here are some useful hyper params for CIFAR-10: Optimizer: LAMB with lr3e-3, cosine decay, weight decay of 0.01, dropout 0.1. Arch: 3 blocks of depth 2, 128 latents, 512 latent dim, 256 interface dim.
closing this issue, as mostly resolved except for the tangent on optimizers
Great ideas @ajabri, I'm investigating the bptt unrolling idea. I also think the latent shifting idea is good, but haven't been able to get it to work, will try again.
Phil here are the CIFAR10 generations I got after 70k steps at bs=512, AdamW 1e-3, eps_pred, DDPM 1000 steps. It took longer to train than I expected, around 12 hours on 1xa100, and I think I can achieve better results with another 12 hours. However, I calculated that it should only have taken ~6-12 hours total, assuming 1xa100 = 2x TPUv3 (3 hours on 8 chips). Let me know what results you are able to get.
Darius
@Lamikins What FID can you get on CIFAR10?
Hey @ajabri. I've tried to reproduce your results on CIFAR but i could not achieved the same speed performance that you have on the paper (3h on 8 TPUv3). I'm myself running on 8 V100 with global batch-size of 256 and doing 150k iteration. I'm using the configs you shared with 3 blocks of 2 modules. My training speed is around 4.5it/s so 9h training. I'm curious if you had any tricks to speed up and what training speed you achieved? Thanks!
Hi Nicolas,
We've finally been able to open source the code: Architecture: https://github.com/google-research/pix2seq/blob/main/architectures/tape.py Image diffusion: https://github.com/google-research/pix2seq/blob/main/models/image_diffusion_model.py CIFAR-10 config: https://github.com/google-research/pix2seq/blob/main/configs/config_diffusion_cifar10.py
Hope this helps!
Allan
Hi Phil,
Thanks for your interest in the architecture!
I just wanted to point out that we don't actually tie the weights across RIN blocks of the same forward pass: https://github.com/lucidrains/recurrent-interface-network-pytorch/blob/main/rin_pytorch/rin_pytorch.py#L398 This is important because tieing the weights in the forward pass (which I experimented with) gives the model significantly less capacity. Will highlight this more in the paper's pseudo-code in an updated version.
In that way, the architecture is actually more similar to models like Perceiver / BigBird / memory transformer, rather than ISAB, because achieving depth without reinitializing "inducing points" is quite important for the model to learn how to use the latents effectively (ISAB mostly achieves clustering of the data tokens at each block, since it resets inducing points at each block).
You can view RIN as factoring a global-self attention matrix that operates on a set of [global, data tokens] into global-global, global-data (write), data-global (read), and data-data interactions. We never use data-data interactions (like Perceiver), but show that adding global-data interactions (writing) is important. We use global-global interactions (like BigBird / memory transformer), but apply them more frequently than other interactions; importantly, global tokens have more channels and are processed with wider blocks.
The 'recurrent' aspect actually comes from the fact that warm-starting the latents across iterations (i.e. forward passes) gives you an effect akin to an RNN, since you can propagate latent information; the result is that the model is effectively very deep at inference.
Happy to discuss further :) Allan