ramsey-devs / ramsey

Probabilistic deep learning using JAX
https://ramsey.rtfd.io
Apache License 2.0
13 stars 3 forks source link

As batch iterators #34

Closed dirmeier closed 1 year ago

dirmeier commented 1 year ago

This is not good. Rewrite like this and document the fn:

 shuffle_key, rng_key = random.split(rng_key)
    shuffle_idxs = random.choice(shuffle_key, jnp.arange(n), shape=(n,), replace=False)
    if shuffle:
        data = ctor(*[el[shuffle_idxs] for _, el in enumerate(data)])

    y_train = ctor(*[el[:n_train] for el in data])
    y_val = ctor(*[el[n_train:] for el in data])
    train_rng_key, val_rng_key = random.split(rng_key)