Closed lkhphuc closed 2 years ago
I discovered a optax.set_to_zero() from this thread.
optax.set_to_zero()
When compare with the original optax.scale(0.0) on a ViT H/16 with some heads, peak GPU memory usage (by setting XLA_PYTHON_CLIENT_PREALLOCATE=false):
optax.scale(0.0)
The frozen weight was set in the config like this (for both current and PR change):
config.schedule = [ (".*ViT_0/.*", None), (".*", dict(warmup_steps=2500)) ]
Theoretically memory usage should be the same after jitted, so I'm not sure if this is a GPU-specific bug from jax or not.
I discovered a
optax.set_to_zero()
from this thread.When compare with the original
optax.scale(0.0)
on a ViT H/16 with some heads, peak GPU memory usage (by setting XLA_PYTHON_CLIENT_PREALLOCATE=false):The frozen weight was set in the config like this (for both current and PR change):
Theoretically memory usage should be the same after jitted, so I'm not sure if this is a GPU-specific bug from jax or not.