huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
13.79k stars 751 forks source link

Bug Fix: When converting a tensor to a variable, clone if the tensor is already a variable. #2124

Closed jsdt closed 2 weeks ago

jsdt commented 2 weeks ago

Before this change, if you use a VarMap to build a BatchNorm, then train the BatchNorm, the variables for running_mean and running_var in the VarMap won't be updated as the model is trained. That's because it calls Var::from_tensor to create those variables, which is creating new variables that aren't tracked in the VarMap.

After this change, if the tensor that batch_norm gets from the VarBuilder is a variable, we will just reuse it instead of creating a new one.

It still feels error prone, since it is possible to create Vars that can be updated in forward_train without ever being saved, but I'm not sure what the best way to fix it is. Maybe BatchNorm shouldn't update running_mean and running_var if it is created with detached tensors.

LaurentMazare commented 2 weeks ago

I would be a bit worried to modify the semantics of Var::from_tensor as it's suppose to not modify the original tensor. The idea is that one should use VarBuilder::get_with_hints to initialize variables within candle-nn modules. Did you try using the batch_norm function here to create the batch-norm passing it the VarMap as a var-builder or did the issue occurred using another way to create the BatchNorm?

jsdt commented 2 weeks ago

Yes, I ran into this issue using that batch_norm function. The problem is that it calls get_with_hints, but then calls Var::from_tensor with the result here, which creates a new variable (disconnected from the one in the VarBuilder).

I added a test illustrating the issue. If you run that test without the change, the last assertion will fail.

LaurentMazare commented 2 weeks ago

Ah, makes sense indeed, thanks for catching this!