All tests actually still pass for me locally if I remove our manual trainmode!/testmode! usage. Do we really need to manually manipulate trainmode!/testmode!, then, or is Flux's default istraining behavior (which considers BatchNorm et al. to be "active" during forward pass only if forward pass is executed within Zygote) sufficient?
If we do need trainmode!/testmode! manually, we probably want to make sure we still handle it correctly in the event of an exception:
diff --git a/src/LighthouseFlux.jl b/src/LighthouseFlux.jl
index 55ccc2b..d5af80b 100644
--- a/src/LighthouseFlux.jl
+++ b/src/LighthouseFlux.jl
@@ -85,24 +85,27 @@ end
function Lighthouse.train!(classifier::FluxClassifier, batches, logger)
Flux.trainmode!(classifier.model)
- weights = Zygote.Params(classifier.params)
- for batch in batches
- train_loss, back = log_resource_info!(logger, "train/forward_pass";
- suffix="_per_batch") do
- f = () -> loss(classifier.model, batch...)
- return Zygote.pullback(f, weights)
- end
- log_value!(logger, "train/loss_per_batch", train_loss)
- gradients = log_resource_info!(logger, "train/reverse_pass";
- suffix="_per_batch") do
- return back(Zygote.sensitivity(train_loss))
- end
- log_resource_info!(logger, "train/update"; suffix="_per_batch") do
- Flux.Optimise.update!(classifier.optimiser, weights, gradients)
- return nothing
+ try
+ weights = Zygote.Params(classifier.params)
+ for batch in batches
+ train_loss, back = log_resource_info!(logger, "train/forward_pass";
+ suffix="_per_batch") do
+ f = () -> loss(classifier.model, batch...)
+ return Zygote.pullback(f, weights)
+ end
+ log_value!(logger, "train/loss_per_batch", train_loss)
+ gradients = log_resource_info!(logger, "train/reverse_pass";
+ suffix="_per_batch") do
+ return back(Zygote.sensitivity(train_loss))
+ end
+ log_resource_info!(logger, "train/update"; suffix="_per_batch") do
+ Flux.Optimise.update!(classifier.optimiser, weights, gradients)
+ return nothing
+ end
end
+ finally
+ Flux.testmode!(classifier.model)
end
- Flux.testmode!(classifier.model)
return nothing
end
If its the case that explicit train/testmode! calls don't necessarily help us in the current state of model training, then using the default istraining() mechanism is sufficient.
All tests actually still pass for me locally if I remove our manual
trainmode!
/testmode!
usage. Do we really need to manually manipulatetrainmode!
/testmode!
, then, or is Flux's defaultistraining
behavior (which considers BatchNorm et al. to be "active" during forward pass only if forward pass is executed within Zygote) sufficient?If we do need
trainmode!
/testmode!
manually, we probably want to make sure we still handle it correctly in the event of an exception: