This is not good. Rewrite like this and document the fn:
shuffle_key, rng_key = random.split(rng_key)
shuffle_idxs = random.choice(shuffle_key, jnp.arange(n), shape=(n,), replace=False)
if shuffle:
data = ctor(*[el[shuffle_idxs] for _, el in enumerate(data)])
y_train = ctor(*[el[:n_train] for el in data])
y_val = ctor(*[el[n_train:] for el in data])
train_rng_key, val_rng_key = random.split(rng_key)
This is not good. Rewrite like this and document the fn: