google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.17k stars 649 forks source link

Multiple initializations. Is this a bug? #3499

Open henrypinkard opened 1 year ago

henrypinkard commented 1 year ago

I've run into a strange behavior, and I'm unsure if its a bug or if I'm doing something wrong, and I wasn't able to find any clarification in the docs: it seems that the weights of my network get initialized every time applyis called, instead of just the first time, as seen in the example below where I create a custom initialization function. This seems to be that it would be inefficient, so I'm wondering if I'm missing something.

class MyModule(nn.Module):

    def setup(self):

        def my_bias_init(rng, shape, dtype):
            print('bias init')
            return random.uniform(rng, shape, dtype=dtype, minval=0, maxval=2)

        self.a_layer = nn.Dense(10, bias_init=my_bias_init)

System information


Name: flax
Version: 0.7.4
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: 
Author: 
Author-email: Flax team <flax-dev@google.com>
License: 
Location: /2tb_nvme/hpinkard_waller/mambaforge/envs/phenotypes/lib/python3.10/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by: 
---
Name: jax
Version: 0.4.18
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /2tb_nvme/hpinkard_waller/mambaforge/envs/phenotypes/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, flax, optax, orbax-checkpoint
---
Name: jaxlib
Version: 0.4.18+cuda12.cudnn89
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /2tb_nvme/hpinkard_waller/mambaforge/envs/phenotypes/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, optax, orbax-checkpoint
samblouir commented 1 year ago

Similar issue here. The first few times I call some functions result in additional compiles. It seems to cause me to OOM sooner than expected. Does yours compile, always?

cgarciae commented 1 year ago

Looking into this...

samblouir commented 1 year ago

I apologize, I updated to the latest versions of JAX (0.4.18 -> 0.4.21) and FLAX (0.7.2? -> 0.8.0) and this seems to have resolved itself. I do not see this behavior in my logs anymore.

This has fixed the OOMing: train_step = jax.jit(train_step, donate_argnums=(0,),) # 0 is the state argument's indice

If it helps, earlier, calling "state = train_step(state, batch); state = train_step(state, batch)" (sequentially) was causing 2x compilations, even after jax.jit(train_state). My XLA flags in the environment were empty.

I logged this re-compilation occuring several times during training on a fixed evaluation function that uses the same inputs every time. This seemed to cause random OOMs after stably training and evaluating for an hour or so. I also filtered my inputs to throw out any batches not matching the intended shape, but no samples were ever caught.

cgarciae commented 1 year ago

Regarding the original issue of bias_init running twice, it turns out that the .param method runs the initializer under eval_shape if the param exists to do some validation. So you will the code run each time but only the first time will actually perform allocations.

https://github.com/google/flax/blob/e172c768965034acbe1d42a50c258e06c6400d43/flax/core/scope.py#L956

henrypinkard commented 1 year ago

Thanks! pretty non-initutive, but glad to know I wasn't doing something incorrectly