tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
827 stars 201 forks source link

API to read weights from ckpt file #420

Open kushal-g opened 2 years ago

kushal-g commented 2 years ago

System information

Describe the feature and the current behavior/state. This API would help in reading the weights from ckpt file. Currently, there is no such feature in the Java SDK but the same can be achieved in python via CheckPointReader.

Will this change the current API? How? This shouldn't affect the preexisting API. It would be a simple addition over it.

Who will benefit with this feature? I was developing a Federated Learning Application and was stuck at the issue where I was not able to get the weights from the ckpt file (Reason: No API in Java SDK) and neither could I write a signature method in tflite for the same (Reason: tflite doesn't support dynamic tensor shapes). So this would help promote active development in FL on Android clients.

Craigacp commented 2 years ago

session.restore(String path) should do that already - https://github.com/tensorflow/java/blob/master/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java#L709.

kushal-g commented 2 years ago

How would I get the graph for creating the session? I need to get the weights from the following ckpt file.

https://drive.google.com/file/d/1ZH1SiF4Z-YYGZjwPbtd4cpCbvGXYIuu2/view?usp=sharing

This operation will take place in an Android app so as to send these weights to a server to aggregate using Federated Learning.

Craigacp commented 2 years ago

The graph should be defined somewhere as a protobuf, but you can also specify the model yourself. How are you loading in the model architecture at the moment?

kushal-g commented 2 years ago

I am not. I assumed that the model architecture could be loaded from the ckpt itself. I have the model architecture in a python code. From what I understand from your reply, I have two options:

1) Export the model architecture as protobuf using python code and then load the architecture in my java code to generate the graph and then use the session to get the weights 2) Reimplement the model inside my java code and then use that graph to load the session and get the weights.

Is that correct? If so, could you guide me to some documentation for this?

Craigacp commented 2 years ago

Variable checkpoints are stored separately from model structure in TensorFlow, so you need both components to load in a model. The saved model format is a directory containing the model structure protobuf and a variable checkpoint, so you can load that in and continue training. There are other ways of exporting the computation graph, but those are harder to do now in TF 2.x. Otherwise then you'll need to implement the model in Java (which might prove a bit tricky as you'll need to enforce that the variables all have the same names as the variables in Python).