Open lorenzoh opened 1 year ago
This would fix #5 and is necessary to be able to train models as PyCallChainRules.jl requires functorch models that can't have mutating BatchNorms.
Implementation-wise, this boils down to calling functorch.experimental.replace_all_batch_norm_modules_(model) on loaded models.
functorch.experimental.replace_all_batch_norm_modules_(model)
This would fix #5 and is necessary to be able to train models as PyCallChainRules.jl requires functorch models that can't have mutating BatchNorms.
Implementation-wise, this boils down to calling
functorch.experimental.replace_all_batch_norm_modules_(model)
on loaded models.