beacon-biosignals / LighthouseFlux.jl

An adapter package that implements Lighthouse's framework interface for Flux
MIT License
1 stars 1 forks source link

usage of trainmode!/testmode! #6

Open jrevels opened 4 years ago

jrevels commented 4 years ago

All tests actually still pass for me locally if I remove our manual trainmode!/testmode! usage. Do we really need to manually manipulate trainmode!/testmode!, then, or is Flux's default istraining behavior (which considers BatchNorm et al. to be "active" during forward pass only if forward pass is executed within Zygote) sufficient?

If we do need trainmode!/testmode! manually, we probably want to make sure we still handle it correctly in the event of an exception:

diff --git a/src/LighthouseFlux.jl b/src/LighthouseFlux.jl
index 55ccc2b..d5af80b 100644
--- a/src/LighthouseFlux.jl
+++ b/src/LighthouseFlux.jl
@@ -85,24 +85,27 @@ end

 function Lighthouse.train!(classifier::FluxClassifier, batches, logger)
     Flux.trainmode!(classifier.model)
-    weights = Zygote.Params(classifier.params)
-    for batch in batches
-        train_loss, back = log_resource_info!(logger, "train/forward_pass";
-                                              suffix="_per_batch") do
-            f = () -> loss(classifier.model, batch...)
-            return Zygote.pullback(f, weights)
-        end
-        log_value!(logger, "train/loss_per_batch", train_loss)
-        gradients = log_resource_info!(logger, "train/reverse_pass";
-                                       suffix="_per_batch") do
-            return back(Zygote.sensitivity(train_loss))
-        end
-        log_resource_info!(logger, "train/update"; suffix="_per_batch") do
-            Flux.Optimise.update!(classifier.optimiser, weights, gradients)
-            return nothing
+    try
+        weights = Zygote.Params(classifier.params)
+        for batch in batches
+            train_loss, back = log_resource_info!(logger, "train/forward_pass";
+                                                  suffix="_per_batch") do
+                f = () -> loss(classifier.model, batch...)
+                return Zygote.pullback(f, weights)
+            end
+            log_value!(logger, "train/loss_per_batch", train_loss)
+            gradients = log_resource_info!(logger, "train/reverse_pass";
+                                           suffix="_per_batch") do
+                return back(Zygote.sensitivity(train_loss))
+            end
+            log_resource_info!(logger, "train/update"; suffix="_per_batch") do
+                Flux.Optimise.update!(classifier.optimiser, weights, gradients)
+                return nothing
+            end
         end
+    finally
+        Flux.testmode!(classifier.model)
     end
-    Flux.testmode!(classifier.model)
     return nothing
 end
anoojpatel commented 4 years ago

If its the case that explicit train/testmode! calls don't necessarily help us in the current state of model training, then using the default istraining() mechanism is sufficient.