raskr / rust-autograd

Tensors and differentiable operations (like TensorFlow) in Rust
MIT License
487 stars 37 forks source link

Custom rand::Rng #1

Closed shanegibbs closed 6 years ago

shanegibbs commented 6 years ago

Your library is exactly what I have been looking for :heart_eyes:!

I have been trying to rig it up to a wasm build and I hitting into a runtime error. When calling autograd::ndarray_ext::array_gen::glorot_uniform I get a stack trace that looks like the following:

std::panicking::rust_panic_with_hook
std::panicking::begin_panic
rand::jitter::get_nstime
rand::jitter::JitterRng::new_with_timer
rand::jitter::JitterRng::new
rand::StdRng::new
rand::thread_rng::THREAD_RNG_KEY::__init
_$LT$std..thread..local..LocalKey$LT$T$GT$$GT$4init
$LT$std..thread..local..LocalKey$LT$T$GT$$GT$8try_with
$LT$std..thread..local..LocalKey$LT$T$GT$$GT$4with
rand::thread_rng
rand::weak_rng
autograd::ndarray_ext::array_gen::glorot_uniform

So we are panicing in the rand::jitter::get_nstime function. This makes sense because when I look at jitter.rs:703, the implementation for wasm is unreachable!().

Is it possible to pass in a custom &mut rand::Rng to these functions? If that was the case, I could use something like pcg_rand.

Thanks :+1:

raskr commented 6 years ago

Thanks for your interest in this project! This is a curious issue whereas I'm not familiar with Web Assembly at all. Anyway, enabling array generator functions to use pre-instantiated Rng object solves this problem as you said. This abstraction is definitely useful, but I want to provide default Rng. Therefore the compromise will be something like:

let arr1 = ndarray_ext::ArrRng::default().glorot_uniform(&[2, 3]);
let arr2 = ndarray_ext::ArrRng::new(Pcg).glorot_uniform(&[2, 3]);

How do you think? In this case, the library implementation will be as follows...

    struct ArrRng<R: Rng> {
        rng: R
    }

    impl Default for ArrRng<rand::XorShiftRng>
    {
        fn default() -> Self
        {
            ArrRng {
                rng: rand::weak_rng()
            }
        }
    }

    impl<R: Rng> ArrRng<R> {
        pub fn new(rng: R) -> Self
        {
            ArrRng {
                rng
            }
        }
    }

    impl<R: Rng> ArrRng<R>
    {
        pub fn glorot_uniform(&mut self, shape: &[usize]) -> NdArray
        {
            let s = (6. / shape[0] as f64).sqrt();
            let dist = rand::distributions::Range::new(-s, s);
            NdArray::from_shape_fn(shape, |_| dist.ind_sample(&mut self.rng) as f32)
        }
        ...
    }
shanegibbs commented 6 years ago

Yes, that would solve my problem. I like this design too 🥇

It looks like you basically have it sorted out but let me know if you would like a hand with anything.

raskr commented 6 years ago

Ok, next version (may take several days) will contain code close to those. Thank you I will; immediate PRs are also welcome since this is an experimental project 👍