eaplatanios / tensorflow_scala

TensorFlow API for the Scala Programming Language
http://platanios.org/tensorflow_scala/
Apache License 2.0
936 stars 95 forks source link

ScalaType instead of DataType as type param. #122

Closed DirkToewe closed 6 years ago

DirkToewe commented 6 years ago

I almost feel bad for asking this, considering how much refactoring is involved, but the answer may have far reaching consequences. So here it comes: I wonder if it might not be better to use the Scala type as type parameter instead of the Tensorflow dtype. In other words: maybe Tensor[Double] would be preferable to Tensor[FLOAT64]? The dtype would then become something inferred implicitly similar to a ClassTag. There is a lot of reasons that come to my mind as to why this would be a good idea, but I might be wrong:

Intuition

DataType is an artifact from Numpy that is going to feel weird to everyone that doesn't know Numpy. For someone not having a Python background, there is going to be a bigger disparity between FLOAT32 and the actual content of type Float.

Type Inference Issues

When member types are involved, the Scala compiler seems to have real issues with type inference. As an example let's look at the following function:

def map[A <: DataType,B <: DataType]( a: Tensor[A] )( f: A#ScalaType => B#ScalaType )( implicit dt: B ): Tensor[B]
  = Tensor[B,B#ScalaType]( a.entriesIterator.map(f).toSeq: _* )
      .reshape(a.shape) ensuring ( dt.getClass isAssignableFrom _.dataType.getClass )

This will not compile since Aux[B,B#ScalaType] cannot be found. If DataType[B] were the implicit tag for a type B, the Aux pattern would not be necessary in the first place. Let's say we get this to compile:

def map[A <: DataType,B <: DataType]( a: Tensor[A], f: A#ScalaType => B#ScalaType )( implicit dt: B ): Tensor[B]
  = ???

println{ map( Tensor(1,2,3), (i: Int) => i*3f )(FLOAT32) }

The compiler is unable to infer type A correctly in this example:

[error]  type mismatch;
[error]  found   : Int => Float
[error]  required: ?#ScalaType => ?#ScalaType
[error]     println{ map( Tensor(1,2,3), (i: Int) => i*3f )(FLOAT32) }
[error]                                           ^

Without member types involved, the compiler has a much simpler life:

def map[A,B: ClassTag]( a: Seq[A], f: A => B )
  = a.map{ f(_) ensuring ( implicitly[ClassTag[B]].runtimeClass.isInstance{_} ) }

println{ map( Seq(1,2,3), (i: Int) => i*3f ) }

Verbosity

Let's say, for example, we wanted to implement a custom map() function for a tensor:

def map[A <: DataType,B <: DataType]( a: Tensor[A], f: A#ScalaType => B#ScalaType )( implicit dt: B ): Tensor[B] = ???

Even in this simple example, things are lot more verbose than they have to be.

def map[A,B: DataType]( a: Tensor[A], f: A => B ): Tensor[B] = ???

There may even be issues with the type inference by the compiler here. In my experience, type expressions in Scala explode very, very quickly and trying to keep them concise as possible is desireable.

Performance/Specialization

Scala already has specialization, even though it may still be a little tedious. With project valhalla specialization is going to become pretty much a free lunch, but most likely not with Scala's member types. While in most applications we should not at all care about boxing/specialization, it is very much relevant to Machine Learning. A boxed float can be 7 time larger than its unboxed counter part and adds additional GC overhead.

What about UInt32, QInt32, ... ?

Fortunately, Scala supports value types which would easily allow to implement new 64-and-less-bit data types that are efficient and intuitive so QInt32#ScalaType does not need to be Int.

class QInt16( val bits: Short ) extends AnyVal // <- this value type is equivalent to Short at runtime
{
  def toFloat = bits.toFloat / Short.MaxValue
}

Spire already offers unsigned and complex data types.

eaplatanios commented 6 years ago

@DirkToewe Sorry for the delaying responding to this for so long but this deserves a long response and I didn't want to post something quick while traveling. I have thought about this issue before and there are a few reasons I decided against it originally, even though I debated a lot with myself. I have been working on something that may resolve this and so I'll respond in a few days with days, depending on how that goes. :)

eaplatanios commented 6 years ago

@DirkToewe In the meantime, you can track some temporary progress I'm making while trying out stuff in the scala_types branch.

eaplatanios commented 6 years ago

I'll close this as I've made significant progress in this direction in #131. Please comment there if you have feedback and thanks for suggesting this because I debated a long time whether it's worth the effort and your post made me finally go for it. :)

eaplatanios commented 5 years ago

@DirkToewe #131 has finally been merged and it addresses both this suggestion that you made, as well as more broadly revamps typing of tensors in the TF Scala API. :)