mlcommons / algorithmic-efficiency

MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
https://mlcommons.org/en/groups/research-algorithms/
Apache License 2.0
335 stars 69 forks source link

TypeError: JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; got dtype=int32 #801

Open init-22 opened 1 month ago

init-22 commented 1 month ago

I was trying to run the submission_runner.py file inside the docker and got a TypeError, Use these commands to reproduce the error:

sudo docker run -it   -v <PATH>/algorithmic-efficiency:/algorithmic-efficiency --runtime=nvidia algoperf_pytorch /bin/bash

cd algorithmic-efficiency

python3 submission_runner.py \
    --framework=pytorch \
    --workload=mnist \
    --experiment_dir=$HOME/experiments \
    --experiment_name=my_first_experiment \
    --submission_path=reference_algorithms/paper_baselines/adamw/jax/submission.py \
    --tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json

Here is the traceback:

Traceback (most recent call last):
  File "submission_runner.py", line 714, in <module>
    app.run(main)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "submission_runner.py", line 682, in main
    score = score_submission_on_workload(
  File "submission_runner.py", line 587, in score_submission_on_workload
    timing, metrics = train_once(workload, workload_name,
  File "submission_runner.py", line 351, in train_once
    optimizer_state, model_params, model_state = update_params(
  File "/algorithmic-efficiency/reference_algorithms/paper_baselines/adamw/jax/submission.py", line 130, in update_params
    per_device_rngs = jax.random.split(rng, jax.local_device_count())
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/random.py", line 217, in split
    key, wrapped = _check_prng_key(key)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/random.py", line 79, in _check_prng_key
    return prng.random_wrap(key, impl=default_prng_impl()), True
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/prng.py", line 907, in random_wrap
    _check_prng_key_data(impl, base_arr)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/prng.py", line 119, in _check_prng_key_data
    raise TypeError("JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; "
TypeError: JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; got dtype=int32

am I missing something?

priyakasimbeg commented 4 weeks ago

I noticed you ran with a docker container called algoperf_pytorch. Does this have the correct jax dependencies installed? Could you attach an output of pip freeze in the container environment?

init-22 commented 4 weeks ago

Oh yeah I created one for pytorch and just rebuilt it with both but got the same error Please cheeck this: pip_deps.txt