tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
833 stars 202 forks source link

Functions #238

Open rnett opened 3 years ago

rnett commented 3 years ago

Functions

This reliant on things described in #237, and being familiar with the python function semantics will help, although I given an overview.

ConcreteFunction

A ConcreteFunction, both here and in python, is a wrapper for a native function handle. This is a set of "compiled" code that can be ran, but the signature can't be changed. Performance optimizations depend on that signature being well-known, too, apparently, i.e. argument shapes should be specified, even if they aren't required on placeholders.

Python

Functions in python are done via tf.function. Since python is duck typed, the dtypes and shapes of inputs can change, so a function will create a new ConcreteFunction for each distinct realized input signature (this is called tracing). Type hints can be used to limit re-tracing.

Supported arguments and what they expose to the signature are:

Of note is that if python objects are used as arguments, a new ConcreteFunction will be made for each set of argument values, so this is discouraged.

Because of the way tracing works, python side effects are only executed on each trace, which is unreliable. They will always be executed on the first call, but the function may not be re-traced each call.

Python control flow is converted to TensorFlow control flow (this is impossible for us).

Variable creation and initialization is handled as described in #237, the only thing we need to know is that global variable creation is forbidden after the first trace.

Outputs can be a single tensor or a tuple of tensors.

Java

For our implementation, I'm envisioning something similar to ConcreteFunction's API, where you would have a method like:

class Example{
  private static Map<String, Operand<?>> defineFunc(Ops tf, Inputs inputs){
    final var x = inputs.input("x", TFloat32.class);
    final var y = tf.math.add(x, tf.constant(2.0));
    return Map.of("y", y);
  }

  public static void main(String[] args){
    MapFunction func = Function.define(Example::defineFunc);
  }
}

The actual syntax is mostly irreverent and will likely change, but a few things to note:

Note that we have no way of converting Java control flow to tensorflow's like is done in Python. This will be noted and highlighted in the docs. It's generally not an issue though, since we can't resolve tensors in functions, so any control flow will be dependent on Java args which will cause re-tracing on new values anyways (it's a bit sub-optimal, but works fine).

I also plan a Kotlin compiler plugin to do the transform if you annotate the function like Python does, but I'm not sure if officially supporting it is a good idea (the compiler API is unstable and undocumented, although it will be released eventually).

Inputs and Outputs

Using a Inputs builder instead of the existing signature builder is because it's necessary to allow for more complicated inputs (i.e. not just placeholders).

Like python, I would want to allow as inputs:

