google-research / big_vision

Official codebase used to develop Vision Transformer, SigLIP, MLP-Mixer, LiT and more.
Apache License 2.0
2.2k stars 147 forks source link

implement gsam in jax #8

Closed juntang-zhuang closed 2 years ago

juntang-zhuang commented 2 years ago

Hi, @lucasb-eyer thanks for your review and comments. I reformated the files and squashed commits into a new PR (sorry I messed up the old PR and could not squash commits there). This PR includes:

1) Put GSAM related configs into config.gsam and call gsam with l, grads = gsam_gradient(loss_fn=loss_fn, base_opt=opt, inputs=images, targets=labels, lr=learning_rate, **config["gsam"]) 2) Add big_vision/configs/proj/gsam/vit_1k_gsam_no_aug.py, the network used in GSAM paper used pool_type='gap' and rep_size=False, which is different from the default config. 2) Fix format issues and squash commits.

Regarding reproducing the experiments, I wonder if it's possible for you to run the script (with 8x8 TPU cores to exactly match the paper)? I'm sorry I don't have access to TPU resources since I'm not affiliated with Google now, so I can't run experiments, though the checkpoints and the old version code that I used were kept in server. Thanks so much for your code review and help!

lucasb-eyer commented 2 years ago

Regarding running experiments, I could give it a try at some point, but definitely impossible to do so this week. I would run exactly the config you provide, and you need to tell me exactly which number in which table of the paper it's supposed to reproduce.

juntang-zhuang commented 2 years ago

Regarding running experiments, I could give it a try at some point, but definitely impossible to do so this week. I would run exactly the config you provide, and you need to tell me exactly which number in which table of the paper it's supposed to reproduce.

Thanks a lot! If the effective wd schedule is not figured out, I might need to find some way to either implement the old versioned weight decay schedule, or tune the hyper-param with the new setting. I wonder if you could point Ting to the docs on how to run this repository internally, and I'll submit codes from external, so we could re-run some experiments to reproduce?

lucasb-eyer commented 2 years ago

hey, sorry I got distracted by something urgent to finish, will get back to this in one of the next two weeks and am optimistic we can get it to work well :)

edit: however, you did not yet tell me which exact number from the paper the config should be reproducing?

juntang-zhuang commented 2 years ago

Thanks for the response. Sorry about the missing number, it's supposed to reproduce the 76.8 for ViT-B/32 in Table 1 of https://openreview.net/pdf?id=edONMAnhLu- .

I'm not fully sure about the new wdecay and lr scheduler. In the old version, lr scheduler is a single function (here lr scheduler func seems to be chained with a bunch of other schedulers); in the old version, wdecay is multiplied by lr, so wdecay is actually a scheduler rather than constant, is the new wdecay set to a constant?

lucasb-eyer commented 2 years ago

oh, and you have a bunch of small issues like wrong indentations, trailing spaces, etc. It would be helpful if you could run pylint with this config over it, then I don't need to fix these later on.

lucasb-eyer commented 2 years ago

and another minor nitpick: could you rename the config from ...1k... to ...i1k...? Because we never call ImageNet 1k, but always i1k in the whole codebase. I assume you made a typo.

lucasb-eyer commented 2 years ago

Here is training_loss of running this config, sweeping over wd=0.0009 (=0.3*0.003, should be exact same as in paper), 0.001 (nicer number close to previous one), and 0.3 (just in case). The loss is crazy, accuracy is and stays at random (not shown): image

However, I find the fact that it starts at 693.15, roughly 100x the standard starting-loss of i1k (log1000=6.907) somewhat suspicious. I noticed the config is using sigmoid_xent loss, your paper does not mention the words "softmax" or "sigmoid" ; could it be that you trained with softmax_xent and have sigmoid_xent here in the config by mistake? I'll try a run with that instead, but please take another careful read over the config and see if you can find other sources of this.

Another thing, the config does not contain the config.init_head_bias, which we often, but not always, use. Could this also be a mistake? (I'll also schedule an experiment about this).

juntang-zhuang commented 2 years ago

Thanks a lot for the experiments, seems the config is not correct. I'll discuss it with Ting and see if we can directly compare the config file with the one we used for experiments.

lucasb-eyer commented 2 years ago

So far, no luck with any of (sigmoid->softmax, head-bias init, ) made it any better.

Then, I also tried the follwing things:

  1. Disable weight-decay altogether, to check whether I can at least overfit. Nope, still an exploding loss, so the issue seems unrelated to wd(?)
  2. Model with cls-token and mlp-head (repr_size=True), as this was original vit. A complete disaster :)

