tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
833 stars 202 forks source link

Functional graph definition API #181

Open rnett opened 3 years ago

rnett commented 3 years ago

@karllessard: While TF Java has always been graph/session centric, it will gradually move towards to a more functional approach, like @tf.function does in Python. This is what the new ConcreteFunction is partially achieving and the core will continue to build up around it to improve the support of functions as the main API for building and executing graphs.

@rnett: The concrete function API looks neat. Correct me if I'm wrong (it's been a while since I worked with @tf.function), but the goal is essentially to create what looks like a eager function but is actually backed by a graph? I've been playing around with ideas for something similar in Kotlin using compiler plugins (you could use annotations on functions or lambdas), but I'm not sure how you could do the same in Java without ASM generation.

This is probably more of a Kotlin thing than Java, but have you given any thought to having some mechanism for tensor lifetime scopes (like PointerScope)? It seems like most tensors should be lexically scoped (i.e. try with resources), and having some kind of scoping mechanism would make this a lot easier to manage, while still allowing non/globally scoped tensors when needed.

@karllessard: Thanks @Craigacp and @rnett for your good feedbacks. My belief is that if we are about to improve the usability of our API, we should focus more on the functions than on the graph and sessions. @rnett to answer your question, yes, the goal of a ConcreteFunction is to mimic a little bit what Python does, i.e. convert easily a function that can be called eagerly or backed by a graph. Right now, only graph mode is supported by ConcreteFunction but nothing prevents a user to call directly the same method passed as the functionBuilder with an eager session to execute it eagerly. Though I would prefer to make the eager support more explicitly integrated with the function concept. Now should we use an annotation or not, like Python does, I guess it could work but I didn't tried to think how this would fit in the actual design.

Now for the differences in resource management between the inputs and outputs, I was also aware of this detail. For the sake of brainstorming, maybe reference count could be useful here. For example, when we pass a tensor to a bundle, we could just increase the reference count so the tensor gets only released once all references are released. Also, ConcreteFunction has already its way to release or not its resources (the graph and the session) depending on how it has been allocated... I don't have the complete paradigm in mind but we can continue to think about it if we all think that could be something useful, wdyt?

Another point if favor to focus on the functional API is that it worked both with training and inference (after loading a saved model bundle), while using directly the graph and sessions works better only for training since TF2.0, if you remember the issues @Shajan was facing before.

rnett commented 3 years ago

A couple of considerations I ran into when working with my Kotlin version of this:

What I did in my Kotlin experiments was to have Operand functions that work like normal, Layers that have an Operand define function and can declare variables (stored in a map in the object using positional memoization), and Models that are layers w/ specified inputs and outputs. Models are analogous to ConcreteFunctions, and can be built for eager or graph mode and called w/ Tensors. It's essentially leaving the eager/graph decision up to the model, instead of using something like @tf.function. It's not ideal for some uses where you just want a function (like image processing), so I'd add a graph function somewhere, but for deep learning imo doing the eager/graph split at the model level works fine.

karllessard commented 3 years ago

Boilerplate. Calling ConcreteFunctions has a lot more boilerplate than just calling an Operand method, especially having to wrap the arguments in Map.

I totally agree. Python does this mapping automatically by analyzing the name of the arguments in the function invocation but we cannot do that at runtime in Java. There is already a convenient shortcut for functions that just accept and return a single tensor. Maybe we can change that method to accept a list of positional tensor arguments where the order must follow the same as in the function signature? This approach cannot be totally error prone but could be acceptable, we could at least validate that the type and shape of the tensors match those of the signature at runtime and fail gracefully if they don't.

As for the returned tensors, we probably need to use the same utility as was proposed in #167 but we should decouple it from the graph/session execution, as functions could also be ran eagerly (or eventually will).

So maybe something like public TensorList Function.call(Tensor inputs...) { ... }

Again we could still pass a TensorList in input as well but it won't be as straightforward to invoke than the proposition above and, as discussed previously with @Craigacp, the TensorList (or whatever we call it) must be able to keep track of a reference count before releasing tensors in a auto-closeable fashion.

Kotlin makes it easier to build up maps, with mapOf and to. If we pass the inputs as a TensorList, then we can provide also some utils to build up easily such list that mimics how Kotlin does it.

rnett commented 3 years ago

I think a lot of my preferences for this depends on where you see ConcreteFunction being used, specifically whether it's common to use on model parts (layers, functions) or is something you use on the whole model. Because the Map interface isn't that bad for calling models, but it would get annoying fast for layers.