This would work very similarly to how Python handles it, creating a signature for each set of argument values (with the same criteria as Python's) and caching ConcreteFunctions based on that. The parameter list (i.e. names and types) would be static, and we would use it to get the argument signature without re-tracing the function. If a retrace resulted in a different parameter list an error would be thrown.

For outputs, I want to support single tensors, a list, and a String -> Operand map (shown above). The values of the list and map would be limited to tensors. The resulting Function objects will be type safe wrt their outputs (i.e. the call method returns List or Operand). I'll also add a type safe single input single output version.

Variables and Captures

As described in #237, global variables are created in an attached eager initScope, and limited to the first trace. This works since everything from the initScope are automatically captured when used. As described in #237, I would like a way to create variables once and remember them on further calls automatically, but this requires execution environment wide unique ids, which is prohibitive (or at least I can't think of a better way).

Operands from other environments that are accessible from the function definition can be captured by the closure. Only operands from the initScope will be captured automatically on use. Otherwise, you can use input.constCapture(x) to bring x into scope if it's from a compatible environment, or input.capture(() -> x) (the preferred way), which will reflect any updates to x.

Constant captures from eager sessions use tf.constant(x.asTensor), and lambda captures or captures from graphs work by adding an input, but providing the argument each time the function is called.

Operands from any eager environment can be captured, but if captures from a graph are used they all must come from the same graph, and the function can only be called from that graph, since the captured value needs to be accessible for calls. This is the case almost all the time anyways.

Note: this is largely an implementation detail, but functions can be inlined when being called from another function (and possible Graph, although that's a bit harder). I still need to work out exactly how that will work with captures and whatnot, but it should definitely be possible (Python does it).

Saving and Loading

Currently not supported. Described in #237 a bit, but essentially we would need to convert the initial values from an eager context to a graph, and the needed C apis aren't exposed yet (if it's even possible). There's workarounds like just saving the current value we could look into if this is very necessary.

Craigacp commented 3 years ago

For the function definition syntax, I think that Kotlin is creating a method reference? Could you use Java in the examples please (unless referring to how this would be exposed in Kotlin). Is that method reference then immediately executed (i.e. traced) by Functions.define?

If it is indeed creating a method reference, then I think we should define a Java functional interface which we can type more strongly which is accepted by the function creation system. It's going to be hard to express to Java developers that this code is only executed once to build the graph, especially if it allows arbitrary Java inputs (which I actually would like to avoid if at all possible because the behaviour will be hard to explain), so relying upon lambdas which conform to some named type (e.g. FunctionBuilder) will help signal that this builds the function rather than executes it. We could even rely on an explicit builder pattern though the appropriate typing signal might be harder.

Why do we want to allow retracing when Java inputs change? It'll be a pain to track, especially if the Java input objects are mutable, and might well lead to very odd behaviour. Why not make the user define a new function with the updated inputs? I think the retracing is likely to be a source of bugs in Python anyway, it is in JAX, and I'm not sure we want to encourage mysterious behaviour.

What's the return type from Function.define, or does it mutate some internal state? We should probably name them too.

rnett commented 3 years ago

For the function definition syntax, I think that Kotlin is creating a method reference? Could you use Java in the examples please (unless referring to how this would be exposed in Kotlin). Is that method reference then immediately executed (i.e. traced) by Functions.define?

Yeah, and I converted it. Tracing would be done at the first call, since it requires the arguments.

If it is indeed creating a method reference, then I think we should define a Java functional interface which we can type more strongly which is accepted by the function creation system.

A functional interface sounds good to me. Something like:

@FunctionalInterface
public interface MapFunctionDefinition {
  public Map<String, Operand<?>> define(Ops tf, Inputs inputs);
}

@FunctionalInterface
public interface ListFunctionDefinition {
  public List<Operand<?>> define(Ops tf, Inputs inputs);
}

@FunctionalInterface
public interface FunctionDefinition<T extends TType> {
  public Operand<T> define(Ops tf, Inputs inputs);
}

with nice javadocs.

Why do we want to allow retracing when Java inputs change?

The reason behind Java inputs is mostly to allow control flow. In python you can do something like:

def test(x, neg=True):
  if neg:
    return -x
  else:
    return x

To allow the same type of thing in Java, we need a way to pass Java objects into the function, since we can't get the values from tensors. I don't love it either, especially since you lose the type information, but I don't see a good way around it either. Python converts the if statement to a tf.cond op, which we can't do, so to get the new output we have to retrace it.

There's also a use case that I'm looking at for the Kotlin version which is serializing an input w/ operands to a template and operand list, then passing in the operands and template as inputs and deserializing it in the function, which would allow any serializable type to be used as input. It's a bit specific, but more generally allowing Java inputs opens up a lot of possibilities for more complicated functions like this.

For mutability, I need to check what Python does, but what I think it does and what I would do here is re-trace every time if there's any Java inputs, which makes mutability not a concern and avoids having to cache the values. We could then add an inputs.immutableJavaInput<T>() that would check equality.

I'm not sure what you mean by "Why not make the user define a new function with the updated inputs?". That's essentially what's going on behind the scenes anyways. And since there's no good way to declare inputs other than inputs you'd have to do something like:

class ParamFunc{
  public ParamFunc(boolean neg){
    this.neg = neg;
  }
  private final boolean neg;
  public final Function func(){
    return Function.define((tf, inputs) -> {
      final var x = inputs.get("x", TFloat32.class);
      if (neg){
        return tf.math.neg(x);
      } else {
        return x;
      }
    });
  }
}

// use
public static Operand<TFloat32> text(Ops tf, Operand<TFloat32> x, boolean neg){
  return ParamFunc(neg).func.call(tf, x);
}

which is a bit of a mess. Plus, the bigger problem with re-defining a function like this is that they can't share variables, so if you declare a (global) variable inside the function it would be re-created each call (whereas if we do this via re-tracing they will be reused).

I could see some way to force re-tracing, as bug prevention. Something like callWithRetrace.

What's the return type from Function.define, or does it mutate some internal state? We should probably name them too.

Function.define would return a Function, MapFunction, or ListFunction, depending on what the definition returns. It's callable like ConcreteFunction. There's no internal state.

rnett commented 3 years ago

An update since I've been playing with generating function ops: there is a space for a function class that supports captures (maybe, I'd need to test the functional ops), variables, and the new init scope, but not retracing. This is equivalent to python's DeFun, and the main use for it is that since it only ever has one ConcreteFunction, you can pass it to ops like If. It also allows users to avoid re-tracing issues at the cost of some functionality (collection imports won't work, and Java ones only do if they are immutable, at which point they should be captures, so I'd forbid it). I'd prefer to keep ConcreteFunction separate, as a lower level API and for cases where you don't want an eager init scope, but they could be combined.

rnett commented 3 years ago

Two more things I want to add to this, clarifying the above a bit:

Function types

I'd want three types of functions (names TBD):

Inlining is the ability to inline into an outer function graph (and maybe normal graph) and to execute in eager mode for debugging (w/ a global flag like Python), which requires keeping the definition lambda around.

Polymorphism on input definers

I'm thinking of doing something like:

// not sure the type inference will work properly like this, may have to add helper functions w/ concrete types
public static DefinedFunction define(FunctionDefiner<? super CapturesAndInputs > def)
public static Function define(FunctionDefiner<? super InputSpec> def)

// possibly more, with different output types.  Or add type param for output type
public interface FunctionDefiner<T extends InputSpec>{
   public Outputs define(Ops tf, T inputs);
}

abstract class InputSpec{
    protected final Ops tf;
}
class CapturesOnly extends InputSpec{
    public Operand<T> capture(Producer<Operand<T>> value)
    public Operand<T> captureConstant(Operand<T> value)
}
class CapturesAndInputs extends CapturesOnly {
    public Operand<T> input(String name, Class<T> type)
}
class CapturesInputsAndJavaValues extends CapturesOnly {
    public T javaInput(String name, Class<T> type)
    public T javaCapture(Producer<T> value)
}

The exact hierarchy is still up in the air, I'm considering adding one for captures, inputs, and collection inputs, or maybe only having InputSpec, Captures, CapturesAndInputs, and Everything. The importance of those three is that InputSpec (i.e. nothing, Producer-like) is what is supported by all ops that take functions, Captures is what is supported by functions passed to ops like If that support passing inputs, CapturesAndInputs is what is supported by DefinedFunction, and everything is everything supported by Function.

That way we can have an If op method that looks like:

public Operand<T> ifOp(Operand<TBool> cond, FunctionDefiner<CapturesOnly> thenBranch, FunctionDefiner<CapturesOnly> elseBranch)

that uses DefinedFunction internally. There's some magic that needs to happen to find the captures and pass them through properly, so the signature might change slightly, but it shouldn't change too much. Most likely it will involve a new subclass of CapturesOnly w/ a shared capture pool. But that's an implementation detail that won't be exposed.