tensorflow / rust

Rust language bindings for TensorFlow
Apache License 2.0
5.19k stars 422 forks source link

There should be a way to obtain a variable from a saved model #394

Open ramon-garcia opened 1 year ago

ramon-garcia commented 1 year ago

If one restores a model with SavedModelBundle::load() and obtains a graph, there is no way to get Variables from it. One can obtain an Operation with Graph::operation_by_name_required, but not a Variable.

Variable objects are needed, for instance, to run an optimizer with optimizer.minimize.

ramon-garcia commented 1 year ago

In addition, in models generated with Tensorflow 2, variables are implemented as ops of type VarHandleOp. In Tensorflow/Rust, variables are implemented as ops of type Variable2. This can make it more difficult to work with models generated with Python.

camdenmcgath commented 1 year ago

@ramon-garcia @adamcrume I'd be happy to look in to this if this is an issue you guys still need fixed. Would I start by implementing some function like Graph::variable_by_name_required similar to operation by name? Also do I need to reconcile the use of both VarHandleOp and Variable2 types?

adamcrume commented 1 year ago

Inside the proto in MetaGraphDef::from_serialized_proto, you can find the collection_def with the key variables, and value.bytes_list.value should contain a serialized VariableDef proto. That (plus metadata from the tensor referred to by VariableDef.variable_name) should be enough info to create a Variable struct.