I see what you mean about the result class, but there will still be some differences from session, in particular the "get from Output/Operand" methods. A TensorMap or TensorList seems like it would work fine for this, although a new class with get(String) and get(int) methods would probably be better. Perhaps the session result class could extend it.

karllessard commented 3 years ago

I think a lot of my preferences for this depends on where you see ConcreteFunction being used,

It is definitely the first choice for running inference on a loaded saved model, as it takes care itself of the tensor/op name mangling that is even more apparent since TF2.0. A lot of our users will use TF Java specifically for running inference on a model pre-trained in Python, so won't have to deal with layers. Still, they will need to keep track of the lifetime of the input/output tensors so we need to make that easy for them. With inference, I don't see why a user would prefer to run its graph using a session instead of a function.

common to use on model parts (layers, functions) or is something you use on the whole model

Functions are mainly used for running a model after it has been built (or compiled), as a whole, but we should probably see how they could be transformed as a custom layer when building a model as well.

A TensorMap or TensorList seems like it would work fine for this, although a new class with get(String) and get(int) methods would probably be better. Perhaps the session result class could extend it.

That is something we can look at, yes. Or we could simply add this methods to the TensorMap/List and return null (or throw) if we request for a operation that does not exist, whether the tensors in the list are attached to a graph or not.

rnett commented 3 years ago

Functions are mainly used for running a model after it has been built (or compiled), as a whole

Ok, that clears up most of my concerns about the interface. While the Map style calls still isn't ideal, imo it's fine for calling a model, and potentially some special layers.

That is something we can look at, yes. Or we could simply add this methods to the TensorMap/List and return null (or throw) if we request for a operation that does not exist, whether the tensors in the list are attached to a graph or not.

Atm there's no way to find Outputs from strings without storing the Graph, although you could use a Map<Output, Tensor> I guess. I've been thinking of TensorMap/List as collections with lifecycle handling, if you want to do Output -> Tensor or String -> Output mapping they would need a bit more context. There's enough features in Result (String -> Output, type safe get(Output/Operand), index and key access on the same collection) that I still think it's worth having it be separate, but it would be good to see the initial versions of the tensor collections first.

rnett commented 3 years ago

For tensor lifetimes, is there any reason a finalizer wasn't used? It seems like the perfect solution here, which makes me think I'm missing something.

karllessard commented 3 years ago

We cannot rely on the garbage collector to free native resources like tensors, as it won't keep track of their actual size in native memory, thus won't trigger in time even if the JVM is close to be OOM.

In eager mode, we do activate the GC to collect unused resources allocated during a session (it used to be done directly in TF but now is handled by JavaCPP) but this is just used as a "best-effort" and we ask our users to free their resources explicitly.

Another solution that we've discussed many time is to rely on reference counter (like C++ smart pointers) instead of using try-with-resources block all over the place. So the user or the library could invoke explicitly retain or release on each resource and only once the count is back to 0 it will be freed. Though, the usability of this paradigm in Java is not clear yet.

rnett commented 3 years ago

Right, I didn't realize finalize had so many caveats, that's unfortunate.

Usability wise, I don't see any benefit to reference counting over close and try with resources, you'd still have to call it manually.

Has any thought been given to using scopes, like how PointerScope works? It seems like an easy way to ensure that any tensors except those you explicitly call detach on are closed. It's "specify when not to close" rather than "specify when to close", which from what I've seen, suits the API better, especially eager mode. You'd still need manual closing methods, but most tensors should be handled by their scope.

saudet commented 3 years ago

Has any thought been given to using scopes, like how PointerScope works? It seems like an easy way to ensure that any tensors except those you explicitly call detach on are closed. It's "specify when not to close" rather than "specify when to close", which from what I've seen, suits the API better, especially eager mode. You'd still need manual closing methods, but most tensors should be handled by their scope.

Although it's implicit, that still uses reference counting though. It's similar to how Swift does ARC (which might be adopted in the Java language too at some point in the far far future), so comes with all the caveats about circular references and what not.

Craigacp commented 3 years ago

For tensor lifetimes, is there any reason a finalizer wasn't used? It seems like the perfect solution here, which makes me think I'm missing something.

Using finalizers causes the JVM to do a bunch of extra bookkeeping and tend to put those objects on the slow path for everything. Plus they are tricky to reason about and not necessarily called. It's best to avoid them unless absolutely necessary.

karllessard commented 3 years ago

FYI, I was planning to discuss further about this topic on our January 8th video call. @rnett and @saudet , if you can make it, that would be great.

