tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
788 stars 195 forks source link

Issues when translating from TType to TNumber and parameterized types within methods. #119

Open JimClarke5 opened 3 years ago

JimClarke5 commented 3 years ago

Many of the Java TF Ops use parameterized types for either TType, TNumber or both. Sometimes an Op uses <T extends TType> and sometimes another Op is using <T extends TNumber>. When writing a method that uses two different Ops that declare <T> differently, the compiler complains that T cannot be converted to the other type. It is interesting that TNumber is a subclass of TType. I have searched "Professor Google", but have not found an answer to this kind of problem.

TType to TNumber conversion is very common, especially if you are creating a base class with a common method signature across many similar objects. Sometimes, the subclass calls for a TType, sometimes a TNumber. The real problem, is when you have a common method such as public <T extends TType> Operand<T> call (Operand<T> input).

As a work around, let's say that you cast a TType to a TNumber (where <U extends TNumber>) as in:

@SuppressWarnings("unchecked")
Operand<U> uInput = (Operand<U>)input;

Now when you call something like tf.math.greater(uInput, otherValue);, the compiler complains: no instance of type variables(s) exists so that T conforms to TNumber. That is because tf.math.greater uses <T extends TNumber> while other ops, like tf.nn.relu defines <T extends TType>.

Another way around this is to force erasure as in (Operand)value.

At a minimum, it would be nice if there were a convention like <T extends TType> and <U extends TNumber> consistently, but this may not solve all these kind of issues, as I have seen <U extends TType, T extends TType>, and <V extends TType, T extends TType, U extends TType>

The main issue that contributes to this problem is that the Ops require a mixture of types, so a higher level user is artificially juggling the situation by casting like above, or by forcing an erasure of the type. IMO this situation is going to be confusing to the API user. I still haven't figured out a clean way to get around the issue when two method signatures use the same generic parameter in different ways.

Perhaps there is a better way. My gut feel is this is going to become a larger headache down the line.

The specific example I am running into at this time this problem is:

@Override
    public Operand<T> call(Operand<T> input) {
        @SuppressWarnings("unchecked")
        Operand<U> uInput = (Operand<U>)input;
....
        Operand<U> greater = tf.dtypes.cast(
                    tf.math.greater(uInput, 
                            tf.dtypes.cast(tf.constant(threshold),  
                            input.asTensor().dataType())), input.asTensor().dataType());
         uInput = tf.math.mul(uInput, greater);
         input = (Operand<T>)uInput;
...
deansher commented 3 years ago

I'd be amazed and disappointed if the Java compiler conflated type variables of the same name across different scopes, such as if "tf.math.greater uses <T extends TNumber> while other ops, like tf.nn.relu defines <T extends TType>", and if the problem is addressed by changing T to U in one of the scopes. Can you create a branch that isolates this problem?

I expect the main solution here is to be ruthlessly consistent in the specificity of our type constraints. I think we need to divide the parameters of our ops and framework methods into a few broad type categories, where both the compilation and the runtime support arbitrary mix and match within a category. To the extent at all possible, we need the compiler to complain if and only if an expression would fail at runtime.

It's a major ergonomics loss when a compiler complaint can be resolved by a simple language-level cast like (Operand<TNumber>) x. This means our typing is inconsistent with our runtime semantics. But it's a major ergonomics win when the compiler complains about a situation that indeed requires runtime changes.

It would also be a major ergonomics loss if we had lots of categories of parameters with different runtime and compile-time behavior. Perhaps we could limit the number of distinct, commonly used categories to something like 3. E.g. TFloating, TNumber, and TType? Then, I think we'd want to be pretty aggressive (especially at the framework level) in supporting the broadest category feasible for a particular parameter.

karllessard commented 3 years ago

Right now, the generated wrappers parametrize an operation to accept an TType or TNumber operand based on the information that is defined in the kernel definition of that op, for example here. If all the types listed for a given attribute is member of this group, then it is bound to TNumber in Java. In the same file you'll find other groups that classifies types in a certain logic. The classification in Java (what we call "family") should reflect this as well to make sure that compile time and runtime validations are in sync.

Now note that I've already saw in the past cases where some types were wrong in the kernel definition (can't remember if it was that unsupported types were listed or if it was accepting all possible values while it should have been more restrictive). But basically, that is our source of truth.

In your example @JimClarke5 , if the method call invokes unconditionally tf.math.greater, then the signature of that method should only accept TNumber operands as well and that would resolve the compiler complaints. The problem is that is you have some conditional logic, you might need to explicitly cast it, like this:

Operand<T> input = ...;
...
if (input.asOutput().dataType() == TFloat32.DTYPE) {
    tf.math.greater((Operand<TFloat32>)input, tf.constant(10.0f));
}

which is not bad in this case but if your condition applies to more than one type, it you need to go more generic:

if (input.asOutput().dataType() == TFloat32.DTYPE || input.asOutput().dataType() == TInt32.DTYPE) {
    tf.math.greater((Operand<? extends TNumber>)input, tf.constant(10.0f));
}

I think this complexity could be partially resolved by the work in #92 again but it might not cover all cases yet, it is more focusing on tensors than any type of operands.

JimClarke5 commented 3 years ago

I actually resolved this one by forcing everything to TNumber in the upcoming PR for losses.

The issue mainly comes from TBool which is usually TF cast to a TNumber like TInt32 later in the method. However, sometimes the input might be a TNumber but still but still has to be passed as a TType in the method signature. This gets tricky when calling another method with <T extends TNumber>.

karllessard commented 3 years ago

I actually resolved this one by forcing everything to TNumber in the upcoming PR for losses.

I think this is the ideal scenario. It is more a headache for us, framework developers, to take care of this but that is how we can end up having compilation-time type checks that catches upfront all possible runtime errors.

This gets tricky when calling another method with .

Can you share an example? Normally if you call a method that accepts and returns T extends TType, and you pass a T extends TNumber operand, you should get back a T extends Number as well.

JimClarke5 commented 3 years ago

The signature for tf.math.abs() has <T extends TNumber>

If I do this with <T extends TType> to support TBool:

public static <T extends TType> Operand<T> handle(Ops tf, Operand<T> input) {

        DataType<T> dataType = input.asOutput().dataType();

        Operand<? extends TNumber> tInput;
        if(dataType.isBoolean()) {
            tInput = tf.dtypes.cast(input, TInt32.DTYPE);
        }else {
            @SuppressWarnings("unchecked")
            tInput = (Operand<? extends TNumber>)input;
        }

        return tf.math.abs(tInput);
    }

This produces errors on :

If I try this, which is what I would like to do:

 public static <T extends TType> Operand<T> handle(Ops tf, Operand<T> input) {

        DataType<T> dataType = input.asOutput().dataType();

        if(dataType.isBoolean()) {
            input = tf.dtypes.cast(input, TInt32.DTYPE);
        }

        return tf.math.abs(input);
    }

Errors with:

deansher commented 3 years ago

Here's my understanding of this situation. I haven't tried running my handle2, but IntelliJ is happy with it syntactically.

  // T is an unknown subclass of TType. It could be a TString, a TBool, or a TNumber.
  public static <T extends TType> Operand<T> handle(Ops tf, Operand<T> input) {

    DataType<T> dataType = input.asOutput().dataType();

    if(dataType.isBoolean()) {
      // Since input is an Operand<T> and T may not be TInt32, this doesn't compile.
      input = tf.dtypes.cast(input, TInt32.DTYPE);
    }

    // Since abs requires a TNumber and input could be something else, this doesn't compile.
    return tf.math.abs(input);
  }

  // That didn't compile. Let's see what will:

  @SuppressWarnings("unchecked")
  public static <T extends TType> Operand<T> handle2(Ops tf, Operand<T> input) {

    DataType<T> inputType = input.asOutput().dataType();

    final Operand<? extends TNumber> x =
        inputType.isBoolean()
            ? tf.dtypes.cast(input, TInt32.DTYPE)
            : inputType.isString()
                ? tf.strings.toNumber((Operand<TString>) input)
                : (Operand<? extends TNumber>) input;

    // At this point, we know x is a TNumber, but we don't know its exact type at compile
    // time. We also don't know the exact type of T at compile time, so it's not trivial
    // to sort things out on the way back.

    // Not all hope is lost, because the runtime value of input gave us the desired
    // runtime DataType.

    final Operand<? extends TNumber> a = tf.math.abs(x);

    DataType<? extends TNumber> outputType = a.asOutput().dataType();
    if (outputType == inputType) {
      return (Operand<T>) a;
    } else if (inputType.isString()) {
      return (Operand<T>) tf.dtypes.asString(a);
    } else {
      return tf.dtypes.cast(a, inputType);
    }
  }
JimClarke5 commented 3 years ago

Your example will result in untyped casts warnings.

deansher commented 3 years ago

It does depend on the @SuppressWarnings("unchecked") at the top of the method, or is that not the thrust of your comment?

karllessard commented 3 years ago

If we take the #139 route, many things around type handling will be different. But for this particular discussion here, it looks like everything works fine as expected so far so if you agree @JimClarke5 , I suggest that we close it and open a new one of problematic cases arise again?