mlcommons / algorithmic-efficiency

MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
https://mlcommons.org/en/groups/research-algorithms/
Apache License 2.0
321 stars 62 forks source link

FastMRI JAX Dropout broken #664

Closed priyakasimbeg closed 6 months ago

priyakasimbeg commented 6 months ago

Description

FastMRI returns the following error:

flax.errors.InvalidRngError: Dropout_0 needs PRNG for "dropout" (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.InvalidRngError)

Logs: https://gist.github.com/priyakasimbeg/e2ecdd85340994fc15e01738dc7d8d2c

Steps to Reproduce

The final hparam point for the external prize_qualification_baselines has dropout set to 0.1. To repro either run the last point with hparam_start_index=4:

python3 submission_runner.py --framework=jax --workload=fastmri --submission_path=prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py --tuning_search_space=prize_qualification_baselines/external_tuning/tuning_search_space.json --data_dir=/data/fastmri --num_tuning_trials=1 --experiment_dir=/experiment_runs --experiment_name=prize_qualification/study_4 --overwrite=true --save_checkpoints=false --num_tuning_trials=5 --hparam_start_index=4 --rng_seed=813120851 --max_global_steps=36189

Or remove the other hparam points in the tuning_search_space.json and set tuning trials to 1.

priyakasimbeg commented 6 months ago

Fixed in https://github.com/mlcommons/algorithmic-efficiency/pull/672