Closed marcelluethi closed 1 year ago
Hi @marcelluethi - You are correct, it indeed should return a Tensor[Int64]
as it is returning indexes.
Thanks a lot for reporting the issue, and your proposed fix looks good, definitely PR material.
Regarding the multiple TODOs I see in that operation:
// TODO Demote Float to Int
I guess was written at the time to question if the function always returns Int64
regardless of the input type, or if the tensor input type (eg, Float64, Float32) may cause the multinomial
operation to return distinct tensor types.
From what I am seeing in pytorch, it always returns a tensor of Int64
, so feel free to delete that comment and implement your proposed change.
// TODO Handle Optional Generators properly
It would be definitely good to handle generators as part of the inputs for this operation (and incrementally to many other operations). @sbrunk added support for them recently in #50 - so I think you would able to implement your extra proposal as well.
I think you would just have to add an input argument like generator: Option[Generator] | Generator = None
and then you could use this extension method to convert it into a required type for the native java libtorch wrapper: "OptionalGenerator"
Yes the torch docs also say it returns a LongTensor
: https://pytorch.org/docs/stable/generated/torch.multinomial.html#torch.multinomial
Fixed via #55
Greetings from Madrid, where Scala Days has just ended. I've given my Storch talk and had lot's of interesting conversations about it. Need to recover now.
I ran into a problem while using
torch.multinomial
. When taking the resulting tensor, converting it to aSeq
and mapping over the elements, I get the following error:The problem is that pytorch returns a
long
, but the return type of storch isTensor[D]
(whereD <: FloatNN
):The following change fixes the problem for me:
As there are multiple TODOs in the code, some of which I don't understand clearly enough, I did not start a PR. Let me know if you would like me to make a PR with this change.
ps. It would also be nice if multinomial would take a generator as an argument.