So, I tried all the ideas I had regarding configuration, and at this point wonder if maybe there's a bug in the implementation. Could you please try on your side? Note that you don't need TPU access to run big_vision, it works great on GPUs too, we did update the README with instructions about that. Let me know when you figure out a setting/code change such that the loss does not explode in the first hundreds of steps anymore, and I can then try longer runs for you again. (I'll also ping Ting my runs internally).

lucasb-eyer commented 2 years ago

I forgot to mention, but I also tried a run with adam 1t momentum not in bfloat16, but in regular float32, and it makes no difference. Note this bfloat16 really just affects the 1st momentum buffer, nothing else.

lucasb-eyer commented 2 years ago

Ting shared with me your exact runs from the paper numbers, so I could dig in a bit more. Carefully replicating exactly the config that was run, I still get similar behaviour, though slightly less extreme ("only" going up to hundreds, not millions): image

At this point, I feel like this must be a bug in the code. It seems to go wrong after ~500 steps, potentially you can even run that on CPUs to debug?

juntang-zhuang commented 2 years ago

Thanks a lot for the feedback and experiments, I'll dig it out with Ting, and will post the working version here. Sorry for all the trouble with this PR.

lucasb-eyer commented 2 years ago

Sorry for all the trouble with this PR

No worries, I will be happy and thankful to have up-to-date GSAM and SAM in the codebase!

evcu commented 2 years ago

I also tried to run this with alpha=0, and it looks slightly better at the start, but still explodes after 1-2k step.

lucasb-eyer commented 2 years ago

I just noticed in one of your changes a few days ago, you did find a bug:

    learning_rate = sched_fns[0](step)   # Wrong
    learning_rate = sched_fns[0](step) * config["lr"]   # Your fix

This looks very promising! So I patched it in and tried another run on top of the last one I mentioned here. It looks a lot better! It doesn't explode, and reaches 75.2/81.8/61.0 validation/real/v2 accuracy after 90 epochs. This not yet the expected 76.8/82.7/63.0 we're trying to reproduce, but it's getting much closer :partying_face:

However, the missing 1.6% are still significant, so we should find them before merging this. I carefully compared configs (already before, but once again) and didn't find a new discrepancy. With alpha=0 I should get SAM, right? Were the SAM and Vanilla numbers in Table1 also produced by you, or copied from somewhere? If produced by you, I could also run SAM and Vanilla and see if I can reproduce them, it would give us an indication where the remaining mistake can be.

Here are a few metrics, do they all look reasonable to you? image

juntang-zhuang commented 2 years ago

@lucasb-eyer Thanks so much for running experiments! I'm also running an experiment on ViT-S/32, but takes much longer on my GPU machine, will also post results here after it finishes.

The results for SAM are copied from https://arxiv.org/abs/2106.01548 table 2. For the gap of 1.6%, it might come from

In previous updates, I made a few changes that potentially make a difference, including the following:

  1. pass the absolute learning rate learning_rate = sched_fns[0](step) * config["lr"] instead of learning_rate = sched_fns[0] (step)
  2. in config.gsam sets absolute values to lr_max=config.get_ref('lr') and lr_min=config.schedule.get_ref('linear_end') * config.get_ref('lr')
  3. in config.schedule set linear_end=0.01 (rather than linear_end=0.00003)
  4. pass flax.jax_utils.replicate(step) when calling update_fn

(I'm not sure if 4 is necessary, just following my old code after meeting with Ting.)

For 1, it's my fault that I did not realize bv_optax defines the learning rate schedule in a relative manner, while all my code last year assumes the lr are all absolute values. This causes a bug in my previous PR, that I passed absolute lr to denominator, but relative lr to the denominator, which results in about 300x larger perturbation amplitude. Such a big perturbation would crash the network. In current version this should be fixed.

For 2 and 3, it's also caused by my mistake with lr schedule. To reproduce the paper results, the absolute learning rate is a linear decay with max_lr=0.003 and min_lr=0.00003. Switching to the relative ratio schedule, should be linear_end=0.01.

I have merged the changes above in the latest PR, let me know if you have time to take a look. I'm also reproducing a ViT-S/32 results with my machine, it's a bit slow but will post it here once I get results. Thanks again for your help with this!

lucasb-eyer commented 2 years ago

No need to blame yourself alone, I also should have noticed ALL of these during review and testing, but didn't :) Happy you found them now! Let me start some runs right away, for 300ep, and report back later today.

I actually ran all experiments on 8x8, but am curious why TPU topology would influence the results?

juntang-zhuang commented 2 years ago

Cool, I'm really excited to see the updated results, they outperform numbers in the paper! I have updated PR according to your comments, except the step is passed to update_fn rather than read out from opt.

One minor thing is, GSAM reduces to SAM requires alpha=0 and rho_max=rho_min in the gsam_gradient function, basically SAM uses a constant perturbation rho_t, GSAM scales rho_t proportional to learning rate schedule. It might not be a good idea to set constant by setting rho_max=rho_min, maybe using a bv_optax style schedule function is a better idea for code style consistency.

For TPU number, it's because that GSAM / SAM performs per-worker perturbation based on per-worker gradient in gsam_gradient, more workers will have more different perturbations, so the model effectively see more neighbors in the parameter space.

lucasb-eyer commented 2 years ago

Thanks for your comments. My "SAM" run with rho_max=rho_min=0.15 just finished, and it's quite a bit better than the paper number too. From my reading of the code, when rho_max=rho_min then we do use a constant rho value independent of learning-rate (schedule), no? image

And yes, making it use the actual schedule_fn from optax would be ideal, then we could simply use SAM with all kinds of schedules, and we don't need to manually specify lr_min/lr_max in the config anymore. That would be a lot better, but I thought that I already asked a lot from you, so didn't want to ask for that too :) If you want to do it, that's great, otherwise I may do it at some point, or maybe never, if we never need it. But this is the largest argument against having it in the core trainer for now.

lucasb-eyer commented 2 years ago

Regarding the perturbations per host, I noticed that the model souping paper states that not syncing may have a significant disadvantage: image

so it may be worth implementing. Do I understand correctly that it basically means doing jax.lax.pmean(g_clean)?

lucasb-eyer commented 2 years ago

I also just realized that we should add a pointer to this from the README. I'll do so early next week too.

juntang-zhuang commented 2 years ago

Thanks so much for your help with the debug and PR!

Regarding the rho_t schedule, yes it is constant when rho_max=rho_min, I implemented it in a way that rho_t follows the same schedule as lr_t (except they have difference value scales). It might be better to pass rho_t as another sched_fn, but I'm not familiar with the chain style fn in bv_optax, so I'm not confident to implement correctly and matching the existing code base.

For per-worker perturbation, the model soup paper seems to contradict the original SAM paper https://arxiv.org/pdf/2010.01412.pdf section 4.1. It defines m-sharpness where m is the per-worker number of examples. A smaller m (hence a larger worker number when total batchsize is fixed) improves generalization.

I'm not quite sure about model soup implementations. In my implementation (and SAM), the process is:

  1. per-worker gradient g_clean (not synced) and per-worker perturbation param_sam https://github.com/google-research/big_vision/blob/136deda7827aa52143905afc0482b79b7438c8f8/big_vision/trainers/proj/gsam/gsam.py#L69
  2. per-worker gradient g_gsam at (per-worker) perturbed model weights param_sam https://github.com/google-research/big_vision/blob/136deda7827aa52143905afc0482b79b7438c8f8/big_vision/trainers/proj/gsam/gsam.py#L91
  3. average g_gsam across workers in https://github.com/google-research/big_vision/blob/136deda7827aa52143905afc0482b79b7438c8f8/big_vision/trainers/proj/gsam/train.py#L211 note the returned grads here is g_gsam (not g_clean) in the gsam_gradient function.
  4. all workers update with the same value of globally averaged gsam in optimizer.

I'm not quite sure with model soup, but I suspect if it draws an opposite conclusion from SAM paper, it might come from a different implementation. For example, if it switches the order of 3 and 4, first performs per-worker parameter update with per-worker g_gsam, then average model weights across workers, this might harm performance compared to synced perturbation.

If want to perform synced perturbation, we can add g_clean = jax.pmean(g_clean) after https://github.com/google-research/big_vision/blob/136deda7827aa52143905afc0482b79b7438c8f8/big_vision/trainers/proj/gsam/gsam.py#L56 so that param_sam is the same for all workers