Open GaryBoone opened 1 year ago
Re 1, all the float like variables in a var store use the same type (by default a float32), this simplifies casting to lower precision at the expense of having finer control on the actual type. You can find methods to convert your store to use double/... in the documentation. Re 3, I don't have a computer at hand to try it out but there is some magic being applied to files with the .pt extension (this magic assumes that they were written from the Python PyTorch API), could you try using another extension? The typical convention for this crate is .ot.
Re 3, success. Changing the model extension from .pt
to .ot
solved the problem. Maybe the code could issue a warning for incorrect extensions?
I wouldn't be super keen on having warnings being emitted as they are likely to clutter the process output (I think for now this crate doesn't emit any such warnings). Hopefully people running into this issue can google the error message and find this issue.
I exactly had a same issue - used .pt
extension and got this error. Maybe we can document this behaviour?
I want to save some metadata in a model. Below is a complete example of adding tensors to a VarStore using
var_copy()
, then saving/restoring them.And here's the output:
So, notes/questions/issues:
epochs
is an Int tensor, the output fore
shows that it becomes a Float when loaded into the VarStore. Thelearning_rate
is similarly changed from a Double tensor to Float. That seems like a bug: The VarStore shouldn't change the type when copying in a tensor.loaded_var_store.load("model.pt").unwrap();
line. It's an internal panic, not a panic on theunwrap()
.