sbrunk / storch

GPU accelerated deep learning and numeric computing for Scala 3.
https://storch.dev
Apache License 2.0
113 stars 7 forks source link

Add Generator #22

Closed sbrunk closed 1 year ago

sbrunk commented 1 year ago

Creates and returns a generator object that manages the state of the algorithm which produces pseudo random numbers. Used as a keyword argument in many In-place random sampling functions.

Via https://pytorch.org/docs/stable/generated/torch.Generator.html

We have to wait for https://github.com/bytedeco/javacpp-presets/issues/1259 to be resolved, then the impl could look s.th. like this:

import org.bytedeco.pytorch

/** Creates and returns a generator object that manages the state of the algorithm which produces pseudo random numbers.
  */
class Generator(val device: Device = Device.CPU) {
  private val native = ??? // waiting for make_generator_cpu/cuda

  /** Returns the Generator state as a [[torch.Tensor[UInt8]]. */
  def getState: Tensor[UInt8] = Tensor(native.get_state())

  /** Returns the initial seed for generating random numbers. */
  def initialSeed: Long = native.seed()

  /** Sets the seed for generating random numbers. Returns a torch.Generator object.
   * 
   * It is recommended to set a large seed, i.e. a number that has a good balance of 0 and 1 bits. Avoid having many 0 bits in the seed.
   */
  def manualSeed(seed: Long): Unit = native.set_current_seed(seed)

  def seed: Long = native.current_seed()
}
davoclavo commented 1 year ago

This would be great to get it working, hopefully this can be added soon in javacpp 🤞

It would be really useful to replicate results with python implementations in order to find discrepancies in results