eaplatanios / tensorflow_scala

TensorFlow API for the Scala Programming Language
http://platanios.org/tensorflow_scala/
Apache License 2.0
936 stars 96 forks source link

Make matrixBandPart work with other integer types than INT64 #49

Closed carlo-veezoo closed 6 years ago

carlo-veezoo commented 6 years ago

I would expect the following code to work:

import org.platanios.tensorflow.api._

object MatrixBandPart {
  def main(args: Array[String]): Unit = {
    tf.matrixBandPart(Seq(Seq(0)), 0, 0)
  }
}

However it does not, because matrixBandPart expects INT64 arguments. The changes make this work, and introduces checks to not cast e.g. float.

eaplatanios commented 6 years ago

Good catch! :)

eaplatanios commented 6 years ago

@csaladin94 FYI if you want to show up as a contributor and make sure your commits are linked to your account, you should associate carlo@veezoo.com with your GitHub account. :)

carlo-veezoo commented 6 years ago

@eaplatanios Thanks for the information, I just did that!