eaplatanios / tensorflow_scala

TensorFlow API for the Scala Programming Language
http://platanios.org/tensorflow_scala/
Apache License 2.0
937 stars 95 forks source link

[WIP] Make Output parameterized with DataType. #14

Closed sbrunk closed 6 years ago

sbrunk commented 7 years ago

As discussed in #4. Here's a draft implementation of a parameterized Output[+T <: DataType]

I have only parameterized OutputLike and its implementations for now. Tensor and Op are unchanged.

Tests are completely broken right now due to the changes. I've focused on getting main to compile first. If the design makes sense, I'll work on correct runtime behavior next.

eaplatanios commented 7 years ago

@sbrunk Thanks for the work you did for this! It seems it was quite cumbersome. I can see some issues with this approach at this point. How about we start with the Tensor class first? Similar to what is proposed here. The reason I am saying this is because we can settle on a convenient design for tensors without having to mess with many parts of the library for now. Then, once we are confident about it we can move to Outputs. This would probably allow for a "cleaner" transition.

Also, "fixing" the Tensor class first would allows us to clean up the types package and the tensors package, without messing with any other part of the library (other than some small changes to the Session API).

sbrunk commented 7 years ago

Well, fortunately the changes were mostly mechanical. :)

Totally agree starting with typed Tensor. I actually tried that first but gave up soon due to my limited knowledge of the current design (and the reasoning behind it).

I'm on mobile right now so I can't really look into the proposed Java approach but at first glance they seem to parameterize their Tensor and Output class with the data type as well, right?

eaplatanios commented 7 years ago

Regarding the current design of Tensor, I don't really like it and I was sort of forced into it by being unable to support some simple features and make interoperability with Scala primitives smoother. We can work on that by first adding support for typed tensors and looking at what pops up.

And yes, their Java approach is similar to yours. It would be interesting to follow their discussion to see what issues they had. Given the smaller feature set of the official Java API, it might be safe to assume we'll have more issues.

sbrunk commented 7 years ago

Yes that might be worth a look. But I'm also hoping that Scala's type system allows us to avoid some of these issues.

sbrunk commented 7 years ago

Although we agreed to focus on Tensor first, I want to document the list of issues/questions I encountered making Output parameterized here for later:

The question that came up most of the time is how to restrict op input/output data types at compile time, i.e. if we want s.th. like [T <: INT32|INT64](input: Output[T]). As Scala 2.x doesn't have union types at the language level (they are in Dotty), we have to encode that restriction in another way.

I found three working solutions. All have pros and cons:

  1. Create a union trait for each of the combinations and let the concrete datatype objects extend each union trait they are part of (blows up the DataType hierarchy, difficult to extend, can't restrict T to multiple concrete data types, only to the trait i.e. you'd have to deal with Output[INT32OrFLOAT32], no implicits necessary).
  2. Create a type class for each of the combinations (quite verbose, can restrict input to a concrete DataType, requires implicits).
  3. Encode union types via Curry-Howard isomorphism (restriction can be defined inline like real union types, can restrict to concrete types, requires implicits, might have other problems due to advanced type system tricks). See also this discussion.

I made T in Output covariant because I thought it to be convienient for ops that don't care about the data type to take just Output[DataType] so you can put in i.e. an Output[FLOAT32]. But I realized this could cause problems when we want to restrict T for different inputs to be the same. i.e. calling def foo[T <: DataType](a: Output[T], b: Output[T]) with a: Output[INT32], b: Output[DataType] and without explicit type parameter infers T to DataType and compiles even though b could have any data type. I haven't tried to make it invariant yet, because it requires more work and I'm also not sure about other implications.

I'm also not quite sure if there are cases where the concrete data type can only determined at runtime (i.e. depends on a runtime calculation) and if that can cause problems with the proposed design.