google-research / big_vision

Official codebase used to develop Vision Transformer, SigLIP, MLP-Mixer, LiT and more.
Apache License 2.0
2.16k stars 147 forks source link

Reduce peak memory usage when freezing parameters. #14

Closed lkhphuc closed 2 years ago

lkhphuc commented 2 years ago

I discovered a optax.set_to_zero() from this thread.

When compare with the original optax.scale(0.0) on a ViT H/16 with some heads, peak GPU memory usage (by setting XLA_PYTHON_CLIENT_PREALLOCATE=false):

The frozen weight was set in the config like this (for both current and PR change):

  config.schedule = [
    (".*ViT_0/.*", None),
    (".*", dict(warmup_steps=2500))
  ]

Theoretically memory usage should be the same after jitted, so I'm not sure if this is a GPU-specific bug from jax or not.