mlcommons / algorithmic-efficiency

MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
https://mlcommons.org/en/groups/research-algorithms/
Apache License 2.0
328 stars 65 forks source link

Criteo JAX resnet variant breaks with TypeError #662

Closed priyakasimbeg closed 8 months ago

priyakasimbeg commented 8 months ago

Steps to Reproduce

python3 submission_runner.py --framework=jax --workload=criteo1tb_resnet --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py --tuning_search_space=reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json --data_dir=/data/criteo1tb --num_tuning_trials=1 --experiment_dir=/experiment_runs --experiment_name=variants_target_setting/study_0 --overwrite=true --save_checkpoints=false --num_tuning_trials=1 --rng_seed=1884662704 --max_global_steps=10666

Description

Log message:

Traceback (most recent call last):
  File "submission_runner.py", line 689, in <module>
    app.run(main)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "submission_runner.py", line 657, in main
    score = score_submission_on_workload(
  File "submission_runner.py", line 568, in score_submission_on_workload
    timing, metrics = train_once(workload, workload_name,
  File "submission_runner.py", line 221, in train_once
    model_params, model_state = workload.init_model_fn(
  File "/algorithmic-efficiency/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py", line 102, in init_model_fn
    initial_variables = jax.jit(init_fn)(
  File "/algorithmic-efficiency/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py", line 91, in __call__
    if self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2:
TypeError: '>' not supported between instances of 'NoneType' and 'float'

Looks related to the tuning hparam config.