Open cgarciae opened 1 year ago
Thanks for raising this @cgarciae, definitely a relevant use case. I would prefer having a use_running_average
member variable in ConvBlock
. Perhaps in the future we can add a use_running_average=None
argument in ConvBlock.__call__
if there is sufficient demand, then use nn.merge_param
just like Flax does, but my general preference is to configure the behaviour of the module during construction (with @nn.compact
you do both at once anyway).
Would be amazing if you could open a PR. Let me know if you have any issues with the environment/tests.
Hey @n2cholas!
This is not an immediate issue but I was playing around with
jax_resnet
and noticed thatConvBlock
decides if it should update it batch statistics or not depending on whether thebatch_stats
collection is mutable or not. This initially sounds like a safe bet but if you embedResNet
inside a another module that by chance also usesBatchNorm
and you want to train the other module but freezeResNet
, it is not clear how you would do this.https://github.com/n2cholas/jax-resnet/blob/5b00735aa0a68ec239af4a728ad4a596c1b551f6/jax_resnet/common.py#L43-L44
To solve this you have to:
use_running_average
(or equivalent) argument inConvBlock.__call__
and pass it tonorm_cls
.ResNet
to be a custom Module (instead ofSequential
) so you also accept this in__call__
and pass it around to the relevant submodules that expect it.Some repos use a single
train
flag to determine the state of both BatchNorm and Dropout.Anyway, not an immediate issue for me but might help some users in the future. Happy to send a PR if the changes makes sense.