Closed jsdt closed 2 weeks ago
I would be a bit worried to modify the semantics of Var::from_tensor
as it's suppose to not modify the original tensor. The idea is that one should use VarBuilder::get_with_hints
to initialize variables within candle-nn
modules. Did you try using the batch_norm
function here to create the batch-norm passing it the VarMap
as a var-builder or did the issue occurred using another way to create the BatchNorm
?
Yes, I ran into this issue using that batch_norm
function. The problem is that it calls get_with_hints
, but then calls Var::from_tensor
with the result here, which creates a new variable (disconnected from the one in the VarBuilder).
I added a test illustrating the issue. If you run that test without the change, the last assertion will fail.
Ah, makes sense indeed, thanks for catching this!
Before this change, if you use a
VarMap
to build aBatchNorm
, then train theBatchNorm
, the variables forrunning_mean
andrunning_var
in theVarMap
won't be updated as the model is trained. That's because it callsVar::from_tensor
to create those variables, which is creating new variables that aren't tracked in theVarMap
.After this change, if the tensor that
batch_norm
gets from theVarBuilder
is a variable, we will just reuse it instead of creating a new one.It still feels error prone, since it is possible to create
Var
s that can be updated inforward_train
without ever being saved, but I'm not sure what the best way to fix it is. MaybeBatchNorm
shouldn't updaterunning_mean
andrunning_var
if it is created with detached tensors.