mlcommons / algorithmic-efficiency

MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
https://mlcommons.org/en/groups/research-algorithms/
Apache License 2.0
335 stars 69 forks source link

IndexError: Too many indices: 0-dimensional array indexed with 1 regular index, while migrating from jax.random.PRNGKey to jax.random.key #815

Open init-22 opened 4 days ago

init-22 commented 4 days ago

System Info: Ubuntu 20.04, Python 3.11, Nvidia3080ti

Jax Versions: jax==0.4.35 jax-cuda12-pjrt==0.4.35 jax-cuda12-plugin==0.4.35 jaxlib==0.4.35

Getting the following error while migrating from PRNGKey to key.

Here is the full traceback:

Traceback (most recent call last):
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 714, in <module>
    app.run(main)
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 682, in main
    score = score_submission_on_workload(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 587, in score_submission_on_workload
    timing, metrics = train_once(workload, workload_name,
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 221, in train_once
    input_queue = workload._build_input_queue(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 155, in _build_input_queue
    ds = _build_mnist_dataset(
         ^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 58, in _build_mnist_dataset
    ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0])
                                                 ~~~~~~~~^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 646, in _getitem
    return lax_numpy._rewriting_take(self, item)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11411, in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11420, in _gather
    indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11528, in _index_to_gather
    idx = _canonicalize_tuple_index(len(x_shape), idx)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11852, in _canonicalize_tuple_index
    raise IndexError(
IndexError: Too many indices: 0-dimensional array indexed with 1 regular index.
init-22 commented 2 days ago

@priyakasimbeg

I can resolve this by using jax.random.key_data(key), I've added the following lines in submission_runner.py file: data_rng = jax.random.key_data(data_rng) at line 213 and eval_rng = jax.random.key_data(eval_rng) at line 339

I did not encounter the error in other keys so maybe they dont need to changed. can you please tell me if this seems okay?

priyakasimbeg commented 6 hours ago

Hi Isaac, I think what we want to do instead is change all the places we index into the the rng to just use the array.

E.g. from the above traceback you'll notice in our code:

 File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 58, in _build_mnist_dataset
    ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0])

We want to change data_rng[0] to just data_rng, since migrating to random.key will return 0 dimensional Arrays. I would do a search for rng[0] to find the pieces of code to correct.