Open init-22 opened 4 days ago
@priyakasimbeg
I can resolve this by using jax.random.key_data(key), I've added the following lines in submission_runner.py file: data_rng = jax.random.key_data(data_rng) at line 213 and eval_rng = jax.random.key_data(eval_rng) at line 339
I did not encounter the error in other keys so maybe they dont need to changed. can you please tell me if this seems okay?
Hi Isaac, I think what we want to do instead is change all the places we index into the the rng to just use the array.
E.g. from the above traceback you'll notice in our code:
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 58, in _build_mnist_dataset
ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0])
We want to change data_rng[0]
to just data_rng
, since migrating to random.key will return 0 dimensional Arrays.
I would do a search for rng[0]
to find the pieces of code to correct.
System Info: Ubuntu 20.04, Python 3.11, Nvidia3080ti
Jax Versions: jax==0.4.35 jax-cuda12-pjrt==0.4.35 jax-cuda12-plugin==0.4.35 jaxlib==0.4.35
Getting the following error while migrating from PRNGKey to key.
Here is the full traceback: