deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.17k stars 660 forks source link

TfModel doesn't support `Model#load(InputStream is)` #3303

Closed tadayosi closed 4 months ago

tadayosi commented 4 months ago

Description

When I try to load a model from a remote URL with the TensorFlow engine, it fails with the error: java.lang.UnsupportedOperationException: Not supported!

Expected Behavior

The TfModel should also support loading models from remote URLs (via InputStream).

Error Message

Exception in thread "main" java.lang.UnsupportedOperationException: Not supported!
        at ai.djl.BaseModel.load(BaseModel.java:126)
        at ai.djl.Model.load(Model.java:145)
        at image_enhancement2.main(image_enhancement2.java:32)

How to Reproduce?

Run this code:

System.setProperty("ai.djl.default_engine", "TensorFlow");
var model = Model.newInstance("esrgan-tf2");
var modelUrl = "https://storage.googleapis.com/tfhub-modules/captain-pool/esrgan-tf2/1.tar.gz";
model.load(new URI(modelUrl).toURL().openStream());

Steps to reproduce

  1. Compile the above code
  2. Run the code

What have you tried to solve it?

Noticed that while PtModel implements load(InputStream modelStream, Map<String, ?> options), TfModel doesn't. Should be fixed by implementing the method for TfModel.

frankfliu commented 4 months ago

@tadayosi This is a limitation in TF/Java side. We leverage tensorflow java low level API and created TfEngine in DJL. The TF/Java only provide TF_LoadSessionFromSavedModel() that load the model from a file path.

tadayosi commented 4 months ago

@frankfliu Thanks for your clarification.

By the way, loading a remote model via Criteria just works fine with the TF engine.

var criteria = Criteria.builder()
        .setTypes(Image.class, Image.class)
        .optEngine("TensorFlow")
        .optApplication(Application.CV.IMAGE_ENHANCEMENT)
        .optModelUrls("https://storage.googleapis.com/tfhub-modules/captain-pool/esrgan-tf2/1.tar.gz")
        .optTranslator(new ImageEnhancementTranslator())
        .optProgress(new ProgressBar())
        .build();
var model = criteria.loadModel();

I see the difference is that for the latter the criteria downloads a remote model locally and the model just loads the local copy. I just wonder why TfModel#load(InputStream modelStream, Map<String, ?> options) cannot do the same internally instead of relying on the TF/Java low level API.

(The question is for me to understand the design decision of the DJL API rather than pitching my request further :-) )

frankfliu commented 4 months ago

The Model.load(InputStream) API was designed to avoid save model file on disk. Some application has security requirement that the model file must be encrypted at still. Means we cannot save the unencrypted model file on disk.

There is another challenge is that, we don't really know the format of InputStream (.tar, .tgz or .zip), the unzip/unencrypt/cache on disk logic has to be implemented at higher level.

tadayosi commented 4 months ago

@frankfliu That makes sense. Thanks again for your clarification!