sbrunk / storch

GPU accelerated deep learning and numeric computing for Scala 3.
https://storch.dev
Apache License 2.0
113 stars 7 forks source link

Fix return type of multinomial and add optional generator #55

Closed marcelluethi closed 1 year ago

marcelluethi commented 1 year ago

This PR fixes the return type of the multinomial distribution to int64. This is justified as the underlying pytorch function always returns a long (as described in #54). Furthermore, an additional optional argument is added that allows a generator to be passed to the function.

davoclavo commented 1 year ago

Thanks a lot for the fix, LGTM!

sbrunk commented 1 year ago

Thanks @marcelluethi!