Open henrypinkard opened 1 year ago
Similar issue here. The first few times I call some functions result in additional compiles. It seems to cause me to OOM sooner than expected. Does yours compile, always?
Looking into this...
I apologize, I updated to the latest versions of JAX (0.4.18 -> 0.4.21) and FLAX (0.7.2? -> 0.8.0) and this seems to have resolved itself. I do not see this behavior in my logs anymore.
This has fixed the OOMing:
train_step = jax.jit(train_step, donate_argnums=(0,),) # 0 is the state argument's indice
If it helps, earlier, calling "state = train_step(state, batch); state = train_step(state, batch)" (sequentially) was causing 2x compilations, even after jax.jit(train_state). My XLA flags in the environment were empty.
I logged this re-compilation occuring several times during training on a fixed evaluation function that uses the same inputs every time. This seemed to cause random OOMs after stably training and evaluating for an hour or so. I also filtered my inputs to throw out any batches not matching the intended shape, but no samples were ever caught.
Regarding the original issue of bias_init
running twice, it turns out that the .param
method runs the initializer under eval_shape
if the param exists to do some validation. So you will the code run each time but only the first time will actually perform allocations.
https://github.com/google/flax/blob/e172c768965034acbe1d42a50c258e06c6400d43/flax/core/scope.py#L956
Thanks! pretty non-initutive, but glad to know I wasn't doing something incorrectly
I've run into a strange behavior, and I'm unsure if its a bug or if I'm doing something wrong, and I wasn't able to find any clarification in the docs: it seems that the weights of my network get initialized every time
apply
is called, instead of just the first time, as seen in the example below where I create a custom initialization function. This seems to be that it would be inefficient, so I'm wondering if I'm missing something.System information