n2cholas / jax-resnet

Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
https://pypi.org/project/jax-resnet/
MIT License
102 stars 8 forks source link

Training state of ResNet coupled with mutable batch_stats collection #9

Open cgarciae opened 1 year ago

cgarciae commented 1 year ago

Hey @n2cholas!

This is not an immediate issue but I was playing around with jax_resnet and noticed that ConvBlock decides if it should update it batch statistics or not depending on whether the batch_stats collection is mutable or not. This initially sounds like a safe bet but if you embed ResNet inside a another module that by chance also uses BatchNorm and you want to train the other module but freeze ResNet, it is not clear how you would do this.

https://github.com/n2cholas/jax-resnet/blob/5b00735aa0a68ec239af4a728ad4a596c1b551f6/jax_resnet/common.py#L43-L44

To solve this you have to:

Some repos use a single train flag to determine the state of both BatchNorm and Dropout.

Anyway, not an immediate issue for me but might help some users in the future. Happy to send a PR if the changes makes sense.

n2cholas commented 1 year ago

Thanks for raising this @cgarciae, definitely a relevant use case. I would prefer having a use_running_average member variable in ConvBlock. Perhaps in the future we can add a use_running_average=None argument in ConvBlock.__call__ if there is sufficient demand, then use nn.merge_param just like Flax does, but my general preference is to configure the behaviour of the module during construction (with @nn.compact you do both at once anyway).

Would be amazing if you could open a PR. Let me know if you have any issues with the environment/tests.