eaplatanios / tensorflow_scala

TensorFlow API for the Scala Programming Language
http://platanios.org/tensorflow_scala/
Apache License 2.0
937 stars 95 forks source link

Add typed Tensor #16

Closed sbrunk closed 7 years ago

sbrunk commented 7 years ago

As discussed in #14 it makes sense to start with a typed version of Tensor, perhaps similar to what's being proposed for the Java API.

sbrunk commented 7 years ago

Lets say we have a generic Tensor like

trait Tensor[T <: DataType]

Then, given the current design of DataType with ScalaType as a type member, we can keep the Scala type generic in Tensor:

def setElementAtFlattenedIndex(index: Int, value: T#ScalaType)
def getElementAtFlattenedIndex(index: Int): T#ScalaType
def fill(value: T#ScalaType)
// More flexible param types with casting should still be possible using SupportedType
...

A Tensor[FLOAT32] has T#ScalaType automatically fixed to Float. IMHO that's very nice property of the current DataType design.

The question is what happens when we require different behavior depending on the data type (especially for String tensors). I see two directions here:

What do you think?

sbrunk commented 7 years ago

@eaplatanios I did some experimentation using a generic Tensor[T <: DataType] trait and implementation classes like FLOAT32Tensor extends Tensor[FLOAT32].

Using a TensorFactory[T <: DataType] typeclass it's also possible to make tensor creation methods generic so we can have s.th. like Tensor.fill[T <: DataType : TensorFactory](value: T#ScalaType): Tensor[T] etc.

So in general I think this approach could work.

One big question is what happens when we know the data type of a tensor only at runtime. i.e. when we get it from the native library. We can still instantiate the right subclass but what type does it have and what operations does it support?

Similar to the current design using path dependent types? or perhaps the user has to cast it to the expected (possibly failing at runtime) tensor type first before doing any data type dependent operations.

eaplatanios commented 7 years ago

@sbrunk I believe that with the new eager execution API for tensors that gives us numpy-like functionality, it becomes harder to keep track of types on the Scala side, given how most tensors will have types known only at runtime (as a result of obtaining them through eagerly executing TensorFlow ops). Do you think it is still worth pursuing typed tensors and outputs? I'm currently focusing more on adding everything needed to be able to build a nice learning API abstraction (I've been working on adding support for creating functions in TF graphs, as well as for calling Scala functions from within TensorFlow ops -- similarly to how py_func works in the Python API -- these will soon be available).

sbrunk commented 7 years ago

You're right it probably doesn't make sense when we know the type only at runtime most of the time. So I'm closing this.