tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
817 stars 203 forks source link

how to implement DataType families? #115

Closed deansher closed 3 years ago

deansher commented 4 years ago

Currently in tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java, we use an uncomfortable pattern: DataType is omniscient about TType and dispatches on a string TType.NAME. For example:

  /** Returns true if this data type represents an integer type */
  public boolean isInteger() {
    switch (this.name()) {
      case TInt32.NAME:
      case TInt64.NAME:
      case TUint8.NAME:
        return true;
      default:
        return false;
    }
  }

The author of this pattern, @JimClarke5, mentioned in our Google Group that he regarded it as temporary:

My present code does a switch on DataType.name(), but IMO, this isn’t the most elegant way to do this.

@karllessard suggested a direction, although with some open questions:

Each data type in Java inherit from a "type family" as in here, which can be use to set bounds on a given datatype when used as a generic parameter (e.g. Tensor<? extends TNumber> to only accept tensors that are numeric). But if doesn't do in your case and you really want to check the data type family at runtime, then we need to add new methods, like dataType.isNumber(). I think ideally it should be in line with the same data types classes defined in the core library; the new methods could even be added to the C API, in this file.

Let's decide on a direction! This is moderately pervasive in our code, but also a pretty simple change, so I'd advocate we choose a direction soon and I'm tempted to volunteer to make the change.

deansher commented 4 years ago

I think most Java developers would expect to use instanceof on the marker interface. Should we simply embrace that, and extend our system of marker interfaces? (Perhaps some of these interfaces would go beyond just marker?)

Although surely we want to align closely with the data type families defined in the core library, and that does motivate exposing them through the core library's C API, I see drawbacks to that approach:

The set of data types evolves slowly, so I see little drawback to maintaining the type families as hand-coded interface inheritance. Also, I can imagine times when a bit of poetic license would be beneficial to fit better into Java's conceptual structure.

JimClarke5 commented 4 years ago

The requirement is that some methods allow only a floating data type, some allow a number (floating + integer), some allow a boolean and number, while others allow any type. As long as we can make that distinction, we could use enum or instanceof.

Craigacp commented 4 years ago