rnett commented 3 years ago

I should be there. @saudet I don't think we'd run into any circular references or anything similar since all of our references are unidirectional, so to speak: multiple java objects can reference one tensor. There's no tensors referencing each other, or anything like that. I'm not sure how it interacts with eager mode and operands but it doesn't seem like it would be too bad.

One issue with any lifecycle solution that we should talk about either now or Friday: NDArrays. Since TTypes are now NDArrays, it's possible to do something like:

TFloat32 x = someTensor();
FloatNdArray y = x.index(...); // or return x as a NDArray, etc

y still refers to the tensor of x and thus it's native memory if I understand the mapping right, but there's no way to tell it's tensor backed or close the tensor.

The easiest solutions are to add a close method to NDArray that closes any underlying tensor, or add a way to get an underlying tensor if it exists, but it's still rather inconvenient.

rnett commented 3 years ago

So, I was bored and decided to play around with a @tf.function style compiler plugin for Kotlin: https://github.com/rnett/tf-function

It was surprisingly easy, the only Java API issue I ran into (other than ConcreteFunction needing more call methods, but that's already known) is that for this to work properly, I need some way to create Op classes from their outputs, as if they've already been ran. This is so I can return the right values from the graph execution. Depending on how we do it in Java, it could be an issue here too, although unless we use bytecode generation it doesn't seem terribly likely. It does not seem that hard to implement, with a FakeEagerOperation class and a new generated factory method.

karllessard commented 3 years ago

That's very interesting, @JimClarke5 raised another case where converting an Output or an Operation back to an Op (Variable in his case) would be helpful.

It shouldn't be hard to do it properly, without the use of a fake class. We probably just need a revert mapping table to find the right Op class from an operation kernel name and invoke its now-private constructor to instantiate it. The tensorflow-core-generator could take care of generating that table when scanning the @Operator annotated classes.

rnett commented 3 years ago

That won't work for my use case, I don't think. I don't have access to the Operation instance, I'm getting Tensors from executing the ConcreteFunction, and then bringing them into the eager session with tf.constantOf (and then trying to re-construct the graph op that produced them in eager mode). Essentially I'm trying to convert an Op from Graph to Eager mode, using the tensors from the already-ran graph version. Currently, as far as I can tell, the implementation of the ops depends on getting an Operation instance, and I can't use EagerOperation since it needs the TFE_Op (and I'd have to feed it the TFE_TensorHandle from the constant ops, which seems inadvisable). The operation to use would be detected in the compiler by looking at the type of the variable I'm trying to mock.

For the compiler plugin specifically, the optimal solution for me would be a "fake"/mock Operation class that just takes the outputs and metadata, and making the generated Operation constructor public. I could work with something based on the op name easily enough assuming that all op classes have it statically accessible.

Craigacp commented 3 years ago

Could you elaborate a little more on what your use case is and what the hurdles are? You can inspect the graph to get the operations back out of it and might be able to unpick it a little more from there.

rnett commented 3 years ago

Basically, I have something like this being generated:

val testFuncProp = FunctionRunner("testFunc") {
    val a = it["a"] as Int? ?: error("a is not present")
    val b = it["b"] as Operand<TInt32>? ?: error("b is not present")

    val c: Operand<TInt32> = math.add(b, constant(2))

    val outputs = mapOf("c" to c)

    val result: (Map<String, Operand<*>>) -> Pair<Int, Operand<TInt32>> = {
        val c = it[""] as Operand<TInt32>? ?: error("c is not present")

        (a + 2) to c
    }

    FunctionResult(outputs, result)
}

The only things that were there originally is the val c: Operand<TInt32> = math.add(b, constant(2)) and the (a + 2) to c, the variable declarations are replacing the function parameters from the original function.

The goal is to substitute the c in the result lambda (which is an eager Operand) for the c in the main lambda (which is a graph Operand, the main lambda is used to define a ConcreteFunction). The c used in the result lambda was originally referencing the declared graph operand from the graph definition method, I'm changing it to reference the c gotten from the map of eager tensors passed to the result lambda. The substitution is done by the compiler plugin, so the types need to match. This works fine as long as the type of the outer c is Operand. However, if the type of the outer c is say Add<T>, I need the type of the inner c to also be Add<T>, and since the inner c is created from eager tensors (with tf.constantOf), there is no way to get such an Add op.

In other words, I have some code that looks like val c: Add<TInt32> = math.add(a, b). It's going to be ran in graph mode via ConcreteFunction (with c as the output), but I want to get the value of c as if it had been ran in eager mode.