mlcommons / algorithmic-efficiency

MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
https://mlcommons.org/en/groups/research-algorithms/
Apache License 2.0
321 stars 62 forks source link

Deepspeech norm and spec aug variant JAX: KeyError: 'batch_stats' #653

Closed priyakasimbeg closed 6 months ago

priyakasimbeg commented 6 months ago

Deepspeech norm and spec aug variant JAX breaks with KeyError: 'batch_stats'.

Description

Logs: link

Steps to Reproduce

python3 submission_runner.py --framework=jax --workload=librispeech_deepspeech_norm_and_spec_aug --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json --data_dir=/data/librispeech --num_tuning_trials=1 --experiment_dir=/experiment_runs --experiment_name=variants_target_setting/study_0 --overwrite=true --save_checkpoints=false --num_tuning_trials=1 --rng_seed=1692305324 --max_global_steps=48000 --librispeech_tokenizer_vocab_path=/data/librispeech/spm_model.vocab 2>&1 | tee -a /logs/librispeech_deepspeech_norm_and_spec_aug_jax_02-17-2024-17-22-51.log