I think a set of interfaces is fine. Either way when we move to a version of Java with sealed classes & type patterns in switches (https://openjdk.java.net/jeps/360 and https://github.com/openjdk/amber-docs/blob/master/site/design-notes/type-patterns-in-switch.md) then we can get benefits from having a closed hierarchy and hopefully cut out a bunch of unnecessary code.

karllessard commented 4 years ago

Hi Dean & all, thanks for starting a thread on this very important topic,

Manipulating properly data types in TensorFlow Java is indeed a non-trivial task. Besides, you will find in this PR draft https://github.com/tensorflow/java/pull/92 multiple attempts to reduce the complexity of the solution actually in place. Let me deep dive a little bit more in the subject, with some history lessons (sorry for the long message but I think it is important to understand every intrinsic details to take the right decisions).

The first version of the API was only supporting runtime data type checks via the Tensor.datatype() https://github.com/tensorflow/java/blob/1d35c17dcc85286a91f59a6ff0b94c48f1b8d4b1/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java#L221 and Output.datatype() https://github.com/tensorflow/java/blob/1d35c17dcc85286a91f59a6ff0b94c48f1b8d4b1/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java#L46, where a DataType was simply represented by an enum https://github.com/tensorflow/tensorflow/blob/46b6537110a00470dfc3093911f1a8c9eebdbe25/tensorflow/java/src/main/java/org/tensorflow/DataType.java#L24. Then, support was added for carrying the data type of each graph node and tensor as a generic parameter on their respective classes so that some type validation can be accomplished at compile time. While some collaborators already questioned the practicality of doing this, I personally think it is a good example of leveraging idiomatic features of Java compared to type-agnostic languages like Python.

It is important to note that at that time, it was decided to map common TensorFlow data types to a standard Java types, familiar to all developers, and fallbacking on a custom type class only if such correspondence cannot be done. You can still consult this original mapping https://github.com/tensorflow/tensorflow/blob/46b6537110a00470dfc3093911f1a8c9eebdbe25/tensorflow/java/src/main/java/org/tensorflow/DataType.java#L107 in the TF Java 1.x implementation. While this solution was elegant, it had some drawbacks that forced me to switch to a different design when adding support for NdArray:

That sort of explain why I've ended up adding this hierarchy of type classes and interfaces https://github.com/tensorflow/java/tree/master/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types to replace the former Java standard types as the parameter of our generic classes like Tensor, Output, Operand, etc. This way, when calling Tensor<T>.data() <https://github.com/tensorflow/java/blob/1d35c17dcc85286a91f59a6ff0b94c48f1b8d4b1/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java#L275>, which returns an instance of T, we have everything in hand to access efficiently the tensor memory (e.g. TFloat32 extends directly from FloatNdArray with all its primitive endpoints). This detail is very important to understand as it justifies most of the decisions that were taken in the actual design of data types in TF Java.

In an attempt to centralize everything related to a given data type in a single file, I've also replaced the original enum in DataType by static instances https://github.com/tensorflow/java/blob/1d35c17dcc85286a91f59a6ff0b94c48f1b8d4b1/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java#L41 of this classes instead, instantiated in each type classes. This also allows us to pass more complex parameters as this instantiation of the data type object, such as the tensor data mapper. Ideally, the factory for DataType would have been package-private, as only a finite set of instances should exist (sealed interfaces would enforce this as well). But since type classes and DataType are not under the same package (and doing so will break a few things as well), I had to keep it public. But we can intuitively think that instead of comparing datatypes by name at runtime as Jim did temporarily, we could compare them by their identity:

if (dtype == TInt32.DTYPE || dtype == TFloat32.DTYPE || ...

or even better, we can pass the characteristics of a given data type to this factory https://github.com/tensorflow/java/blob/1d35c17dcc85286a91f59a6ff0b94c48f1b8d4b1/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java#L55 as a series of boolean, a BitSet or maybe even its type family so that methods like DataType.isNumber() won't need to switch on a value at runtime.

I do agree though that mixing both runtime and compile time data type checks increase the cognitive load for our users, as Dean explained. If we could do everything at compile time (which is again a powerful idiomatic feature of Java that we should leverage), that would be perfect but is nearly impossible. Replacing raw switches by instanceof sounds right at first but you need to be careful since it is not the instances of a DataType that extends from the TType families but the type classes themselves. Instances of these classes are only accessible once you have mapped the memory of a tensor. So you won't be able to extract valuable information simply from Output.datatype(), which is mostly useful in graph building since we don't always have access to a tensor.

On Sat, Sep 19, 2020 at 3:50 PM Adam Pocock notifications@github.com wrote:

I think a set of interfaces is fine. Either way when we move to a version of Java with sealed classes & type patterns in switches ( https://openjdk.java.net/jeps/360 and https://github.com/openjdk/amber-docs/blob/master/site/design-notes/type-patterns-in-switch.md) then we can get benefits from having a closed hierarchy and hopefully cut out a bunch of unnecessary code.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/tensorflow/java/issues/115#issuecomment-695349456, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACNEEXW5BGR4QQRIRVDN2IDSGUDQHANCNFSM4RTCUKCA .

deansher commented 4 years ago

My leaning from an initial study of this situation is that DataType is now trying to do too much. It has many roles:

  1. runtime representation of a tensor type
  2. runtime metadata for a tensor type
  3. Java proxy for the C++ (protobuf) DataType
  4. carries the compile-time tensor type, at least in Tensor's <U extends TType> U expect(DataType<U> dt)

I love the overall direction of PR draft <#92>. I think the PR's definition of TType carries its weight:

TType<T extends TType, U> extends Tensor<U>
Tensor<T> extends NdArray<T>

I would find this much easier to understand if we could express it this way, where S is the recursive/self type and E is the native element type:

Tensor<S extends Tensor, E>

Anyway, I'm good so far, but when we add DataType's type parameterization, my brain explodes:

DataType<T extends TType>

Looking back at the current roles of DataType, here's what I'd currently propose:

  1. runtime representation of a tensor type Could we let the tensor's Class do this? That would correspond exactly to our compile-time tensor typing. It would empower idioms like instanceof and isAssignableFrom as our official runtime tensor type checks.
  2. runtime metadata for a tensor type Could we define a Tensor.TypeMeta or DataType.Metaand provide methods for obtaining that from from a Tensor's Class, from a Tensor instance, and from a DataType?
  3. Java proxy for the C++ (protobuf) DataType I like the original Java DataType enum for this. But perhaps that's just part of our core API, not part of our framework API?
  4. carries the compile-time tensor type Could we use the tensor's compile-time type for this? For example, we could encourage an unchecked cast instead of <U extends TType> U expect(DataType<U> dt).
karllessard commented 4 years ago

I would find this much easier to understand if we could express it this way, where S is the recursive/self type and E is the native element type:

Tensor<S extends Tensor, E>

The reason why I preferred to carry both the recursive and the Java type in TType instead of Tensor is because I expect that users won't manipulate often directly an instance of TType, which is mostly used to bound our generic types, and therefore won't have to carry these two parameters everywhere, as Tensor would inflict.

Anyway, I'm good so far, but when we add DataType's type parameterization, my brain explodes:

DataType

This is not really required, when you think about it, even in the actual implementation. It is only used to safely typed the returned value of DataType.map, which is for internal use only anyway, so we could simply explicitly or implicitly cast it instead.

runtime representation of a tensor type Could we let the tensor's Class do this? That would correspond exactly to our compile-time tensor typing. It would empower idioms like instanceof and isAssignableFrom as our official runtime tensor type checks.

With PR draft #92 , that will work yes and I agree it would be a great improvement. The reason for having a DataType instance for each type (e.g. TBool.DTYPE) was to initialize and carry metadata associated to that type, which you cover in the next point:

runtime metadata for a tensor type Could we define a Tensor.TypeMeta or DataType.Meta and provide methods for obtaining that from from a Tensor's Class, from a Tensor instance, and from a DataType?

TF1.x was doing something similar, I'm ok to go back to this schema if that can improve the whole user experience. I don't think we need a Meta subclass, DataType can still be used to carry this metadata (without any generic parameter, as explained above).

I would leave the responsibility of initializing this metadata to the TType classes, for the sake of encapsulating all related information about one data type in place. We still need to statically initialize each type classes as we do in DataTypes since we might need to resolve that type class before a user explicitly refers to it in the code.

So we should do now something like:

public final class DataTypes {
  static {
    register(TBool.Class, TBool.DTYPE);
    ...
  }
}

public interface TBool extends BooleanTensor, TType<TBool, Boolean> {
  DataType DTYPE = DataType.create("BOOL", 10, 1, TBoolImpl::new);
}

Java proxy for the C++ (protobuf) DataType I like the original Java DataType enum for this. But perhaps that's just part of our core API, not part of our framework API?

Right now we build protobuf DataType from the numeric value of a given type, like in here. That will still be possible to do then from TBool.META.nativeCode() (which is right now package-protected but could be exposed if needed). I don't know how many use cases there is where we need to do this conversion though.

carries the compile-time tensor type Could we use the tensor's compile-time type for this? For example, we could encourage an unchecked cast instead of U expect(DataType dt).

I agree, again with PR #92 that should work well

deansher commented 4 years ago

Ah, ok, I do see benefit in the TType layer as a carrier for the recursive type, allowing a simpler Tensor<E> (where E is the Java element type) as the underlying primitive. On the other hand, TType will then show up in many public API definitions, so although our user may not code to it, they will regularly look at it. I'll think through some examples to see if I'd advocate either way.

Putting that aside, as an alternative to this:

public final class DataTypes {
  static {
    register(TBool.class, TBool.DTYPE);
    ...
  }
}

public interface TBool extends BooleanTensor, TType<TBool, Boolean> {
  DataType DTYPE = DataType.create("BOOL", 10, 1, TBoolImpl::new);
}

we could do this:

public interface TBool extends BooleanTensor, TType<TBool, Boolean> {
  DataType DTYPE = DataType.create("BOOL", 10, 1, /* ==> */ TBool.class, TBoolImpl::new);
}

public final class DataType {
  . . .
    private <E, T extends TType<T, E>> DataType(String name, int nativeCode, int byteSize,
                                                Class<T> tensorClass,
                                                TensorInstantiator<E> tensorInstantiator) {
    . . .
    register(tensorClass, this);
  }

  /**
   * Returns the <code>TType</code> subtype corresponding to this data type.
   */
  public Class<? extends TType<?, ?>> tensorClass() {
      return tensorClass;
  }

  /**
   * Indicates whether this <code>DataType</code>'s tensor class is a subtype of <code>ttype</code>.
   */
  public boolean hasType(Class<? extends TType<?, ?>> ttype) {
      return ttype.isAssignableFrom(tensorClass);
  }
}

(I'm not sure yet how I'd advocate using generics in the above. I like the simplicity of the bare DataType, given the complexity latent in DataType<T extends TType>. But then we do have awkward wildcarding in cases like tensorClass(). Yet, I suspect that works out fine in actual use. I think it deserves exploration.)

The hasType name above is odd in the context of DataType, but I propose it because the same name seems pleasing in the context of more API-central classes like Output:

public final class Output<T extends TType<?, ?>> implements Operand<T> {
  . . .
  /**
   * Indicates whether this <code>Output</code>'s tensor class is a subtype of <code>ttype</code>.
   */
  public boolean hasType(Class<? extends TType<?, ?>> ttype) {
    return ttype.isAssignableFrom(dataType().tensorClass());
  }
}

So instead of code like this:

    boolean convertToFloat32 =
        logits.asOutput().dataType() == TFloat16.DTYPE
            || logits.asOutput().dataType() == TBfloat16.DTYPE;

We'd end up with code like this:

    boolean convertToFloat32 =
        logits.asOutput().hasType(TFloat16.class)
            || logits.asOutput().hasType(TBfloat16.class);

I see this as a benefit because of how it relates to compile-time typing:

  public Output<TFloat32> foo(Output<TFloat32> x) {
    . . .
  }

  public Output<TFloat32> bar(Output<?> x) {
    if (x.hasType(TFloat32.class)) {
        return foo((Output<TFloat32>) x);
    } else {
      . . .
    }
  }
karllessard commented 4 years ago

we could do this:

public interface TBool extends BooleanTensor, TType<TBool, Boolean> { DataType DTYPE = DataType.create("BOOL", 10, 1, / ==> / TBool.class, TBoolImpl::new); }

Ideally we could do this, the problem though is that we won't be able to do reverse lookups from a native type code to a concrete Java class until the user initializes explicitly that class.

For example, for a fictional function that returns true or false if the tensor in input has a zero value,

try (SavedModelBundle model = SavedModelBundle.load(...)) {
    Tensor<?> t = model.function("isZero").call(TFInt32.scalar(0));
}

Under the hood, to instantiate the tensor t (and this is especially true in #92), we need to convert the numeric native code of the output tensor to TBool.class somehow. But if we register this mapping only in TBool itself, then we need to make sure that the static initialization of this class happens before we need to do this conversion.

This is why we now have this static helper class, DataTypes, taking care of setting up this mapping as well as searching into it.

The hasType name above is odd in the context of DataType, but I propose it because the same name seems pleasing in the context of more API-central classes like Output

I like hasType(), it acts like instanceOf but on the generic parameter of the class exposing it. And again in #92, it would be possible to simply use instanceOf (or is, as in Kotlin) if the object is a Tensor.

@deansher , it looks like reviving #92 is a good next step, what do you think about starting your experimentations based on that branch?

deansher commented 4 years ago

Ideally we could do this, the problem though is that we won't be able to do reverse lookups from a native type code to a concrete Java class until the user initializes explicitly that class.

Yikes -- good point!

@deansher , it looks like reviving #92 is a good next step, what do you think about starting your experimentations based on that branch?

I agree wrt the principles, but I don't have enough context to judge wrt version/code management. @karllessard , is your intuition that

  • We could start with your current #92 branch, improve it and get tests running, and then successfully merge?
  • Or should we do further exploratory work based on the current #92 branch, expecting to start back at the then-current master when we are ready to start on the final version of these changes?
  • Or should we hypothesize a migration strategy now and immediately start back at master?
karllessard commented 4 years ago

We could start with your current #92 branch, improve it and get tests running, and then successfully merge?

That is what I had in mind, yes. Despite that this PR draft is getting "old", it shouldn't be very hard to rebase it on master when we will be ready to merge it, i.e. once we are satisfied with the changes.

I'm aware that the PR might get pretty large at that time but this kind of changes require us to work on a complete solution to ensure that we are on the good track. There might still be a possibility that we could split the PR in smaller parts later if needed.

deansher commented 4 years ago

I'll certainly go with your intuition on that! I'll try some small PRs against #92 's branch. (Or else coordinate further if I can't PR in the main repo against a fork's PR branch.)

karllessard commented 4 years ago

So I was playing around data types yesterday to get break from the release stuff, and I ended up with a series of new observations concerning the removal of the generic parameter to the DataType class and the usage of T*.class instead of T*.DTYPE (where T* is TFloat32, TInt32, etc.)

In general, the transition went smoothly, only at the Tensor allocation I needed to add an explicit cast to T* due to the fact that this type cannot be inferred from DataType. Where I've stopped though is when I realized that for all operations that took a DataType in input, I will now need to do a table lookup to get it from T*.class.

For example, tf.empty(shape, TInt32.DTYPE) will now become tf.empty(shape, TInt32.class) but TInt32.class does not carry enough information to build that operation and I need to find its DataType from a Map<Class<?>, DataType> to do it (that I've added to DataTypes). The impact on the performances is probably negligible, especially in graph mode, but that just made me realized that DataType was put in place to substitute Class as it can carry more information about this type and avoid this additional lookup.

In other words,TInt32.DTYPE is to DataType<TInt32> asTInt32.class is to Class<TInt32>. This symmetry can explain why it might be acceptable to carry a generic parameter in DataType as we do now.

From this observation, I was thinking if there is a way to add additional information to a Class instead of having our own class for doing it and, of course, there is : annotations. And basically, doing TInt32.class.getDeclaredAnnotation(DataType.class) ends up doing the same hash map lookup I'm doing now.

What is interesting with annotations too is that with our annotation processor, we can scan all classes annotated with DataType and build at compile time the reverse lookup table from a native code to a type class, avoiding this manual type registration that we are actually doing.

I'm not done brainstorming if the usage of annotations in this context makes total sense or not but I wanted to share my thoughts now so I can increase the number of neurons in my network :)

deansher commented 4 years ago

Your framing of this, Karl, aligns exactly with my mental model of it:

  • DataType<T> is a runtime representation of a tensor type.
  • Class<T> is Java's mainstream runtime representation of a type, so can we use that?
  • Yes we can, at the cost of a side-table lookup for metadata associated with Class<T>
  • Wouldn't it be great if there were a Java-mainstream way to add custom metadata to a Class<T>?
  • Yes, of course there is -- that problem comes up all over the place. It's annotations.

That said, I don't have experience providing a heavily-used package that defines its own annotations. So if there are gotchas, I don't know them.

deansher commented 3 years ago

Resolved by #174