tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
813 stars 200 forks source link

Still able to access raw tensor data after close() is called #460

Open ku222 opened 2 years ago

ku222 commented 2 years ago

System information

Describe the current behavior After closing a tensor output from a concrete function (and the concrete function itself), am still able to access the tensor's raw data.

Describe the expected behavior Expect an IllegalStateException when accessing the raw tensor data because close() was already called on the tensor.

Code to reproduce the issue

// Create simple (a + b) function
ConcreteFunction addTwoInts = ConcreteFunction.create((tf) -> {
    Placeholder<TInt32> inputA = tf.placeholder(TInt32.class);
    Placeholder<TInt32> inputB = tf.placeholder(TInt32.class);
    Add<TInt32> output = tf.math.add(inputA, inputB);
    return Signature.builder().key("add")
            .input("a", inputA)
            .input("b", inputB)
            .output("out", output)
            .build();
});

// Apply to input
Map<String, Tensor> input = Map.of(
        "a", TInt32.scalarOf(1),
        "b", TInt32.scalarOf(2));
Tensor output = addTwoInts.call(input).get("out");

// Close everything
addTwoInts.close();
output.close();

// Expect java.lang.IllegalStateException: close() was called on the Tensor
output.asRawTensor().data();  // However, no exception is thrown here

Have recently upgraded from v0.2.0, where the same test case setup would throw the excepted exception.

karllessard commented 2 years ago

The semantic of ConcreteFunction has changed a lot since 0.2.0. Now it returns a callable that could be called eagerly or attached to another graph, replicating what Python is doing. The previous implementation of ConcreteFunction is closer to what we call now a SessionFunction, though the function does not own a graph/session anymore, it needs to be passed explicitly and managed outside the scope of the function, like this:

void main() {
    try (var g = new Graph(); var s = new Session(g)) {
        var function = SessionFunction.create(myFunc(Ops.create(g)), session);
        try (var result = function.call(...)) {
            ...
        }
    }
}

Signature myFunc(Ops tf) {
    Placeholder<TInt32> inputA = tf.placeholder(TInt32.class);
    Placeholder<TInt32> inputB = tf.placeholder(TInt32.class);
    Add<TInt32> output = tf.math.add(inputA, inputB);
    return Signature.builder().key("add")
            .input("a", inputA)
            .input("b", inputB)
            .output("out", output)
            .build();
}

Normally SessionFunction are only used to run a SavedModelBundle in a functional way. But I would like to restore the original behavior of ConcreteFunction into this class, where the lifecycle of the graph and the session required to execute the function is managed within the function itself, as I also find it useful (and functions ran by a graph are way faster than by an eager session).

But that being said, can you compare your actual observations with a newest version of the library (e.g. 0.4.1) and share your results?

ku222 commented 2 years ago

thanks for getting back so quickly @karllessard !

I can see that the API for ConcreteFunction has changed in the way that you've described. However, from the docs it's stated that neither the ConcreteFunction nor the SessionFunction take ownership of the output tensors that are returned from call(...). As such it is the caller's responsibility to close the output tensors.

With this in mind, I've set up some like-for-like tests between Concrete vs. Session function. When run locally with v0.4.1, I observe that the bug is still present. (NB - I am using the same Function<Ops, Signature> to add two TInt32s as earlier in the thread:

Test Case 1

@Test
public void testWithSessionFunctionThenCallCloseOnOutput() {
    Map<String, Tensor> input = Map.of("a", TInt32.scalarOf(1), "b", TInt32.scalarOf(2));

    Tensor result;
    try (var graph = new Graph(); var session = new Session(graph)) {
        var ops = Ops.create(graph);
        var function = SessionFunction.create(addSig(ops), session);
        result = function.call(input).get("out");
    }

    result.close();

    // java.lang.IllegalStateException: close() was called on the Tensor
    result.asRawTensor().data();
}

Test Case 2

@Test
public void testWithConcreteFunctionThenCallCloseOnOutput() {
    Map<String, Tensor> input = Map.of("a", TInt32.scalarOf(1), "b", TInt32.scalarOf(2));

    Tensor result;
    try (var graph = new Graph(); var session = new Session(graph)) {
        var ops = Ops.create(graph);
        try (var function = ConcreteFunction.create(addSig(ops), session)) {
            result = function.call(input).get("out");
        }
    }

    result.close();

    // No exception is thrown
    result.asRawTensor().data();
}

Would we not expect the same behaviour here between the two function types?

Thanks again for the help!