mlcommons / algorithmic-efficiency

MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
https://mlcommons.org/en/groups/research-algorithms/
Apache License 2.0
335 stars 69 forks source link

MNIST fails to run on main branch #62

Closed danielsnider closed 2 years ago

danielsnider commented 2 years ago

Recent merges have caused some bugs when trying to run the mnist workload (both jax and pytorch).

Description

The bugs were introduced sometime between now and this commit (which works):

commit 6f8c6bdc8f4055659d7f7e75a512152dbbd35daa (upstream/main)
Merge: cc58b23 489df83
Author: Frank <frank.stefan.schneider@gmail.com>
Date:   Wed Feb 23 11:10:59 2022 +0100

    Merge pull request #57 from fsschneider/rules_update_dependencies

    Add suggested rule changes RE dependencies

Steps to Reproduce

python3 submission_runner.py --framework=pytorch --workload=mnist_pytorch --submission_path=baselines/mnist/mnist
_pytorch/submission.py --tuning_search_space=baselines/mnist/tuning_search_space.json

or

python3 submission_runner.py --framework=jax --workload=mnist_jax --submission_path=baselines/mnist/mnist
_jax/submission.py --tuning_search_space=baselines/mnist/tuning_search_space.json

Source or Possible Fix

I found three bugs, but there are possibly more.

Bug 1. README.md is out of date.

The algorithmic_efficiency/submission_runner.py is now located at submission_runner.py, but the README.md doesn't reflect this change.

Bug 2. random_utils import issue

$ python3 submission_runner.py --framework=jax --workload=mnist_jax --submission_path=baselines/mnist/mnist_jax/sub
mission.py --tuning_search_space=baselines/mnist/tuning_search_space.json

Traceback (most recent call last):
  File "submission_runner.py", line 317, in <module>
    app.run(main)
  File "/root/.local/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/root/.local/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "submission_runner.py", line 302, in main
    workload = _import_workload(
  File "submission_runner.py", line 132, in _import_workload
    workload_module = importlib.import_module(workload_path)
  File "/usr/lib/python3.8/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1014, in _gcd_import
  File "<frozen importlib._bootstrap>", line 991, in _find_and_load                                                                                                       File "<frozen importlib._bootstrap>", line 975, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 671, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 848, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py", line 12, in <module>
    from algorithmic_efficiency.workloads.mnist.workload import Mnist
  File "/home/ubuntu/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 1, in <module>
    import random_utils as prng
ModuleNotFoundError: No module named 'random_utils'

Possible solution:

-import random_utils as prng
+from algorithmic_efficiency import random_utils as prng

Bug 3. Can't instantiate abstract class MnistWorkload

y# python3 submission_runner.py --framework=pytorch --workload=mnist_pytorch --submission_path=baselines/mnist/mnist
_pytorch/submission.py --tuning_search_space=baselines/mnist/tuning_search_space.json
Traceback (most recent call last):
  File "submission_runner.py", line 317, in <module>
    app.run(main)
  File "/root/.local/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/root/.local/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "submission_runner.py", line 302, in main
    workload = _import_workload(
  File "submission_runner.py", line 145, in _import_workload
    return workload_class()
TypeError: Can't instantiate abstract class MnistWorkload with abstract methods num_eval_train_examples, num_validation_examples, param_shapes

I didn't take a close look at how to fix this. If someone knows they introduced the issue maybe they can take a look. Otherwise, I can take a look but I'm too busy at the moment, sorry.

I hope this bug report helps!

scarere commented 2 years ago

A quick hot fix to get things running

https://github.com/UofT-EcoSystem/algorithmic-efficiency/commit/34473aa0cab396bf66869abf5487505423bfdb81

fsschneider commented 2 years ago

I looked at it and fixed the bugs described here in PR #65.

This should, however, be seen as a hotfix, not as a proper solution (and in its current form, the MNIST example is still not fully functional, see below).

We currently don't consistently use the num_X_examples property for evaluating the model. I would say, ideally, all models should follow the same protocol during eval_model, e.g. evaluate on the train set, validation set, and test set. This would use the num_X_examples property, but currently, we are not doing this for most models.

Similarly, we wanted to provide both the param_shapes and the model_params_types as properties for submitters, but in most workloads, they are not implemented. Ideally, they would be defined in the framework-agnostic base class of each workload.

For further details see the PR #65.

runame commented 2 years ago

See PR #71.