sbrunk / storch

GPU accelerated deep learning and numeric computing for Scala 3.
https://storch.dev
Apache License 2.0
113 stars 7 forks source link

Multinomial should return Int64 instead of Float #54

Closed marcelluethi closed 1 year ago

marcelluethi commented 1 year ago

I ran into a problem while using torch.multinomial. When taking the resulting tensor, converting it to a Seq and mapping over the elements, I get the following error:

Exception in thread "main" java.lang.ClassCastException: class java.lang.Long cannot be cast to class java.lang.Float (java.lang.Long and java.lang.Float are in module java.base of loader 'bootstrap')

The problem is that pytorch returns a long, but the return type of storch is Tensor[D] (where D <: FloatNN):

  def multinomial[D <: FloatNN](
      input: Tensor[D],
      numSamples: Long,
      replacement: Boolean = false,
  ): Tensor[D] =

The following change fixes the problem for me:

  def multinomial[D <: FloatNN](
      input: Tensor[D],
      numSamples: Long,
      replacement: Boolean = false
  ): Tensor[Int64] =
    Tensor(torchNative.multinomial(input.native, numSamples, replacement), dtype=int64)

As there are multiple TODOs in the code, some of which I don't understand clearly enough, I did not start a PR. Let me know if you would like me to make a PR with this change.

ps. It would also be nice if multinomial would take a generator as an argument.

davoclavo commented 1 year ago

Hi @marcelluethi - You are correct, it indeed should return a Tensor[Int64] as it is returning indexes. Thanks a lot for reporting the issue, and your proposed fix looks good, definitely PR material.

Regarding the multiple TODOs I see in that operation:

  1. // TODO Demote Float to Int I guess was written at the time to question if the function always returns Int64 regardless of the input type, or if the tensor input type (eg, Float64, Float32) may cause the multinomial operation to return distinct tensor types. From what I am seeing in pytorch, it always returns a tensor of Int64, so feel free to delete that comment and implement your proposed change.

  2. // TODO Handle Optional Generators properly It would be definitely good to handle generators as part of the inputs for this operation (and incrementally to many other operations). @sbrunk added support for them recently in #50 - so I think you would able to implement your extra proposal as well. I think you would just have to add an input argument like generator: Option[Generator] | Generator = None and then you could use this extension method to convert it into a required type for the native java libtorch wrapper: "OptionalGenerator"

sbrunk commented 1 year ago

Yes the torch docs also say it returns a LongTensor: https://pytorch.org/docs/stable/generated/torch.multinomial.html#torch.multinomial

Fixed via #55

Greetings from Madrid, where Scala Days has just ended. I've given my Storch talk and had lot's of interesting conversations about it. Need to recover now.