google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.96k stars 631 forks source link

Imagenet example does not work with newer flax versions #3364

Open gkroiz opened 12 months ago

gkroiz commented 12 months ago

Using newer versions of Flax, there seems to be an outdate API call in the imagenet test.

First, when testing with the most recent stable release of jax, 0.4.16 and flax 0.7.4, there was an orbax issue (most likely because orbax has not had a stable release since end of july.

Full error message (on TPU):

Using ssh batch size of 4. Attempting to SSH into 1 nodes with a total of 4 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
2023-09-21 23:18:03.034775: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-09-21 23:18:03.053484: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-09-21 23:18:03.063527: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-09-21 23:18:03.110393: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-09-21 23:18:03.582526: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-21 23:18:03.582586: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-09-21 23:18:03.582593: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-09-21 23:18:03.602052: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-21 23:18:03.602115: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-09-21 23:18:03.602121: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-09-21 23:18:03.622315: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-21 23:18:03.622377: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-09-21 23:18:03.622384: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-09-21 23:18:03.657278: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-21 23:18:03.657346: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-09-21 23:18:03.657353: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 32, in <module>
    import train
  File "/home/alijafari/flax/examples/imagenet/train.py", line 33, in <module>
    from flax.training import train_state
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/training/train_state.py", line 19, in <module>
    import optax
  File "/home/alijafari/.local/lib/python3.10/site-packages/optax/__init__.py", line 102, in <module>
    from optax._src.second_order import fisher_diag
  File "/home/alijafari/.local/lib/python3.10/site-packages/optax/_src/second_order.py", line 46, in <module>
    v: jnp.DeviceArray,
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'
Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 32, in <module>
    import train
  File "/home/alijafari/flax/examples/imagenet/train.py", line 33, in <module>
    from flax.training import train_state
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/training/train_state.py", line 19, in <module>
    import optax
  File "/home/alijafari/.local/lib/python3.10/site-packages/optax/__init__.py", line 102, in <module>
    from optax._src.second_order import fisher_diag
  File "/home/alijafari/.local/lib/python3.10/site-packages/optax/_src/second_order.py", line 46, in <module>
    v: jnp.DeviceArray,
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'
Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 32, in <module>
    import train
  File "/home/alijafari/flax/examples/imagenet/train.py", line 33, in <module>
    from flax.training import train_state
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/training/train_state.py", line 19, in <module>
    import optax
  File "/home/alijafari/.local/lib/python3.10/site-packages/optax/__init__.py", line 102, in <module>
    from optax._src.second_order import fisher_diag
  File "/home/alijafari/.local/lib/python3.10/site-packages/optax/_src/second_order.py", line 46, in <module>
    v: jnp.DeviceArray,
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'
Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 32, in <module>
    import train
  File "/home/alijafari/flax/examples/imagenet/train.py", line 33, in <module>
    from flax.training import train_state
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/training/train_state.py", line 19, in <module>
    import optax
  File "/home/alijafari/.local/lib/python3.10/site-packages/optax/__init__.py", line 102, in <module>
    from optax._src.second_order import fisher_diag
  File "/home/alijafari/.local/lib/python3.10/site-packages/optax/_src/second_order.py", line 46, in <module>
    v: jnp.DeviceArray,
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'
##### Command execution on worker 0 failed with exit status 1. Continuing.
##### Command execution on worker 1 failed with exit status 1. Continuing.
##### Command execution on worker 2 failed with exit status 1. Continuing.
##### Command execution on worker 3 failed with exit status 1. Continuing

We then tested a jax and flax stable release that allign with the most recent orbax release: jax 0.4.13 and flax 0.7.0. However, this ran into a new error:

How to reproduce:

pip install "jax[tpu]==0.4.13" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
git clone https://github.com/google/flax.git && cd flax/examples/imagenet && pip install -r requirements.txt && pip install flax==0.7.0
cd flax/examples/imagenet && JAX_PLATFORMS=tpu python3 imagenet_fake_data_benchmark.py

Full error message (on TPU):

Using ssh batch size of 4. Attempting to SSH into 1 nodes with a total of 4 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
2023-09-22 00:02:14.684726: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-09-22 00:02:14.714459: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-09-22 00:02:14.746449: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-09-22 00:02:14.755915: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-09-22 00:02:15.228367: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-22 00:02:15.228429: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-09-22 00:02:15.228435: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-09-22 00:02:15.260470: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-22 00:02:15.260531: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-09-22 00:02:15.260537: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-09-22 00:02:15.304349: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-22 00:02:15.304411: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-09-22 00:02:15.304417: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-09-22 00:02:15.305893: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-22 00:02:15.305955: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-09-22 00:02:15.305961: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Running tests under Python 3.10.12: /usr/bin/python3
2023-09-22 00:02:16.346399: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-09-22 00:02:16.346423: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
Running tests under Python 3.10.12: /usr/bin/python3
2023-09-22 00:02:16.368624: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-09-22 00:02:16.368648: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
Running tests under Python 3.10.12: /usr/bin/python3
2023-09-22 00:02:16.429315: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-09-22 00:02:16.429338: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
Running tests under Python 3.10.12: /usr/bin/python3
2023-09-22 00:02:16.438815: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-09-22 00:02:16.438839: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
I0922 00:02:26.652038 140051359447040 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0922 00:02:26.654772 140051359447040 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I0922 00:02:26.655005 140051359447040 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I0922 00:02:26.655146 140051359447040 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0922 00:02:26.655780 140051359447040 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[0:320291], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W0922 00:02:26.781402 140051359447040 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W0922 00:02:26.893633 140051359447040 deprecation.py:350] From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I0922 00:02:27.190137 140051359447040 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0922 00:02:27.190793 140051359447040 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[0:12500], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W0922 00:02:27.212210 140051359447040 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
/home/alijafari/.local/lib/python3.10/site-packages/optax/_src/schedule.py:391: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
  def schedule(step: jnp.DeviceArray) -> jnp.DeviceArray:
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 66, in <module>
    absltest.main()
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2049, in main
    _run_in_app(run_tests, args, kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2166, in _run_in_app
    app.run(main=main_function)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2164, in main_function
    function(argv, args, kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2568, in run_tests
    result = _run_and_get_tests_result(
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2537, in _run_and_get_tests_result
    test_program = unittest.TestProgram(*args, **kwargs)
  File "/usr/lib/python3.10/unittest/main.py", line 101, in __init__
    self.runTests()
  File "/usr/lib/python3.10/unittest/main.py", line 271, in runTests
    self.result = testRunner.run(self.test)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/_pretty_print_reporter.py", line 86, in run
    return super(TextTestRunner, self).run(test)
  File "/usr/lib/python3.10/unittest/runner.py", line 184, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/case.py", line 650, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/usr/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 50, in test_fake_data
    train.train_and_evaluate(config, workdir)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 361, in train_and_evaluate
    state = create_train_state(rng, config, model, image_size, learning_rate_fn)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 258, in create_train_state
    params, batch_stats = initialized(rng, image_size, model)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 69, in initialized
    variables = init({'params': key}, jnp.ones(input_shape, model.dtype))
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/api.py", line 306, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/alijafari/flax/examples/imagenet/train.py", line 67, in init
    return model.init(*args)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1662, in init
    _, v_out = self.init_with_output(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1567, in init_with_output
    return init_with_output(
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/core/scope.py", line 960, in wrapper
    raise ValueError('First argument passed to an init function should be a '
jax._src.traceback_util.UnfilteredStackTrace: ValueError: First argument passed to an init function should be a `jax.PRNGKey` or a dictionary mapping strings to `jax.PRNGKey`.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 50, in test_fake_data
    train.train_and_evaluate(config, workdir)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 361, in train_and_evaluate
    state = create_train_state(rng, config, model, image_size, learning_rate_fn)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 258, in create_train_state
    params, batch_stats = initialized(rng, image_size, model)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 69, in initialized
    variables = init({'params': key}, jnp.ones(input_shape, model.dtype))
  File "/home/alijafari/flax/examples/imagenet/train.py", line 67, in init
    return model.init(*args)
ValueError: First argument passed to an init function should be a `jax.PRNGKey` or a dictionary mapping strings to `jax.PRNGKey`.

======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 150, in tearDown
    self._report_benchmark_results()
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 258, in _report_benchmark_results
    raise ValueError('Unable to determine test name for reporting '
ValueError: Unable to determine test name for reporting benchmark results. Make sure you are using `self.report_*` methods.

----------------------------------------------------------------------
Ran 1 test in 10.909s

FAILED (errors=2)
I0922 00:02:28.195271 139886586542080 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0922 00:02:28.198028 139886586542080 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I0922 00:02:28.198243 139886586542080 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I0922 00:02:28.198381 139886586542080 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0922 00:02:28.200115 139886586542080 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[960873:1281164], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W0922 00:02:28.339387 139886586542080 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W0922 00:02:28.452126 139886586542080 deprecation.py:350] From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I0922 00:02:28.746239 139886586542080 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0922 00:02:28.746915 139886586542080 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[37500:50000], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W0922 00:02:28.768107 139886586542080 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
/home/alijafari/.local/lib/python3.10/site-packages/optax/_src/schedule.py:391: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
  def schedule(step: jnp.DeviceArray) -> jnp.DeviceArray:
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 66, in <module>
    absltest.main()
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2049, in main
    _run_in_app(run_tests, args, kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2166, in _run_in_app
    app.run(main=main_function)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2164, in main_function
    function(argv, args, kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2568, in run_tests
    result = _run_and_get_tests_result(
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2537, in _run_and_get_tests_result
    test_program = unittest.TestProgram(*args, **kwargs)
  File "/usr/lib/python3.10/unittest/main.py", line 101, in __init__
    self.runTests()
  File "/usr/lib/python3.10/unittest/main.py", line 271, in runTests
    self.result = testRunner.run(self.test)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/_pretty_print_reporter.py", line 86, in run
    return super(TextTestRunner, self).run(test)
  File "/usr/lib/python3.10/unittest/runner.py", line 184, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/case.py", line 650, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/usr/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 50, in test_fake_data
    train.train_and_evaluate(config, workdir)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 361, in train_and_evaluate
    state = create_train_state(rng, config, model, image_size, learning_rate_fn)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 258, in create_train_state
    params, batch_stats = initialized(rng, image_size, model)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 69, in initialized
    variables = init({'params': key}, jnp.ones(input_shape, model.dtype))
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/api.py", line 306, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/alijafari/flax/examples/imagenet/train.py", line 67, in init
    return model.init(*args)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1662, in init
    _, v_out = self.init_with_output(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1567, in init_with_output
    return init_with_output(
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/core/scope.py", line 960, in wrapper
    raise ValueError('First argument passed to an init function should be a '
jax._src.traceback_util.UnfilteredStackTrace: ValueError: First argument passed to an init function should be a `jax.PRNGKey` or a dictionary mapping strings to `jax.PRNGKey`.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 50, in test_fake_data
    train.train_and_evaluate(config, workdir)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 361, in train_and_evaluate
    state = create_train_state(rng, config, model, image_size, learning_rate_fn)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 258, in create_train_state
    params, batch_stats = initialized(rng, image_size, model)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 69, in initialized
    variables = init({'params': key}, jnp.ones(input_shape, model.dtype))
  File "/home/alijafari/flax/examples/imagenet/train.py", line 67, in init
    return model.init(*args)
ValueError: First argument passed to an init function should be a `jax.PRNGKey` or a dictionary mapping strings to `jax.PRNGKey`.

======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 150, in tearDown
    self._report_benchmark_results()
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 258, in _report_benchmark_results
    raise ValueError('Unable to determine test name for reporting '
ValueError: Unable to determine test name for reporting benchmark results. Make sure you are using `self.report_*` methods.

----------------------------------------------------------------------
Ran 1 test in 12.547s

FAILED (errors=2)
I0922 00:02:28.993444 139703183525888 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0922 00:02:28.996214 139703183525888 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I0922 00:02:28.996427 139703183525888 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I0922 00:02:28.996560 139703183525888 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0922 00:02:28.998337 139703183525888 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[320291:640582], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W0922 00:02:29.142283 139703183525888 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W0922 00:02:29.256080 139703183525888 deprecation.py:350] From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I0922 00:02:29.553791 139703183525888 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0922 00:02:29.554453 139703183525888 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[12500:25000], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0922 00:02:29.559327 139777785620480 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0922 00:02:29.562028 139777785620480 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I0922 00:02:29.562225 139777785620480 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I0922 00:02:29.562349 139777785620480 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0922 00:02:29.563956 139777785620480 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[640582:960873], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W0922 00:02:29.575566 139703183525888 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
/home/alijafari/.local/lib/python3.10/site-packages/optax/_src/schedule.py:391: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
  def schedule(step: jnp.DeviceArray) -> jnp.DeviceArray:
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 66, in <module>
    absltest.main()
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2049, in main
    _run_in_app(run_tests, args, kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2166, in _run_in_app
    app.run(main=main_function)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2164, in main_function
    function(argv, args, kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2568, in run_tests
    result = _run_and_get_tests_result(
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2537, in _run_and_get_tests_result
    test_program = unittest.TestProgram(*args, **kwargs)
  File "/usr/lib/python3.10/unittest/main.py", line 101, in __init__
    self.runTests()
  File "/usr/lib/python3.10/unittest/main.py", line 271, in runTests
    self.result = testRunner.run(self.test)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/_pretty_print_reporter.py", line 86, in run
    return super(TextTestRunner, self).run(test)
  File "/usr/lib/python3.10/unittest/runner.py", line 184, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/case.py", line 650, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/usr/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 50, in test_fake_data
    train.train_and_evaluate(config, workdir)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 361, in train_and_evaluate
    state = create_train_state(rng, config, model, image_size, learning_rate_fn)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 258, in create_train_state
    params, batch_stats = initialized(rng, image_size, model)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 69, in initialized
    variables = init({'params': key}, jnp.ones(input_shape, model.dtype))
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/api.py", line 306, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/alijafari/flax/examples/imagenet/train.py", line 67, in init
    return model.init(*args)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1662, in init
    _, v_out = self.init_with_output(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1567, in init_with_output
    return init_with_output(
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/core/scope.py", line 960, in wrapper
    raise ValueError('First argument passed to an init function should be a '
jax._src.traceback_util.UnfilteredStackTrace: ValueError: First argument passed to an init function should be a `jax.PRNGKey` or a dictionary mapping strings to `jax.PRNGKey`.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 50, in test_fake_data
    train.train_and_evaluate(config, workdir)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 361, in train_and_evaluate
    state = create_train_state(rng, config, model, image_size, learning_rate_fn)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 258, in create_train_state
    params, batch_stats = initialized(rng, image_size, model)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 69, in initialized
    variables = init({'params': key}, jnp.ones(input_shape, model.dtype))
  File "/home/alijafari/flax/examples/imagenet/train.py", line 67, in init
    return model.init(*args)
ValueError: First argument passed to an init function should be a `jax.PRNGKey` or a dictionary mapping strings to `jax.PRNGKey`.

======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 150, in tearDown
    self._report_benchmark_results()
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 258, in _report_benchmark_results
    raise ValueError('Unable to determine test name for reporting '
ValueError: Unable to determine test name for reporting benchmark results. Make sure you are using `self.report_*` methods.

----------------------------------------------------------------------
Ran 1 test in 13.262s

FAILED (errors=2)
W0922 00:02:29.722676 139777785620480 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W0922 00:02:29.835337 139777785620480 deprecation.py:350] From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I0922 00:02:30.127933 139777785620480 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0922 00:02:30.128622 139777785620480 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[25000:37500], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W0922 00:02:30.149866 139777785620480 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
/home/alijafari/.local/lib/python3.10/site-packages/optax/_src/schedule.py:391: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
  def schedule(step: jnp.DeviceArray) -> jnp.DeviceArray:
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 66, in <module>
    absltest.main()
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2049, in main
    _run_in_app(run_tests, args, kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2166, in _run_in_app
    app.run(main=main_function)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2164, in main_function
    function(argv, args, kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2568, in run_tests
    result = _run_and_get_tests_result(
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/absltest.py", line 2537, in _run_and_get_tests_result
    test_program = unittest.TestProgram(*args, **kwargs)
  File "/usr/lib/python3.10/unittest/main.py", line 101, in __init__
    self.runTests()
  File "/usr/lib/python3.10/unittest/main.py", line 271, in runTests
    self.result = testRunner.run(self.test)
  File "/home/alijafari/.local/lib/python3.10/site-packages/absl/testing/_pretty_print_reporter.py", line 86, in run
    return super(TextTestRunner, self).run(test)
  File "/usr/lib/python3.10/unittest/runner.py", line 184, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/case.py", line 650, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/usr/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 50, in test_fake_data
    train.train_and_evaluate(config, workdir)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 361, in train_and_evaluate
    state = create_train_state(rng, config, model, image_size, learning_rate_fn)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 258, in create_train_state
    params, batch_stats = initialized(rng, image_size, model)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 69, in initialized
    variables = init({'params': key}, jnp.ones(input_shape, model.dtype))
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/api.py", line 306, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/alijafari/flax/examples/imagenet/train.py", line 67, in init
    return model.init(*args)
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1662, in init
    _, v_out = self.init_with_output(
  File "/home/alijafari/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1567, in init_with_output
    return init_with_output(
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/core/scope.py", line 960, in wrapper
    raise ValueError('First argument passed to an init function should be a '
jax._src.traceback_util.UnfilteredStackTrace: ValueError: First argument passed to an init function should be a `jax.PRNGKey` or a dictionary mapping strings to `jax.PRNGKey`.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 50, in test_fake_data
    train.train_and_evaluate(config, workdir)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 361, in train_and_evaluate
    state = create_train_state(rng, config, model, image_size, learning_rate_fn)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 258, in create_train_state
    params, batch_stats = initialized(rng, image_size, model)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 69, in initialized
    variables = init({'params': key}, jnp.ones(input_shape, model.dtype))
  File "/home/alijafari/flax/examples/imagenet/train.py", line 67, in init
    return model.init(*args)
ValueError: First argument passed to an init function should be a `jax.PRNGKey` or a dictionary mapping strings to `jax.PRNGKey`.

======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 150, in tearDown
    self._report_benchmark_results()
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 258, in _report_benchmark_results
    raise ValueError('Unable to determine test name for reporting '
ValueError: Unable to determine test name for reporting benchmark results. Make sure you are using `self.report_*` methods.

----------------------------------------------------------------------
Ran 1 test in 13.907s

FAILED (errors=2)
##### Command execution on worker 2 failed with exit status 1. Continuing.
##### Command execution on worker 1 failed with exit status 1. Continuing.
##### Command execution on worker 3 failed with exit status 1. Continuing.
##### Command execution on worker 0 failed with exit status 1. Continuing.

System information

cgarciae commented 12 months ago

I cannot reproduce this on single TPU node (v4-8), what setup are you using?

gkroiz commented 12 months ago

@cgarciae, Here is the setup:

# Install newest version of JAX and jaxlib
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_ID --zone=$ZONE --worker=all --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

# Clone the ImageNet model and install the corresponding requirements:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_ID --zone=$ZONE --worker=all --command='git clone https://github.com/google/flax.git && cd flax/examples/imagenet && pip install -r requirements.txt && pip install flax==0.7.4'

# To generate fake data, the model needs information on the dimensions of the dataset. This can be gathered from the ImageNet dataset's metadata:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_ID --zone=$ZONE --worker=all --command='mkdir -p $HOME/flax/.tfds/metadata/imagenet2012/5.1.0 && curl https://raw.githubusercontent.com/tensorflow/datasets/v4.4.0/tensorflow_datasets/testing/metadata/imagenet2012/5.1.0/dataset_info.json --output $HOME/flax/.tfds/metadata/imagenet2012/5.1.0/dataset_info.json'

# Train the model
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_ID --zone=$ZONE --worker=all --command='cd flax/examples/imagenet && JAX_PLATFORMS=tpu python3 imagenet_fake_data_benchmark.py'
cgarciae commented 12 months ago

I see. This is multi-host right? Sadly I cannot test this.

andsteing commented 12 months ago

How exactly did you set up your VM?

I tried the following:

PROJECT_ID=...
TPU_NAME=isssue_3364
ZONE=us-central1-a
gcloud alpha compute tpus tpu-vm create $TPU_NAME \
    --zone=$ZONE \
    --version v2-alpha --accelerator-type v2-32

But then on the first command

# Install newest version of JAX and jaxlib
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_ID --zone=$ZONE --worker=all --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

I get the following error:

ERROR: Could not find a version that satisfies the requirement jax[tpu]==0.4.16 (from versions: 0.0, 0.1, 0.1.1, 0.1.2, 0.1.3, 0.1.4, 0.1.5, 0.1.6, 0.1.7, 0.1.8, 0.1.9, 0.1.10, 0.1.11, 0.1.12, 0.1.13, 0.1.14, 0.1.15, 0.1.16, 0.1.18, 0.1.19, 0.1.20, 0.1.21, 0.1.22, 0.1.23, 0.1.24, 0.1.25, 0.1.26, 0.1.27, 0.1.28, 0.1.29, 0.1.30, 0.1.31, 0.1.32, 0.1.33, 0.1.34, 0.1.35, 0.1.36, 0.1.37, 0.1.38, 0.1.39, 0.1.40, 0.1.41, 0.1.42, 0.1.43, 0.1.44, 0.1.45, 0.1.46, 0.1.47, 0.1.48, 0.1.49, 0.1.50, 0.1.51, 0.1.52, 0.1.53, 0.1.54, 0.1.55, 0.1.56, 0.1.57, 0.1.58, 0.1.59, 0.1.60, 0.1.61, 0.1.62, 0.1.63, 0.1.64, 0.1.65, 0.1.66, 0.1.67, 0.1.68, 0.1.69, 0.1.70, 0.1.71, 0.1.72, 0.1.73, 0.1.74, 0.1.75, 0.1.76, 0.1.77, 0.2.0, 0.2.1, 0.2.2, 0.2.3, 0.2.4, 0.2.5, 0.2.6, 0.2.7, 0.2.8, 0.2.9, 0.2.10, 0.2.11, 0.2.12, 0.2.13, 0.2.14, 0.2.15, 0.2.16, 0.2.17, 0.2.18, 0.2.19, 0.2.20, 0.2.21, 0.2.22, 0.2.23, 0.2.24, 0.2.25, 0.2.26, 0.2.27, 0.2.28, 0.3.0, 0.3.1, 0.3.2, 0.3.3, 0.3.4, 0.3.5, 0.3.6, 0.3.7, 0.3.8, 0.3.9, 0.3.10, 0.3.11, 0.3.12, 0.3.13, 0.3.14, 0.3.15, 0.3.16, 0.3.17, 0.3.18, 0.3.19, 0.3.20, 0.3.21, 0.3.22, 0.3.23, 0.3.24, 0.3.25, 0.4.0, 0.4.1, 0.4.2, 0.4.3, 0.4.4, 0.4.5, 0.4.6, 0.4.7, 0.4.8, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13)
ERROR: No matching distribution found for jax[tpu]==0.4.16
ERROR: Ignored the following versions that require a different python version: 0.4.14 Requires-Python >=3.9; 0.4.15 Requires-Python >=3.9; 0.4.16 Requires-Python >=3.9
gkroiz commented 12 months ago

@andsteing the setup for tpu v5e is a bit different than v2. Could you try with the flag --tpu-ubuntu2204-base, this should have a newer python version.

andsteing commented 12 months ago

@gkroiz what was the exact command you used to start the TPU VMs?

Ah, that's probably also why @cgarciae could not reproduce the issue.

@cgarciae can you try with that image and see if you can reproduce the problem?

gkroiz commented 12 months ago

@andsteing


export PROJECT_ID=your_project_ID
export ACCELERATOR_TYPE=v5e-16
export ZONE=us-west4-a
export RUNTIME_VERSION=v2-alpha-tpuv5-lite
export SERVICE_ACCOUNT=your_service_account
export TPU_NAME=your_tpu_name
export QUEUED_RESOURCE_ID=your_queued_resource_id
export VALID_DURATION=1d

gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
--node-id ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--accelerator-type ${ACCELERATOR_TYPE} \
--runtime-version ${RUNTIME_VERSION} \
--valid-until-duration ${VALID_DURATION} \
--service-account ${SERVICE_ACCOUNT} \
--reserved
andsteing commented 12 months ago

I can confirm I can reproduce the AttributeError: module 'jax.numpy' has no attribute 'DeviceArray' on a fresh Colab with a CPU runtime (Python 3.10.12).

Setup:

!pip install jax==0.4.16 jaxlib==0.4.16
!git clone --depth=1 https://github.com/google/flax.git
!cd flax/examples/imagenet && pip install -r requirements.txt
!pip install flax==0.7.4

Command:

!cd flax/examples/imagenet && python3 imagenet_fake_data_benchmark.py

Error:

Traceback (most recent call last):
  File "/content/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 32, in <module>
    import train
  File "/content/flax/examples/imagenet/train.py", line 33, in <module>
    from flax.training import train_state
  File "/usr/local/lib/python3.10/dist-packages/flax/training/train_state.py", line 19, in <module>
    import optax
  File "/usr/local/lib/python3.10/dist-packages/optax/__init__.py", line 102, in <module>
    from optax._src.second_order import fisher_diag
  File "/usr/local/lib/python3.10/dist-packages/optax/_src/second_order.py", line 46, in <module>
    v: jnp.DeviceArray,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'

That's with

>>> !pip freeze | egrep 'jax|flax'
flax==0.7.4
jax==0.4.16
jaxlib==0.4.16
andsteing commented 12 months ago

@gkroiz

So the problem seems to be that the optax version pinned in the example's requiements.txt is not compatible:

https://github.com/google/flax/blob/242f84cac883108eb1e945221c5c544bef6cbd21/examples/imagenet/requirements.txt#L9

If you do a pip install flax==0.7.4 optax==0.1.7 after installing the requirements, this should be compatible with pip install jax==0.4.16 jaxlib==0.4.16

Note though that if you manually change jax and flax versions, then the dependencies as specified in requirements.txt are not expected to be compatible with your setup anymore.

andsteing commented 12 months ago

@cgarciae wdyt should we update (some of) our examples to newer flax/jax versions?

cgarciae commented 11 months ago

@chiamp has been recently updating the examples, do you want to take care of this issue? Else, happy to add it to my list 🙂

chiamp commented 11 months ago

sure let me take a look!

gkroiz commented 11 months ago

@andsteing using optax 0.1.7 we run into a new error:

Log:

Using ssh batch size of 4. Attempting to SSH into 1 nodes with a total of 4 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
2023-09-27 21:01:54.041649: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-09-27 21:01:54.051241: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-09-27 21:01:54.078029: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-09-27 21:01:54.118428: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-09-27 21:01:54.584898: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-27 21:01:54.584958: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-09-27 21:01:54.584964: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-09-27 21:01:54.603725: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-27 21:01:54.603785: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-09-27 21:01:54.603791: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-09-27 21:01:54.629032: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-27 21:01:54.629093: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-09-27 21:01:54.629099: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-09-27 21:01:54.663411: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-27 21:01:54.663472: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-09-27 21:01:54.663478: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Running tests under Python 3.10.12: /usr/bin/python3
2023-09-27 21:01:55.708319: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-09-27 21:01:55.708343: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
Running tests under Python 3.10.12: /usr/bin/python3
2023-09-27 21:01:55.748354: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-09-27 21:01:55.748378: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
Running tests under Python 3.10.12: /usr/bin/python3
2023-09-27 21:01:55.757288: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-09-27 21:01:55.757312: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
Running tests under Python 3.10.12: /usr/bin/python3
2023-09-27 21:01:55.796781: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-09-27 21:01:55.796805: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695848515.973369    7237 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/alijafari/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1695848515.973399    7237 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1695848515.973401    7237 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695848516.002654    7821 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/alijafari/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1695848516.002680    7821 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1695848516.002682    7821 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695848516.020979    7980 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/alijafari/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1695848516.021006    7980 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1695848516.021008    7980 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695848516.071177    6699 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/alijafari/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1695848516.071208    6699 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1695848516.071210    6699 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
I0927 21:02:06.333447 139675962472448 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0927 21:02:06.336215 139675962472448 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I0927 21:02:06.336426 139675962472448 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I0927 21:02:06.336564 139675962472448 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0927 21:02:06.337304 139675962472448 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[320291:640582], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W0927 21:02:06.383298 139675962472448 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I0927 21:02:06.510347 140029308213248 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0927 21:02:06.513039 140029308213248 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I0927 21:02:06.513241 140029308213248 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I0927 21:02:06.513373 140029308213248 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0927 21:02:06.514047 140029308213248 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[0:320291], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0927 21:02:06.557687 140002190469120 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0927 21:02:06.560446 140002190469120 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I0927 21:02:06.560657 140002190469120 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I0927 21:02:06.560795 140002190469120 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0927 21:02:06.561531 140002190469120 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[640582:960873], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
WARNING:tensorflow:From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W0927 21:02:06.591139 139675962472448 deprecation.py:350] From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W0927 21:02:06.606846 140002190469120 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W0927 21:02:06.644832 140029308213248 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W0927 21:02:06.757888 140029308213248 deprecation.py:350] From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W0927 21:02:06.816365 140002190469120 deprecation.py:350] From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I0927 21:02:06.884829 139675962472448 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0927 21:02:06.885504 139675962472448 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[12500:25000], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W0927 21:02:06.906559 139675962472448 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I0927 21:02:07.024086 139773692135424 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0927 21:02:07.026850 139773692135424 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I0927 21:02:07.027071 139773692135424 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I0927 21:02:07.027209 139773692135424 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0927 21:02:07.027959 139773692135424 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[960873:1281164], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0927 21:02:07.052164 140029308213248 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0927 21:02:07.052843 140029308213248 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[0:12500], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W0927 21:02:07.074269 140029308213248 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W0927 21:02:07.076008 139773692135424 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I0927 21:02:07.112762 140002190469120 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0927 21:02:07.113440 140002190469120 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[25000:37500], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W0927 21:02:07.134758 140002190469120 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W0927 21:02:07.290289 139773692135424 deprecation.py:350] From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I0927 21:02:07.593482 139773692135424 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I0927 21:02:07.594168 139773692135424 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[37500:50000], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W0927 21:02:07.615502 139773692135424 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I0927 21:02:31.140465 139675962472448 checkpoints.py:1054] Found no checkpoint files in /tmp/tmp0_xrlhyo/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I0927 21:02:31.222093 139675962472448 train.py:378] Initial compilation, this might take some minutes...
I0927 21:02:31.228001 140029308213248 checkpoints.py:1054] Found no checkpoint files in /tmp/tmp8697fv0l/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I0927 21:02:31.310585 140029308213248 train.py:378] Initial compilation, this might take some minutes...
I0927 21:02:31.531608 140002190469120 checkpoints.py:1054] Found no checkpoint files in /tmp/tmp8cqa1xr7/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I0927 21:02:31.612831 140002190469120 train.py:378] Initial compilation, this might take some minutes...
I0927 21:02:31.977073 139773692135424 checkpoints.py:1054] Found no checkpoint files in /tmp/tmp_z_goiap/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I0927 21:02:32.058211 139773692135424 train.py:378] Initial compilation, this might take some minutes...
I0927 21:03:06.568743 139675962472448 train.py:384] Initial compilation completed.
I0927 21:03:06.944715 140029308213248 train.py:384] Initial compilation completed.
I0927 21:03:07.221070 140002190469120 train.py:384] Initial compilation completed.
I0927 21:03:07.484132 139773692135424 train.py:384] Initial compilation completed.
I0927 21:03:14.630924 140029308213248 local.py:50] Created artifact [10] Profile of type ArtifactType.URL and value None.
I0927 21:04:00.030422 140013838800448 logging_writer.py:35] [100] steps_per_second=1.127150, train_accuracy=0.011674804612994194, train_learning_rate=0.05076923221349716, train_loss=6.274923324584961
I0927 21:04:00.031371 139660464608832 logging_writer.py:35] [100] steps_per_second=1.126012, train_accuracy=0.011674804612994194, train_learning_rate=0.05076923221349716, train_loss=6.274923324584961
I0927 21:04:00.031908 139758197134912 logging_writer.py:35] [100] steps_per_second=1.136707, train_accuracy=0.011674804612994194, train_learning_rate=0.05076923221349716, train_loss=6.274923324584961
I0927 21:04:00.032414 139986689906240 logging_writer.py:35] [100] steps_per_second=1.130975, train_accuracy=0.011674804612994194, train_learning_rate=0.05076923221349716, train_loss=6.274923324584961
I0927 21:04:56.350251 139986689906240 logging_writer.py:35] [200] steps_per_second=1.775658, train_accuracy=0.051142577081918716, train_learning_rate=0.15333333611488342, train_loss=5.4565749168396
I0927 21:04:56.350302 139660464608832 logging_writer.py:35] [200] steps_per_second=1.775626, train_accuracy=0.051142577081918716, train_learning_rate=0.15333333611488342, train_loss=5.4565749168396
I0927 21:04:56.350344 140013838800448 logging_writer.py:35] [200] steps_per_second=1.775663, train_accuracy=0.051142577081918716, train_learning_rate=0.15333333611488342, train_loss=5.4565749168396
I0927 21:04:56.350569 139758197134912 logging_writer.py:35] [200] steps_per_second=1.775630, train_accuracy=0.051142577081918716, train_learning_rate=0.15333333611488342, train_loss=5.4565749168396
I0927 21:05:52.401964 139660464608832 logging_writer.py:35] [300] steps_per_second=1.784069, train_accuracy=0.13822509348392487, train_learning_rate=0.2558974325656891, train_loss=4.702647686004639
I0927 21:05:52.402337 139986689906240 logging_writer.py:35] [300] steps_per_second=1.784058, train_accuracy=0.13822509348392487, train_learning_rate=0.2558974325656891, train_loss=4.702647686004639
I0927 21:05:52.402663 139758197134912 logging_writer.py:35] [300] steps_per_second=1.784057, train_accuracy=0.13822509348392487, train_learning_rate=0.2558974325656891, train_loss=4.702647686004639
I0927 21:05:52.415851 140013838800448 logging_writer.py:35] [300] steps_per_second=1.783629, train_accuracy=0.13822509348392487, train_learning_rate=0.2558974325656891, train_loss=4.702647686004639
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 50, in test_fake_data
    train.train_and_evaluate(config, workdir)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 413, in train_and_evaluate
    eval_metrics = common_utils.get_metrics(eval_metrics)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/training/common_utils.py", line 86, in get_metrics
    return stack_forest(metrics_np)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/training/common_utils.py", line 69, in stack_forest
    return jax.tree_util.tree_map(stack_args, *forest)
TypeError: tree_map() missing 1 required positional argument: 'tree'

======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 153, in tearDown
    self._report_benchmark_results()
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 262, in _report_benchmark_results
    raise ValueError(
ValueError: Unable to determine test name for reporting benchmark results. Make sure you are using `self.report_*` methods.

----------------------------------------------------------------------
Ran 1 test in 238.006s

FAILED (errors=2)
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 50, in test_fake_data
    train.train_and_evaluate(config, workdir)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 413, in train_and_evaluate
    eval_metrics = common_utils.get_metrics(eval_metrics)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/training/common_utils.py", line 86, in get_metrics
    return stack_forest(metrics_np)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/training/common_utils.py", line 69, in stack_forest
    return jax.tree_util.tree_map(stack_args, *forest)
TypeError: tree_map() missing 1 required positional argument: 'tree'

======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 153, in tearDown
    self._report_benchmark_results()
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 262, in _report_benchmark_results
    raise ValueError(
ValueError: Unable to determine test name for reporting benchmark results. Make sure you are using `self.report_*` methods.

----------------------------------------------------------------------
Ran 1 test in 239.695s

FAILED (errors=2)
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 50, in test_fake_data
    train.train_and_evaluate(config, workdir)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 413, in train_and_evaluate
    eval_metrics = common_utils.get_metrics(eval_metrics)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/training/common_utils.py", line 86, in get_metrics
    return stack_forest(metrics_np)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/training/common_utils.py", line 69, in stack_forest
    return jax.tree_util.tree_map(stack_args, *forest)
TypeError: tree_map() missing 1 required positional argument: 'tree'

======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 153, in tearDown
    self._report_benchmark_results()
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 262, in _report_benchmark_results
    raise ValueError(
ValueError: Unable to determine test name for reporting benchmark results. Make sure you are using `self.report_*` methods.

----------------------------------------------------------------------
Ran 1 test in 240.987s

FAILED (errors=2)
##### Command execution on worker 3 failed with exit status 1. Continuing.
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
[  FAILED  ] ImagenetBenchmarkFakeData.test_fake_data
======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 50, in test_fake_data
    train.train_and_evaluate(config, workdir)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 413, in train_and_evaluate
    eval_metrics = common_utils.get_metrics(eval_metrics)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/training/common_utils.py", line 86, in get_metrics
    return stack_forest(metrics_np)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/training/common_utils.py", line 69, in stack_forest
    return jax.tree_util.tree_map(stack_args, *forest)
TypeError: tree_map() missing 1 required positional argument: 'tree'

======================================================================
ERROR: test_fake_data (__main__.ImagenetBenchmarkFakeData)
ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 153, in tearDown
    self._report_benchmark_results()
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/testing/benchmark.py", line 262, in _report_benchmark_results
    raise ValueError(
ValueError: Unable to determine test name for reporting benchmark results. Make sure you are using `self.report_*` methods.

----------------------------------------------------------------------
Ran 1 test in 243.985s

FAILED (errors=2)
##### Command execution on worker 1 failed with exit status 1. Continuing.
##### Command execution on worker 0 failed with exit status 1. Continuing.
##### Command execution on worker 2 failed with exit status 1. Continuing.
andsteing commented 11 months ago

@gkroiz

Can specify the system version, Python version, Flax repo commit hash, and output of pip freeze from one of the machines?

gkroiz commented 11 months ago

This is once again on v5e-16, using the following setup for creating the TPU

@andsteing

export PROJECT_ID=your_project_ID
export ACCELERATOR_TYPE=v5e-16
export ZONE=us-west4-a
export RUNTIME_VERSION=v2-alpha-tpuv5-lite
export SERVICE_ACCOUNT=your_service_account
export TPU_NAME=your_tpu_name
export QUEUED_RESOURCE_ID=your_queued_resource_id
export VALID_DURATION=1d

gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
--node-id ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--accelerator-type ${ACCELERATOR_TYPE} \
--runtime-version ${RUNTIME_VERSION} \
--valid-until-duration ${VALID_DURATION} \
--service-account ${SERVICE_ACCOUNT} \
--reserved

Setup within TPU

@cgarciae, Here is the setup:

# Install newest version of JAX and jaxlib
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_ID --zone=$ZONE --worker=all --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

# Clone the ImageNet model and install the corresponding requirements:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_ID --zone=$ZONE --worker=all --command='git clone https://github.com/google/flax.git && cd flax/examples/imagenet && pip install -r requirements.txt && pip install flax==0.7.4'

# To generate fake data, the model needs information on the dimensions of the dataset. This can be gathered from the ImageNet dataset's metadata:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_ID --zone=$ZONE --worker=all --command='mkdir -p $HOME/flax/.tfds/metadata/imagenet2012/5.1.0 && curl https://raw.githubusercontent.com/tensorflow/datasets/v4.4.0/tensorflow_datasets/testing/metadata/imagenet2012/5.1.0/dataset_info.json --output $HOME/flax/.tfds/metadata/imagenet2012/5.1.0/dataset_info.json'

# Train the model
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_ID --zone=$ZONE --worker=all --command='cd flax/examples/imagenet && JAX_PLATFORMS=tpu python3 imagenet_fake_data_benchmark.py'

Can specify the system version, Python version, Flax repo commit hash, and output of pip freeze from one of the machines?

I'll send a follow up on this once I can rerun

gkroiz commented 11 months ago

Can specify the system version, Python version, Flax repo commit hash, and output of pip freeze from one of the machines?

Following up here: python version: 3.10.12 OS: Ubuntu 22.04.2 LTS Flax repo commit hash: 242f84cac883108eb1e945221c5c544bef6cbd21 Pip freeze output:

absl-py==1.0.0
astunparse==1.6.3
attrs==21.2.0
Automat==20.2.0
Babel==2.8.0
bcrypt==3.2.0
blinker==1.4
cachetools==5.3.1
certifi==2020.6.20
chardet==4.0.0
charset-normalizer==3.2.0
chex==0.1.7
click==8.0.3
cloud-init==23.2.1
clu==0.0.6
colorama==0.4.4
command-not-found==0.3
configobj==5.0.6
constantly==15.1.0
contextlib2==21.6.0
contourpy==1.1.1
cryptography==3.4.8
cycler==0.11.0
Cython==0.29.28
dbus-python==1.2.18
dill==0.3.7
distlib==0.3.7
distro==1.7.0
distro-info===1.1build1
dm-tree==0.1.8
etils==1.5.0
filelock==3.12.2
flatbuffers==23.5.26
flax==0.7.4
fonttools==4.42.1
fsspec==2023.9.2
future==0.18.3
gast==0.4.0
google-auth==2.23.1
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
googleapis-common-protos==1.60.0
grpcio==1.58.0
h5py==3.9.0
httplib2==0.20.2
hyperlink==21.0.0
idna==3.3
importlib-metadata==4.6.4
importlib-resources==6.1.0
incremental==21.3.0
jax==0.4.16
jaxlib==0.4.16
jeepney==0.7.1
Jinja2==3.0.3
jsonpatch==1.32
jsonpointer==2.0
jsonschema==3.2.0
keras==2.11.0
keyring==23.5.0
kiwisolver==1.4.5
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
libclang==16.0.6
libtpu-nightly==0.1.dev20230918
Markdown==3.4.4
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.8.0
mdurl==0.1.2
ml-collections==0.1.0
ml-dtypes==0.3.1
more-itertools==8.10.0
msgpack==1.0.6
nest-asyncio==1.5.8
netifaces==0.11.0
numpy==1.22.0
oauthlib==3.2.0
opt-einsum==3.3.0
optax==0.1.7
orbax==0.1.9
orbax-checkpoint==0.4.0
packaging==21.3
pexpect==4.8.0
Pillow==10.0.1
platformdirs==3.10.0
promise==2.3
protobuf==3.19.6
ptyprocess==0.7.0
pyasn1==0.4.8
pyasn1-modules==0.2.1
Pygments==2.16.1
PyGObject==3.42.1
PyHamcrest==2.0.2
PyJWT==2.3.0
pyOpenSSL==21.0.0
pyparsing==2.4.7
pyrsistent==0.18.1
pyserial==3.5
python-apt==2.4.0+ubuntu1
python-dateutil==2.8.2
python-debian==0.1.43+ubuntu1.1
python-magic==0.4.24
pytz==2022.1
PyYAML==5.4.1
requests==2.31.0
requests-oauthlib==1.3.1
rich==13.5.3
rsa==4.9
scipy==1.11.2
SecretStorage==3.3.1
service-identity==18.1.0
six==1.16.0
sos==4.4
ssh-import-id==5.11
systemd-python==234
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.1
tensorflow-datasets==4.4.0
tensorflow-estimator==2.11.0
tensorflow-io-gcs-filesystem==0.34.0
tensorflow-metadata==1.13.0
tensorstore==0.1.44
termcolor==2.3.0
toolz==0.12.0
tqdm==4.66.1
Twisted==22.1.0
typing_extensions==4.8.0
ubuntu-advantage-tools==8001
ufw==0.36.1
unattended-upgrades==0.1
urllib3==2.0.5
virtualenv==20.24.2
wadllib==1.3.6
Werkzeug==2.3.7
wrapt==1.15.0
zipp==1.0.0
zope.interface==5.4.0
chiamp commented 11 months ago

Can specify the system version, Python version, Flax repo commit hash, and output of pip freeze from one of the machines?

Following up here: python version: 3.10.12 OS: Ubuntu 22.04.2 LTS Flax repo commit hash: 242f84c Pip freeze output:

absl-py==1.0.0
astunparse==1.6.3
attrs==21.2.0
...

I was unable to create a TPU node with software version v2-alpha-tpuv5-lite, but when I ran python imagenet_fake_data_benchmark.py on a TPU node with TPU type v5lite-8 and TPU software version tpu-ubuntu2204-base and with the specific package versions you listed, the script completed with no error. Could you try using tpu-ubuntu2204-base as the TPU software version if you haven't already?

Traceback (most recent call last):
  File "/home/alijafari/flax/examples/imagenet/imagenet_fake_data_benchmark.py", line 50, in test_fake_data
    train.train_and_evaluate(config, workdir)
  File "/home/alijafari/flax/examples/imagenet/train.py", line 413, in train_and_evaluate
    eval_metrics = common_utils.get_metrics(eval_metrics)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/training/common_utils.py", line 86, in get_metrics
    return stack_forest(metrics_np)
  File "/home/alijafari/.local/lib/python3.10/site-packages/flax/training/common_utils.py", line 69, in stack_forest
    return jax.tree_util.tree_map(stack_args, *forest)
TypeError: tree_map() missing 1 required positional argument: 'tree'

This traceback implies that eval_metrics is an empty list. Could you inspect it and confirm this?

andsteing commented 11 months ago

Glad to see we got the dependency problems sorted out.

The remaining problem seems to be due to a configuration problem:

https://github.com/google/flax/blob/d059ba8aadfe839a7bb0ce7b2c47afb5d91fdf0a/examples/imagenet/configs/fake_data_benchmark.py#L22-L36

Since we're using ACCELERATOR_TYPE=v5e-16, we would have config.batch_size = 4096 and thus config.steps_per_eval = 0.

This in turn explains why we have empty eval_metrics:

https://github.com/google/flax/blob/d059ba8aadfe839a7bb0ce7b2c47afb5d91fdf0a/examples/imagenet/train.py#L409-L413

So I would try setting config.steps_per_eval to a value larger than zero. But I'm a bit puzzled by this error, since the same settings have been used for over 3 years, so I don't understand how this test ever would have worked on more than 2 devices.

@gkroiz can you confirm the details of an earlier device configuration where this test has worked?

gkroiz commented 11 months ago

@andsteing, ~confirming that I was able to run this test with the same setup on v4-32 without any issues. (v4-32 has the same number of chips as v5e-16).~ I didn't run the test for long enough, this test also does not work on v4-32.

andsteing commented 11 months ago

can you try config.steps_per_eval = 1 and see if that runs?

gkroiz commented 11 months ago

I updated my comment above, the test on v4-32 failed for the same reason. However, with config.steps_per_eval = 1, the test seems to be running, I'll test this on v5e-16.

gkroiz commented 11 months ago

On v5e we noticed that training and evaluation seem to work but eventually the code freezes. Here is the output:

Could this be related to checkpointing?

gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_ID --zone=$ZONE --worker=all --command='cd flax/examples/imagenet && JAX_PLATFORMS=tpu python3 imagenet_fake_data_benchmark.py' 2>&1 | tee jax-test2.log
Using ssh batch size of 4. Attempting to SSH into 1 nodes with a total of 4 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
2023-10-02 21:33:13.178018: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-10-02 21:33:13.213814: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-10-02 21:33:13.228705: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-10-02 21:33:13.231578: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-10-02 21:33:13.718224: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-10-02 21:33:13.718274: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-10-02 21:33:13.718280: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-10-02 21:33:13.744422: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-10-02 21:33:13.744471: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-10-02 21:33:13.744477: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-10-02 21:33:13.756624: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-10-02 21:33:13.756673: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-10-02 21:33:13.756679: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-10-02 21:33:13.766900: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-10-02 21:33:13.766948: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-10-02 21:33:13.766953: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Running tests under Python 3.10.12: /usr/bin/python3
2023-10-02 21:33:14.838154: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-10-02 21:33:14.838176: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
Running tests under Python 3.10.12: /usr/bin/python3
2023-10-02 21:33:14.839311: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-10-02 21:33:14.839334: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
Running tests under Python 3.10.12: /usr/bin/python3
2023-10-02 21:33:14.850979: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-10-02 21:33:14.851002: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
Running tests under Python 3.10.12: /usr/bin/python3
2023-10-02 21:33:14.877592: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-10-02 21:33:14.877615: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1696282395.101129   11281 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/alijafari/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1696282395.101158   11281 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1696282395.101160   11281 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1696282395.114954   11372 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/alijafari/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1696282395.114983   11372 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1696282395.114985   11372 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1696282395.118679   11265 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/alijafari/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1696282395.118709   11265 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1696282395.118711   11265 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1696282395.149025   11330 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/alijafari/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1696282395.149054   11330 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1696282395.149056   11330 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
I1002 21:33:25.472892 139933356431360 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:33:25.475714 139933356431360 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1002 21:33:25.475951 139933356431360 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1002 21:33:25.476093 139933356431360 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:33:25.476928 139933356431360 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[640582:960873], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W1002 21:33:25.523503 139933356431360 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1002 21:33:25.581774 140416143869952 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:33:25.584514 140416143869952 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1002 21:33:25.584738 140416143869952 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1002 21:33:25.584877 140416143869952 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:33:25.585661 140416143869952 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[960873:1281164], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W1002 21:33:25.630933 140416143869952 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W1002 21:33:25.742024 139933356431360 deprecation.py:350] From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W1002 21:33:25.837592 140416143869952 deprecation.py:350] From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I1002 21:33:26.036740 139933356431360 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:33:26.037425 139933356431360 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[25000:37500], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W1002 21:33:26.058284 139933356431360 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1002 21:33:26.126379 140017219430400 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:33:26.129148 140017219430400 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1002 21:33:26.129365 140017219430400 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1002 21:33:26.129508 140017219430400 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:33:26.130132 140017219430400 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[0:320291], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:33:26.141045 140416143869952 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:33:26.141921 140416143869952 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[37500:50000], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W1002 21:33:26.159762 140017219430400 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1002 21:33:26.165181 140416143869952 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W1002 21:33:26.370064 140017219430400 deprecation.py:350] From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I1002 21:33:26.664216 140017219430400 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:33:26.664877 140017219430400 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[0:12500], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W1002 21:33:26.685927 140017219430400 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1002 21:33:26.919408 140411760846848 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:33:26.922156 140411760846848 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1002 21:33:26.922379 140411760846848 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1002 21:33:26.922517 140411760846848 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:33:26.923253 140411760846848 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[320291:640582], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W1002 21:33:26.968081 140411760846848 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W1002 21:33:27.179234 140411760846848 deprecation.py:350] From /home/alijafari/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I1002 21:33:27.476334 140411760846848 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:33:27.477014 140411760846848 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[12500:25000], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W1002 21:33:27.498383 140411760846848 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1002 21:33:50.561522 139933356431360 checkpoints.py:1054] Found no checkpoint files in /tmp/tmp8hr_h3p0/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1002 21:33:50.644490 139933356431360 train.py:378] Initial compilation, this might take some minutes...
I1002 21:33:50.794145 140416143869952 checkpoints.py:1054] Found no checkpoint files in /tmp/tmpodbgd_9g/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1002 21:33:50.877022 140416143869952 train.py:378] Initial compilation, this might take some minutes...
I1002 21:33:50.899796 140017219430400 checkpoints.py:1054] Found no checkpoint files in /tmp/tmp__vlwx_o/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1002 21:33:50.980936 140017219430400 train.py:378] Initial compilation, this might take some minutes...
I1002 21:33:51.858507 140411760846848 checkpoints.py:1054] Found no checkpoint files in /tmp/tmp1td79dmz/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1002 21:33:51.938897 140411760846848 train.py:378] Initial compilation, this might take some minutes...
I1002 21:34:26.227666 139933356431360 train.py:384] Initial compilation completed.
I1002 21:34:26.242962 140017219430400 train.py:384] Initial compilation completed.
I1002 21:34:26.565754 140416143869952 train.py:384] Initial compilation completed.
I1002 21:34:27.635730 140411760846848 train.py:384] Initial compilation completed.
I1002 21:34:33.754988 140017219430400 local.py:50] Created artifact [10] Profile of type ArtifactType.URL and value None.
I1002 21:35:18.538424 140400640460352 logging_writer.py:35] [100] steps_per_second=1.140757, train_accuracy=0.012011718936264515, train_learning_rate=0.05076923221349716, train_loss=6.274664402008057
I1002 21:35:18.538672 139917864867392 logging_writer.py:35] [100] steps_per_second=1.137736, train_accuracy=0.012011718936264515, train_learning_rate=0.05076923221349716, train_loss=6.274664402008057
I1002 21:35:18.538827 140396270327360 logging_writer.py:35] [100] steps_per_second=1.154740, train_accuracy=0.012011718936264515, train_learning_rate=0.05076923221349716, train_loss=6.274664402008057
I1002 21:35:18.538597 140001745585728 logging_writer.py:35] [100] steps_per_second=1.142109, train_accuracy=0.012011718936264515, train_learning_rate=0.05076923221349716, train_loss=6.274664402008057
I1002 21:36:14.354088 139917864867392 logging_writer.py:35] [200] steps_per_second=1.791642, train_accuracy=0.05131591856479645, train_learning_rate=0.15333333611488342, train_loss=5.458009243011475
I1002 21:36:14.354106 140396270327360 logging_writer.py:35] [200] steps_per_second=1.791647, train_accuracy=0.05131591856479645, train_learning_rate=0.15333333611488342, train_loss=5.458009243011475
I1002 21:36:14.354475 140400640460352 logging_writer.py:35] [200] steps_per_second=1.791621, train_accuracy=0.05131591856479645, train_learning_rate=0.15333333611488342, train_loss=5.458009243011475
I1002 21:36:14.353846 140001745585728 logging_writer.py:35] [200] steps_per_second=1.791721, train_accuracy=0.05131591856479645, train_learning_rate=0.15333333611488342, train_loss=5.458009243011475
I1002 21:37:10.283636 140400640460352 logging_writer.py:35] [300] steps_per_second=1.787978, train_accuracy=0.13828368484973907, train_learning_rate=0.2558974325656891, train_loss=4.704421043395996
I1002 21:37:10.283439 139917864867392 logging_writer.py:35] [300] steps_per_second=1.787973, train_accuracy=0.13828368484973907, train_learning_rate=0.2558974325656891, train_loss=4.704421043395996
I1002 21:37:10.284340 140396270327360 logging_writer.py:35] [300] steps_per_second=1.787944, train_accuracy=0.13828368484973907, train_learning_rate=0.2558974325656891, train_loss=4.704421043395996
I1002 21:37:10.283701 140001745585728 logging_writer.py:35] [300] steps_per_second=1.787955, train_accuracy=0.13828368484973907, train_learning_rate=0.2558974325656891, train_loss=4.704421043395996
I1002 21:37:27.694692 140411760846848 train.py:415] eval epoch: 0, loss: 4.8501, accuracy: 16.11
I1002 21:37:27.694755 140416143869952 train.py:415] eval epoch: 0, loss: 4.8501, accuracy: 16.11
I1002 21:37:27.695066 140400640460352 logging_writer.py:35] [312] eval_accuracy=0.1611328125, eval_loss=4.850093841552734
I1002 21:37:27.694686 139933356431360 train.py:415] eval epoch: 0, loss: 4.8501, accuracy: 16.11
I1002 21:37:27.694985 139917864867392 logging_writer.py:35] [312] eval_accuracy=0.1611328125, eval_loss=4.850093841552734
I1002 21:37:27.695611 140396270327360 logging_writer.py:35] [312] eval_accuracy=0.1611328125, eval_loss=4.850093841552734
I1002 21:37:27.775917 140017219430400 train.py:415] eval epoch: 0, loss: 4.8501, accuracy: 16.11
I1002 21:37:27.776633 140001745585728 logging_writer.py:35] [312] eval_accuracy=0.1611328125, eval_loss=4.850093841552734
I1002 21:38:13.425730 139917864867392 logging_writer.py:35] [400] steps_per_second=1.583725, train_accuracy=0.3724340796470642, train_learning_rate=0.35846155881881714, train_loss=3.219879150390625
I1002 21:38:13.425510 140400640460352 logging_writer.py:35] [400] steps_per_second=1.583736, train_accuracy=0.3724340796470642, train_learning_rate=0.35846155881881714, train_loss=3.219879150390625
I1002 21:38:13.425862 140396270327360 logging_writer.py:35] [400] steps_per_second=1.583744, train_accuracy=0.3724340796470642, train_learning_rate=0.35846155881881714, train_loss=3.219879150390625
I1002 21:38:13.425578 140001728800320 logging_writer.py:35] [400] steps_per_second=1.583737, train_accuracy=0.3724340796470642, train_learning_rate=0.35846155881881714, train_loss=3.219879150390625
I1002 21:39:09.572115 140400640460352 logging_writer.py:35] [500] steps_per_second=1.781075, train_accuracy=0.8223657011985779, train_learning_rate=0.4610256254673004, train_loss=0.8868171572685242
I1002 21:39:09.572078 139917864867392 logging_writer.py:35] [500] steps_per_second=1.781084, train_accuracy=0.8223657011985779, train_learning_rate=0.4610256254673004, train_loss=0.8868171572685242
I1002 21:39:09.572421 140396270327360 logging_writer.py:35] [500] steps_per_second=1.781082, train_accuracy=0.8223657011985779, train_learning_rate=0.4610256254673004, train_loss=0.8868171572685242
I1002 21:39:09.571724 140001728800320 logging_writer.py:35] [500] steps_per_second=1.781117, train_accuracy=0.8223657011985779, train_learning_rate=0.4610256254673004, train_loss=0.8868171572685242
I1002 21:40:05.913566 139917864867392 logging_writer.py:35] [600] steps_per_second=1.774895, train_accuracy=0.9743383526802063, train_learning_rate=0.5635898113250732, train_loss=0.13830925524234772
I1002 21:40:05.913715 140400640460352 logging_writer.py:35] [600] steps_per_second=1.774888, train_accuracy=0.9743383526802063, train_learning_rate=0.5635898113250732, train_loss=0.13830925524234772
I1002 21:40:05.914761 140396270327360 logging_writer.py:35] [600] steps_per_second=1.774866, train_accuracy=0.9743383526802063, train_learning_rate=0.5635898113250732, train_loss=0.13830925524234772
I1002 21:40:05.911848 140001728800320 logging_writer.py:35] [600] steps_per_second=1.774931, train_accuracy=0.9743383526802063, train_learning_rate=0.5635898113250732, train_loss=0.13830925524234772
I1002 21:40:19.502022 139933356431360 train.py:415] eval epoch: 1, loss: 0.1311, accuracy: 97.66
I1002 21:40:19.501996 140416143869952 train.py:415] eval epoch: 1, loss: 0.1311, accuracy: 97.66
I1002 21:40:19.502644 140400640460352 logging_writer.py:35] [624] eval_accuracy=0.9765625, eval_loss=0.13113322854042053
I1002 21:40:19.501962 140411760846848 train.py:415] eval epoch: 1, loss: 0.1311, accuracy: 97.66
I1002 21:40:19.502587 140396270327360 logging_writer.py:35] [624] eval_accuracy=0.9765625, eval_loss=0.13113322854042053
I1002 21:40:19.502808 139917864867392 logging_writer.py:35] [624] eval_accuracy=0.9765625, eval_loss=0.13113322854042053
I1002 21:40:19.502541 140017219430400 train.py:415] eval epoch: 1, loss: 0.1311, accuracy: 97.66
I1002 21:40:19.503815 140001728800320 logging_writer.py:35] [624] eval_accuracy=0.9765625, eval_loss=0.13113322854042053
I1002 21:41:02.939436 140400640460352 logging_writer.py:35] [700] steps_per_second=1.753597, train_accuracy=0.9886084198951721, train_learning_rate=0.6661539673805237, train_loss=0.06366942077875137
I1002 21:41:02.940451 139917864867392 logging_writer.py:35] [700] steps_per_second=1.753559, train_accuracy=0.9886084198951721, train_learning_rate=0.6661539673805237, train_loss=0.06366942077875137
I1002 21:41:02.941023 140396270327360 logging_writer.py:35] [700] steps_per_second=1.753578, train_accuracy=0.9886084198951721, train_learning_rate=0.6661539673805237, train_loss=0.06366942077875137
I1002 21:41:02.940557 140001745585728 logging_writer.py:35] [700] steps_per_second=1.753507, train_accuracy=0.9886084198951721, train_learning_rate=0.6661539673805237, train_loss=0.06366942077875137
I1002 21:41:58.619088 139917864867392 logging_writer.py:35] [800] steps_per_second=1.796044, train_accuracy=0.9912988543510437, train_learning_rate=0.7687180042266846, train_loss=0.04839024692773819
I1002 21:41:58.619411 140400640460352 logging_writer.py:35] [800] steps_per_second=1.795992, train_accuracy=0.9912988543510437, train_learning_rate=0.7687180042266846, train_loss=0.04839024692773819
I1002 21:41:58.619482 140396270327360 logging_writer.py:35] [800] steps_per_second=1.796048, train_accuracy=0.9912988543510437, train_learning_rate=0.7687180042266846, train_loss=0.04839024692773819
I1002 21:41:58.618924 140001745585728 logging_writer.py:35] [800] steps_per_second=1.796077, train_accuracy=0.9912988543510437, train_learning_rate=0.7687180042266846, train_loss=0.04839024692773819
I1002 21:42:54.633974 140400640460352 logging_writer.py:35] [900] steps_per_second=1.785251, train_accuracy=0.7319629192352295, train_learning_rate=0.8712821006774902, train_loss=2.548089027404785
I1002 21:42:54.634981 139917864867392 logging_writer.py:35] [900] steps_per_second=1.785210, train_accuracy=0.7319629192352295, train_learning_rate=0.8712821006774902, train_loss=2.548089027404785
I1002 21:42:54.635331 140396270327360 logging_writer.py:35] [900] steps_per_second=1.785213, train_accuracy=0.7319629192352295, train_learning_rate=0.8712821006774902, train_loss=2.548089027404785
I1002 21:42:54.634346 140001745585728 logging_writer.py:35] [900] steps_per_second=1.785226, train_accuracy=0.7319629192352295, train_learning_rate=0.8712821006774902, train_loss=2.548089027404785
I1002 21:43:14.804400 139933356431360 train.py:415] eval epoch: 2, loss: 6.6112, accuracy: 0.59
I1002 21:43:14.804507 140416143869952 train.py:415] eval epoch: 2, loss: 6.6112, accuracy: 0.59
I1002 21:43:14.805023 139917864867392 logging_writer.py:35] [936] eval_accuracy=0.005859375, eval_loss=6.6112446784973145
I1002 21:43:14.804613 140411760846848 train.py:415] eval epoch: 2, loss: 6.6112, accuracy: 0.59
I1002 21:43:14.805183 140396270327360 logging_writer.py:35] [936] eval_accuracy=0.005859375, eval_loss=6.6112446784973145
I1002 21:43:14.805149 140400640460352 logging_writer.py:35] [936] eval_accuracy=0.005859375, eval_loss=6.6112446784973145
I1002 21:43:14.804787 140017219430400 train.py:415] eval epoch: 2, loss: 6.6112, accuracy: 0.59
I1002 21:43:14.806575 140001745585728 logging_writer.py:35] [936] eval_accuracy=0.005859375, eval_loss=6.6112446784973145
I1002 21:43:51.424899 140400640460352 logging_writer.py:35] [1000] steps_per_second=1.760847, train_accuracy=0.006696777418255806, train_learning_rate=0.9738461971282959, train_loss=6.525280952453613
I1002 21:43:51.426173 139917864867392 logging_writer.py:35] [1000] steps_per_second=1.760837, train_accuracy=0.006696777418255806, train_learning_rate=0.9738461971282959, train_loss=6.525280952453613
I1002 21:43:51.426186 140396270327360 logging_writer.py:35] [1000] steps_per_second=1.760848, train_accuracy=0.006696777418255806, train_learning_rate=0.9738461971282959, train_loss=6.525280952453613
I1002 21:43:51.426347 140001728800320 logging_writer.py:35] [1000] steps_per_second=1.760814, train_accuracy=0.006696777418255806, train_learning_rate=0.9738461971282959, train_loss=6.525280952453613
I1002 21:44:47.876418 140400640460352 logging_writer.py:35] [1100] steps_per_second=1.771447, train_accuracy=0.008820801042020321, train_learning_rate=1.0764102935791016, train_loss=6.023613929748535
I1002 21:44:47.877130 139917864867392 logging_writer.py:35] [1100] steps_per_second=1.771471, train_accuracy=0.008820801042020321, train_learning_rate=1.0764102935791016, train_loss=6.023613929748535
I1002 21:44:47.877115 140396270327360 logging_writer.py:35] [1100] steps_per_second=1.771474, train_accuracy=0.008820801042020321, train_learning_rate=1.0764102935791016, train_loss=6.023613929748535
I1002 21:44:47.877155 140001728800320 logging_writer.py:35] [1100] steps_per_second=1.771506, train_accuracy=0.008820801042020321, train_learning_rate=1.0764102935791016, train_loss=6.023613929748535
I1002 21:45:43.947209 140400640460352 logging_writer.py:35] [1200] steps_per_second=1.783463, train_accuracy=0.015727538615465164, train_learning_rate=1.1789745092391968, train_loss=5.665318489074707
I1002 21:45:43.948185 139917864867392 logging_writer.py:35] [1200] steps_per_second=1.783451, train_accuracy=0.015727538615465164, train_learning_rate=1.1789745092391968, train_loss=5.665318489074707
I1002 21:45:43.948034 140396270327360 logging_writer.py:35] [1200] steps_per_second=1.783457, train_accuracy=0.015727538615465164, train_learning_rate=1.1789745092391968, train_loss=5.665318489074707
I1002 21:45:43.946491 140001728800320 logging_writer.py:35] [1200] steps_per_second=1.783506, train_accuracy=0.015727538615465164, train_learning_rate=1.1789745092391968, train_loss=5.665318489074707
I1002 21:46:11.037369 139933356431360 train.py:415] eval epoch: 3, loss: 5.7759, accuracy: 2.73
I1002 21:46:11.037415 140416143869952 train.py:415] eval epoch: 3, loss: 5.7759, accuracy: 2.73
I1002 21:46:11.037302 140411760846848 train.py:415] eval epoch: 3, loss: 5.7759, accuracy: 2.73
I1002 21:46:11.038008 139917864867392 logging_writer.py:35] [1248] eval_accuracy=0.02734375, eval_loss=5.775895595550537
I1002 21:46:11.037983 140396270327360 logging_writer.py:35] [1248] eval_accuracy=0.02734375, eval_loss=5.775895595550537
I1002 21:46:11.038063 140400640460352 logging_writer.py:35] [1248] eval_accuracy=0.02734375, eval_loss=5.775895595550537
I1002 21:46:11.037597 140017219430400 train.py:415] eval epoch: 3, loss: 5.7759, accuracy: 2.73
I1002 21:46:11.038964 140001728800320 logging_writer.py:35] [1248] eval_accuracy=0.02734375, eval_loss=5.775895595550537
I1002 21:46:40.891882 140400640460352 logging_writer.py:35] [1300] steps_per_second=1.756090, train_accuracy=0.053569335490465164, train_learning_rate=1.281538486480713, train_loss=5.112351894378662
I1002 21:46:40.892645 140396270327360 logging_writer.py:35] [1300] steps_per_second=1.756094, train_accuracy=0.053569335490465164, train_learning_rate=1.281538486480713, train_loss=5.112351894378662
I1002 21:46:40.892704 139917864867392 logging_writer.py:35] [1300] steps_per_second=1.756098, train_accuracy=0.053569335490465164, train_learning_rate=1.281538486480713, train_loss=5.112351894378662
I1002 21:46:40.892796 140001745585728 logging_writer.py:35] [1300] steps_per_second=1.756043, train_accuracy=0.053569335490465164, train_learning_rate=1.281538486480713, train_loss=5.112351894378662
I1002 21:47:37.381415 140400640460352 logging_writer.py:35] [1400] steps_per_second=1.770255, train_accuracy=0.17946776747703552, train_learning_rate=1.3841028213500977, train_loss=4.246438026428223
I1002 21:47:37.382330 139917864867392 logging_writer.py:35] [1400] steps_per_second=1.770257, train_accuracy=0.17946776747703552, train_learning_rate=1.3841028213500977, train_loss=4.246438026428223
I1002 21:47:37.382674 140396270327360 logging_writer.py:35] [1400] steps_per_second=1.770243, train_accuracy=0.17946776747703552, train_learning_rate=1.3841028213500977, train_loss=4.246438026428223
I1002 21:47:37.382471 140001745585728 logging_writer.py:35] [1400] steps_per_second=1.770281, train_accuracy=0.17946776747703552, train_learning_rate=1.3841028213500977, train_loss=4.246438026428223
I1002 21:48:33.766398 140400640460352 logging_writer.py:35] [1500] steps_per_second=1.773523, train_accuracy=0.4083007872104645, train_learning_rate=1.4866665601730347, train_loss=2.9806673526763916
I1002 21:48:33.767817 139917864867392 logging_writer.py:35] [1500] steps_per_second=1.773510, train_accuracy=0.4083007872104645, train_learning_rate=1.4866665601730347, train_loss=2.9806673526763916
I1002 21:48:33.767883 140396270327360 logging_writer.py:35] [1500] steps_per_second=1.773518, train_accuracy=0.4083007872104645, train_learning_rate=1.4866665601730347, train_loss=2.9806673526763916
I1002 21:48:33.766485 140001745585728 logging_writer.py:35] [1500] steps_per_second=1.773554, train_accuracy=0.4083007872104645, train_learning_rate=1.4866665601730347, train_loss=2.9806673526763916
I1002 21:49:07.393768 140411760846848 train.py:415] eval epoch: 4, loss: 1.8427, accuracy: 60.84
I1002 21:49:07.393653 139933356431360 train.py:415] eval epoch: 4, loss: 1.8427, accuracy: 60.84
I1002 21:49:07.393865 140416143869952 train.py:415] eval epoch: 4, loss: 1.8427, accuracy: 60.84
I1002 21:49:07.394260 139917864867392 logging_writer.py:35] [1560] eval_accuracy=0.6083984375, eval_loss=1.8427281379699707
I1002 21:49:07.394444 140396270327360 logging_writer.py:35] [1560] eval_accuracy=0.6083984375, eval_loss=1.8427281379699707
I1002 21:49:07.394522 140400640460352 logging_writer.py:35] [1560] eval_accuracy=0.6083984375, eval_loss=1.8427281379699707
I1002 21:49:07.393724 140017219430400 train.py:415] eval epoch: 4, loss: 1.8427, accuracy: 60.84
I1002 21:49:07.395623 140001745585728 logging_writer.py:35] [1560] eval_accuracy=0.6083984375, eval_loss=1.8427281379699707
I1002 21:49:07.661298 139933356431360 train.py:231] Saving checkpoint step 1560.
I1002 21:49:07.667837 140411760846848 train.py:231] Saving checkpoint step 1560.
I1002 21:49:07.669569 140416143869952 train.py:231] Saving checkpoint step 1560.
I1002 21:49:07.872932 140017219430400 train.py:231] Saving checkpoint step 1560.
I1002 21:49:07.917415 140416143869952 checkpoints.py:571] Saving checkpoint at step: 1560
I1002 21:49:07.917531 140416143869952 checkpoints.py:792] Using Orbax as backend to save Flax checkpoints. For potential troubleshooting see: https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#orbax-as-backend-troubleshooting
I1002 21:49:07.917694 140416143869952 type_handlers.py:223] OCDBT is initialized successfully.
I1002 21:49:07.917503 140411760846848 checkpoints.py:571] Saving checkpoint at step: 1560
I1002 21:49:07.917655 140411760846848 checkpoints.py:792] Using Orbax as backend to save Flax checkpoints. For potential troubleshooting see: https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#orbax-as-backend-troubleshooting
I1002 21:49:07.917536 139933356431360 checkpoints.py:571] Saving checkpoint at step: 1560
I1002 21:49:07.917850 139933356431360 checkpoints.py:792] Using Orbax as backend to save Flax checkpoints. For potential troubleshooting see: https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#orbax-as-backend-troubleshooting
I1002 21:49:07.917829 140411760846848 type_handlers.py:223] OCDBT is initialized successfully.
I1002 21:49:07.918013 139933356431360 type_handlers.py:223] OCDBT is initialized successfully.
I1002 21:49:07.918971 140416143869952 checkpointer.py:67] Saving item to /tmp/tmpodbgd_9g/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1560.
I1002 21:49:07.919377 139933356431360 checkpointer.py:67] Saving item to /tmp/tmp8hr_h3p0/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1560.
I1002 21:49:07.920249 140411760846848 checkpointer.py:67] Saving item to /tmp/tmp1td79dmz/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1560.
I1002 21:49:07.918579 140017219430400 checkpoints.py:571] Saving checkpoint at step: 1560
I1002 21:49:07.919683 140017219430400 checkpoints.py:792] Using Orbax as backend to save Flax checkpoints. For potential troubleshooting see: https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#orbax-as-backend-troubleshooting
I1002 21:49:07.919870 140017219430400 type_handlers.py:223] OCDBT is initialized successfully.
I1002 21:49:07.922735 140017219430400 checkpointer.py:67] Saving item to /tmp/tmp__vlwx_o/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1560.
I1002 21:49:08.582233 140017219430400 utils.py:522] Renaming /tmp/tmp__vlwx_o/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1560.orbax-checkpoint-tmp-1696283347923311 to /tmp/tmp__vlwx_o/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1560
I1002 21:49:08.583415 140017219430400 utils.py:566] Finished saving checkpoint to `/tmp/tmp__vlwx_o/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1560`.
I1002 21:49:08.882952 140416143869952 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:08.885487 140416143869952 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1002 21:49:08.885710 140416143869952 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1002 21:49:08.885855 140416143869952 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:08.886517 140416143869952 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[960873:1281164], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:08.897555 140411760846848 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:08.897615 139933356431360 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:08.900106 139933356431360 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1002 21:49:08.900037 140411760846848 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1002 21:49:08.900235 140411760846848 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1002 21:49:08.900371 140411760846848 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:08.900305 139933356431360 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1002 21:49:08.900436 139933356431360 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:08.901031 140411760846848 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[320291:640582], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:08.901119 139933356431360 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[640582:960873], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:08.901741 140017219430400 dataset_info.py:358] Load dataset info from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:08.904234 140017219430400 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1002 21:49:08.904443 140017219430400 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1002 21:49:08.904574 140017219430400 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:08.905229 140017219430400 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[0:320291], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W1002 21:49:08.909770 140416143869952 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1002 21:49:08.924156 139933356431360 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1002 21:49:08.924472 140411760846848 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1002 21:49:08.928236 140017219430400 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1002 21:49:09.008003 140416143869952 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:09.008627 140416143869952 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[37500:50000], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:09.020364 139933356431360 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:09.020982 139933356431360 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[25000:37500], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:09.022900 140411760846848 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:09.023515 140411760846848 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[12500:25000], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:09.025855 140017219430400 mocking.py:151] Metadata found for imagenet2012 at /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
I1002 21:49:09.026493 140017219430400 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[0:12500], from /home/alijafari/flax/.tfds/metadata/imagenet2012/5.1.0
W1002 21:49:09.030889 140416143869952 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1002 21:49:09.042039 139933356431360 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1002 21:49:09.045854 140411760846848 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1002 21:49:09.048129 140017219430400 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1002 21:49:33.352570 139933356431360 checkpoints.py:1054] Found no checkpoint files in /tmp/tmp8hr_h3p0/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1002 21:49:33.434531 139933356431360 train.py:378] Initial compilation, this might take some minutes...
I1002 21:49:33.677163 140017219430400 checkpoints.py:1064] Restoring orbax checkpoint from /tmp/tmp__vlwx_o/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1560
I1002 21:49:33.684611 140411760846848 checkpoints.py:1054] Found no checkpoint files in /tmp/tmp1td79dmz/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1002 21:49:33.686718 140017219430400 checkpointer.py:97] Restoring item from /tmp/tmp__vlwx_o/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1560.
I1002 21:49:33.767652 140411760846848 train.py:378] Initial compilation, this might take some minutes...
I1002 21:49:33.980492 140416143869952 checkpoints.py:1054] Found no checkpoint files in /tmp/tmpodbgd_9g/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
W1002 21:49:34.007778 140017219430400 transform_utils.py:229] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
I1002 21:49:34.062245 140416143869952 train.py:378] Initial compilation, this might take some minutes...
I1002 21:50:08.994675 140411760846848 train.py:384] Initial compilation completed.
I1002 21:50:09.158574 139933356431360 train.py:384] Initial compilation completed.
I1002 21:50:09.278753 140416143869952 train.py:384] Initial compilation completed.
gkroiz commented 11 months ago

Noticed this same behavior on v4-32

chiamp commented 11 months ago

I wonder if this if condition is not evaluating to True. What value are you getting for steps_per_epoch?

gkroiz commented 11 months ago

I wonder if this if condition is not evaluating to True. What value are you getting for steps_per_epoch?

@chiamp Sorry for the late response here. when running on v4-32, I'm getting steps_per_epoch=312

chiamp commented 11 months ago

On v5e we noticed that training and evaluation seem to work but eventually the code freezes. Here is the output:

Could this be related to checkpointing?

@IvyZX, do you think it's a checkpointing issue?

gkroiz commented 11 months ago

@chiamp I'm somewhat confused here, shouldn't this issue still be open?

chiamp commented 11 months ago

Apologies, seems like PR #3386 (which fixes the config.num_train_steps and config.steps_per_eval to 1) linked to this issue and automatically closed it once it was submitted.

RissyRan commented 11 months ago

The training still hangs there with the latest change.

Full logs:

2023-10-11 01:27:34.284670: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-10-11 01:27:34.625669: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-10-11 01:27:34.809357: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-10-11 01:27:34.809408: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-10-11 01:27:34.809414: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-10-11 01:27:34.876478: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-10-11 01:27:35.131153: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-10-11 01:27:35.279691: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-10-11 01:27:35.279743: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-10-11 01:27:35.279750: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-10-11 01:27:35.403274: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-10-11 01:27:35.403322: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-10-11 01:27:35.403328: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-10-11 01:27:35.797554: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-10-11 01:27:35.797604: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-10-11 01:27:35.797610: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Running tests under Python 3.10.12: /usr/bin/python3
2023-10-11 01:27:35.898936: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-10-11 01:27:35.898961: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1696987656.278455   10017 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/ranran/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1696987656.278485   10017 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1696987656.278487   10017 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
Running tests under Python 3.10.12: /usr/bin/python3
2023-10-11 01:27:36.515545: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-10-11 01:27:36.515577: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
Running tests under Python 3.10.12: /usr/bin/python3
2023-10-11 01:27:36.710666: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-10-11 01:27:36.710691: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1696987656.824012   16022 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/ranran/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1696987656.824040   16022 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1696987656.824042   16022 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1696987657.032997   14975 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/ranran/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1696987657.033023   14975 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1696987657.033025   14975 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
Running tests under Python 3.10.12: /usr/bin/python3
2023-10-11 01:27:37.269418: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-10-11 01:27:37.269442: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1696987657.594851   13801 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/ranran/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1696987657.594882   13801 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1696987657.594884   13801 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
I1011 01:27:46.438061 139990409582592 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:27:46.440865 139990409582592 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 01:27:46.441073 139990409582592 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 01:27:46.441209 139990409582592 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:27:46.441837 139990409582592 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[960873:1281164], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 01:27:46.486543 139990409582592 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 01:27:46.588143 140074188843008 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:27:46.590813 140074188843008 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 01:27:46.591015 140074188843008 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 01:27:46.591147 140074188843008 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:27:46.591765 140074188843008 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[640582:960873], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 01:27:46.636430 140074188843008 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W1011 01:27:46.695426 139990409582592 deprecation.py:350] From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I1011 01:27:46.747684 139834310957056 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:27:46.750409 139834310957056 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 01:27:46.750622 139834310957056 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 01:27:46.750761 139834310957056 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:27:46.751528 139834310957056 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[0:320291], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 01:27:46.780556 139834310957056 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W1011 01:27:46.846010 140074188843008 deprecation.py:350] From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I1011 01:27:46.990176 139990409582592 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:27:46.990841 139990409582592 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[37500:50000], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
WARNING:tensorflow:From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W1011 01:27:46.992446 139834310957056 deprecation.py:350] From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W1011 01:27:47.011812 139990409582592 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 01:27:47.142456 140074188843008 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:27:47.143145 140074188843008 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[25000:37500], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 01:27:47.164407 140074188843008 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 01:27:47.288790 139834310957056 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:27:47.289528 139834310957056 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[0:12500], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 01:27:47.311395 139834310957056 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 01:27:48.178797 140298040678400 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:27:48.181573 140298040678400 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 01:27:48.181784 140298040678400 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 01:27:48.181932 140298040678400 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:27:48.182553 140298040678400 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[320291:640582], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 01:27:48.228492 140298040678400 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W1011 01:27:48.441981 140298040678400 deprecation.py:350] From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I1011 01:27:48.748358 140298040678400 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:27:48.749107 140298040678400 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[12500:25000], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 01:27:48.770431 140298040678400 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 01:28:11.197376 139990409582592 checkpoints.py:1054] Found no checkpoint files in /tmp/tmpvzx1_7_q/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 01:28:11.278421 139990409582592 train.py:378] Initial compilation, this might take some minutes...
I1011 01:28:11.434583 139834310957056 checkpoints.py:1054] Found no checkpoint files in /tmp/tmpqzvq0xvp/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 01:28:11.515003 139834310957056 train.py:378] Initial compilation, this might take some minutes...
I1011 01:28:11.683796 140074188843008 checkpoints.py:1054] Found no checkpoint files in /tmp/tmphaxp2xpk/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 01:28:11.765251 140074188843008 train.py:378] Initial compilation, this might take some minutes...
I1011 01:28:12.831703 140298040678400 checkpoints.py:1054] Found no checkpoint files in /tmp/tmp13hvsq07/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 01:28:12.914557 140298040678400 train.py:378] Initial compilation, this might take some minutes...
I1011 01:28:46.664929 139990409582592 train.py:384] Initial compilation completed.
I1011 01:28:46.997515 140074188843008 train.py:384] Initial compilation completed.
I1011 01:28:47.162226 139834310957056 train.py:384] Initial compilation completed.
I1011 01:28:48.478332 140298040678400 train.py:384] Initial compilation completed.
I1011 01:28:49.194751 139990409582592 train.py:231] Saving checkpoint step 1.
I1011 01:28:49.230294 140074188843008 train.py:231] Saving checkpoint step 1.
I1011 01:28:49.245000 139834310957056 train.py:231] Saving checkpoint step 1.
I1011 01:28:49.571197 140298040678400 train.py:231] Saving checkpoint step 1.
I1011 01:28:49.610326 140298040678400 checkpoints.py:571] Saving checkpoint at step: 1
I1011 01:28:49.610457 140298040678400 checkpoints.py:792] Using Orbax as backend to save Flax checkpoints. For potential troubleshooting see: https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#orbax-as-backend-troubleshooting
I1011 01:28:49.610620 140298040678400 type_handlers.py:223] OCDBT is initialized successfully.
I1011 01:28:49.610318 139834310957056 checkpoints.py:571] Saving checkpoint at step: 1
I1011 01:28:49.610454 139834310957056 checkpoints.py:792] Using Orbax as backend to save Flax checkpoints. For potential troubleshooting see: https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#orbax-as-backend-troubleshooting
I1011 01:28:49.610617 139834310957056 type_handlers.py:223] OCDBT is initialized successfully.
I1011 01:28:49.610426 140074188843008 checkpoints.py:571] Saving checkpoint at step: 1
I1011 01:28:49.610563 140074188843008 checkpoints.py:792] Using Orbax as backend to save Flax checkpoints. For potential troubleshooting see: https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#orbax-as-backend-troubleshooting
I1011 01:28:49.610722 140074188843008 type_handlers.py:223] OCDBT is initialized successfully.
I1011 01:28:49.611991 140298040678400 checkpointer.py:67] Saving item to /tmp/tmp13hvsq07/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1.
I1011 01:28:49.610365 139990409582592 checkpoints.py:571] Saving checkpoint at step: 1
I1011 01:28:49.610523 139990409582592 checkpoints.py:792] Using Orbax as backend to save Flax checkpoints. For potential troubleshooting see: https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#orbax-as-backend-troubleshooting
I1011 01:28:49.610713 139990409582592 type_handlers.py:223] OCDBT is initialized successfully.
I1011 01:28:49.612014 139834310957056 checkpointer.py:67] Saving item to /tmp/tmpqzvq0xvp/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1.
I1011 01:28:49.612037 140074188843008 checkpointer.py:67] Saving item to /tmp/tmphaxp2xpk/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1.
I1011 01:28:49.612159 139990409582592 checkpointer.py:67] Saving item to /tmp/tmpvzx1_7_q/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1.
I1011 01:28:50.681691 139834310957056 utils.py:522] Renaming /tmp/tmpqzvq0xvp/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1.orbax-checkpoint-tmp-1696987729612269 to /tmp/tmpqzvq0xvp/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1
I1011 01:28:50.681884 139834310957056 utils.py:566] Finished saving checkpoint to `/tmp/tmpqzvq0xvp/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1`.
I1011 01:28:50.982564 139834310957056 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:50.985078 139834310957056 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 01:28:50.985274 139834310957056 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 01:28:50.985406 139834310957056 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:50.986055 139834310957056 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[0:320291], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 01:28:51.008818 139834310957056 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 01:28:51.107008 139834310957056 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:51.107613 139834310957056 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[0:12500], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:51.107779 140074188843008 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:51.110271 140074188843008 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 01:28:51.110470 140074188843008 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 01:28:51.110601 140074188843008 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:51.111253 140074188843008 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[640582:960873], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:51.119885 140298040678400 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:51.122518 140298040678400 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 01:28:51.122715 140298040678400 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 01:28:51.122844 140298040678400 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:51.123497 140298040678400 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[320291:640582], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:51.124710 139990409582592 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:51.127226 139990409582592 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 01:28:51.127423 139990409582592 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 01:28:51.127554 139990409582592 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 01:28:51.128889 139834310957056 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 01:28:51.128194 139990409582592 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[960873:1281164], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 01:28:51.134078 140074188843008 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1011 01:28:51.146520 140298040678400 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1011 01:28:51.151225 139990409582592 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 01:28:51.231164 140074188843008 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:51.231767 140074188843008 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[25000:37500], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:51.244602 140298040678400 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:51.245238 140298040678400 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[12500:25000], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:51.247156 139990409582592 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 01:28:51.247810 139990409582592 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[37500:50000], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 01:28:51.253009 140074188843008 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1011 01:28:51.266740 140298040678400 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1011 01:28:51.268987 139990409582592 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 01:29:15.316825 140074188843008 checkpoints.py:1054] Found no checkpoint files in /tmp/tmphaxp2xpk/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 01:29:15.350800 139990409582592 checkpoints.py:1054] Found no checkpoint files in /tmp/tmpvzx1_7_q/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 01:29:15.352745 140298040678400 checkpoints.py:1054] Found no checkpoint files in /tmp/tmp13hvsq07/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 01:29:15.397813 140074188843008 train.py:378] Initial compilation, this might take some minutes...
I1011 01:29:15.430036 139990409582592 train.py:378] Initial compilation, this might take some minutes...
I1011 01:29:15.435266 140298040678400 train.py:378] Initial compilation, this might take some minutes...
I1011 01:29:15.479074 139834310957056 checkpoints.py:1064] Restoring orbax checkpoint from /tmp/tmpqzvq0xvp/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1
I1011 01:29:15.488394 139834310957056 checkpointer.py:97] Restoring item from /tmp/tmpqzvq0xvp/ImagenetBenchmarkFakeData.test_fake_data/checkpoint_1.
W1011 01:29:15.680681 139834310957056 transform_utils.py:229] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
I1011 01:29:50.906538 139990409582592 train.py:384] Initial compilation completed.
I1011 01:29:52.072032 140298040678400 train.py:384] Initial compilation completed.
I1011 01:29:52.267601 140074188843008 train.py:384] Initial compilation completed.

Can anyone verify it? Thank you!

RissyRan commented 11 months ago

I think this is related to the checkpoint saving/restoring. In the previous logs, it mentions Saving checkpoint at step: 1; while during restoring, I noticed both:

It seems the checkpoint is actually stored in one worker, and other workers cannot restore the checkpoint. This causes the hang.

So, I removed the line 425-427 in train.py for testing.

And the the trianing works fine.

Using ssh batch size of 4. Attempting to SSH into 1 nodes with a total of 4 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
TFDS metadata already exists.
TFDS metadata already exists.
2023-10-11 06:41:47.831533: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-10-11 06:41:47.914994: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
TFDS metadata already exists.
TFDS metadata already exists.
2023-10-11 06:41:48.363867: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-10-11 06:41:48.363916: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-10-11 06:41:48.363922: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-10-11 06:41:48.450789: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-10-11 06:41:48.450840: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-10-11 06:41:48.450846: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-10-11 06:41:48.798182: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-10-11 06:41:48.813800: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-10-11 06:41:49.322817: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-10-11 06:41:49.322866: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-10-11 06:41:49.322872: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-10-11 06:41:49.347628: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-10-11 06:41:49.347677: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-10-11 06:41:49.347683: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Running tests under Python 3.10.12: /usr/bin/python3
2023-10-11 06:41:49.472730: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-10-11 06:41:49.472753: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
Running tests under Python 3.10.12: /usr/bin/python3
2023-10-11 06:41:49.594028: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-10-11 06:41:49.594053: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1697006509.739886   43000 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/ranran/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1697006509.739916   43000 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1697006509.739918   43000 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1697006509.869077   41796 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/ranran/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1697006509.869106   41796 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1697006509.869108   41796 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
Running tests under Python 3.10.12: /usr/bin/python3
2023-10-11 06:41:50.429925: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-10-11 06:41:50.429949: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
Running tests under Python 3.10.12: /usr/bin/python3
2023-10-11 06:41:50.455175: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-10-11 06:41:50.455199: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
[ RUN      ] ImagenetBenchmarkFakeData.test_fake_data
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1697006510.702444   43139 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/ranran/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1697006510.702474   43139 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1697006510.702476   43139 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1697006510.718483   43099 pjrt_api.cc:98] GetPjrtApi was found for tpu at /home/ranran/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1697006510.718514   43099 pjrt_api.cc:67] PJRT_Api is set for device type tpu
I0000 00:00:1697006510.718517   43099 pjrt_api.cc:72] PJRT plugin for tpu has PJRT API version 0.30. The framework PJRT API version is 0.30.
I1011 06:42:01.258782 140704714082304 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:42:01.261477 140704714082304 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 06:42:01.261690 140704714082304 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 06:42:01.261832 140704714082304 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:42:01.262503 140704714082304 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[960873:1281164], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 06:42:01.308968 140704714082304 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W1011 06:42:01.422235 140704714082304 deprecation.py:350] From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I1011 06:42:01.718919 140704714082304 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:42:01.719579 140704714082304 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[37500:50000], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 06:42:01.743072 140704714082304 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 06:42:02.301202 140231361497088 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:42:02.303973 140231361497088 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 06:42:02.304185 140231361497088 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 06:42:02.304322 140231361497088 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:42:02.304980 140231361497088 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[320291:640582], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 06:42:02.350966 140231361497088 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W1011 06:42:02.464453 140231361497088 deprecation.py:350] From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I1011 06:42:02.759400 140231361497088 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:42:02.760054 140231361497088 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[12500:25000], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 06:42:02.781184 140231361497088 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 06:42:02.934723 140399706380288 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:42:02.937496 140399706380288 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 06:42:02.937703 140399706380288 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 06:42:02.937839 140399706380288 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:42:02.938472 140399706380288 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[640582:960873], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 06:42:02.983524 140399706380288 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W1011 06:42:03.097830 140399706380288 deprecation.py:350] From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I1011 06:42:03.393968 140399706380288 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:42:03.394631 140399706380288 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[25000:37500], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 06:42:03.415513 140399706380288 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 06:42:04.566965 139772030900224 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:42:04.569774 139772030900224 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 06:42:04.569999 139772030900224 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 06:42:04.570140 139772030900224 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:42:04.570912 139772030900224 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[0:320291], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 06:42:04.600678 139772030900224 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
WARNING:tensorflow:From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
W1011 06:42:04.825872 139772030900224 deprecation.py:350] From /home/ranran/.local/lib/python3.10/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
I1011 06:42:05.130667 139772030900224 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:42:05.131510 139772030900224 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[0:12500], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 06:42:05.153764 139772030900224 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 06:42:26.101440 140704714082304 checkpoints.py:1054] Found no checkpoint files in /tmp/tmprexpr086/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 06:42:26.182769 140704714082304 train.py:378] Initial compilation, this might take some minutes...
I1011 06:42:26.981086 140231361497088 checkpoints.py:1054] Found no checkpoint files in /tmp/tmpqobamzrr/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 06:42:27.062227 140231361497088 train.py:378] Initial compilation, this might take some minutes...
I1011 06:42:27.489730 140399706380288 checkpoints.py:1054] Found no checkpoint files in /tmp/tmp62mkvr_m/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 06:42:27.570203 140399706380288 train.py:378] Initial compilation, this might take some minutes...
I1011 06:42:29.672261 139772030900224 checkpoints.py:1054] Found no checkpoint files in /tmp/tmpfjw3fkuj/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 06:42:29.752334 139772030900224 train.py:378] Initial compilation, this might take some minutes...
I1011 06:43:01.645021 140704714082304 train.py:384] Initial compilation completed.
I1011 06:43:02.443091 140231361497088 train.py:384] Initial compilation completed.
I1011 06:43:02.677765 140399706380288 train.py:384] Initial compilation completed.
I1011 06:43:05.470218 139772030900224 train.py:384] Initial compilation completed.
I1011 06:43:05.630665 140399706380288 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.633227 140399706380288 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 06:43:05.633424 140399706380288 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 06:43:05.633552 140399706380288 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.633757 140704714082304 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.634183 140399706380288 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[640582:960873], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.634640 140231361497088 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.636315 140704714082304 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 06:43:05.636542 140704714082304 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 06:43:05.636702 140704714082304 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.637222 140231361497088 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 06:43:05.637440 140231361497088 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 06:43:05.637584 140231361497088 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.637382 140704714082304 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[960873:1281164], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.638234 140231361497088 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[320291:640582], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 06:43:05.657313 140399706380288 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1011 06:43:05.660951 140704714082304 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1011 06:43:05.662058 140231361497088 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 06:43:05.677829 139772030900224 dataset_info.py:358] Load dataset info from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.680407 139772030900224 dataset_info.py:411] Field info.description from disk and from code do not match. Keeping the one from code.
I1011 06:43:05.680603 139772030900224 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I1011 06:43:05.680739 139772030900224 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.681377 139772030900224 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split train[0:320291], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 06:43:05.704411 139772030900224 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 06:43:05.754293 140399706380288 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.754893 140399706380288 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[25000:37500], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.758002 140704714082304 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.758645 140704714082304 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[37500:50000], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.760042 140231361497088 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.760658 140231361497088 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[12500:25000], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 06:43:05.775973 140399706380288 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1011 06:43:05.779988 140704714082304 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
W1011 06:43:05.782152 140231361497088 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 06:43:05.802015 139772030900224 mocking.py:151] Metadata found for imagenet2012 at /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
I1011 06:43:05.802646 139772030900224 logging_logger.py:35] Constructing tf.data.Dataset imagenet2012 for split validation[0:12500], from /home/ranran/flax/.tfds/metadata/imagenet2012/5.1.0
W1011 06:43:05.824019 139772030900224 options.py:588] options.experimental_threading is deprecated. Use options.threading instead.
I1011 06:43:30.270634 140399706380288 checkpoints.py:1054] Found no checkpoint files in /tmp/tmp62mkvr_m/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 06:43:30.327605 140704714082304 checkpoints.py:1054] Found no checkpoint files in /tmp/tmprexpr086/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 06:43:30.351497 140399706380288 train.py:378] Initial compilation, this might take some minutes...
I1011 06:43:30.413816 140704714082304 train.py:378] Initial compilation, this might take some minutes...
I1011 06:43:30.461287 140231361497088 checkpoints.py:1054] Found no checkpoint files in /tmp/tmpqobamzrr/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 06:43:30.503510 139772030900224 checkpoints.py:1054] Found no checkpoint files in /tmp/tmpfjw3fkuj/ImagenetBenchmarkFakeData.test_fake_data with prefix checkpoint_
I1011 06:43:30.541634 140231361497088 train.py:378] Initial compilation, this might take some minutes...
I1011 06:43:30.582338 139772030900224 train.py:378] Initial compilation, this might take some minutes...
I1011 06:44:05.680504 140704714082304 train.py:384] Initial compilation completed.
I1011 06:44:05.882558 140399706380288 train.py:384] Initial compilation completed.
I1011 06:44:07.578282 139772030900224 train.py:384] Initial compilation completed.
I1011 06:44:07.858182 140231361497088 train.py:384] Initial compilation completed.
I1011 06:44:08.035283 140231361497088 benchmark.py:282] {"name": "ImagenetBenchmarkFakeData.test_fake_data", "succeeded": true, "metrics": {}, "extras": {"description": "ImageNet ResNet50 with fake data", "model_name": "resnet50", "parameters": "hp=true,bs=4096"}, "wall_time": 62.40354561805725}
[       OK ] ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Ran 1 test in 137.578s

OK
I1011 06:44:08.035782 140704714082304 benchmark.py:282] {"name": "ImagenetBenchmarkFakeData.test_fake_data", "succeeded": true, "metrics": {}, "extras": {"description": "ImageNet ResNet50 with fake data", "model_name": "resnet50", "parameters": "hp=true,bs=4096"}, "wall_time": 62.40485453605652}
[       OK ] ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Ran 1 test in 138.561s

OK
I1011 06:44:08.036287 139772030900224 benchmark.py:282] {"name": "ImagenetBenchmarkFakeData.test_fake_data", "succeeded": true, "metrics": {}, "extras": {"description": "ImageNet ResNet50 with fake data", "model_name": "resnet50", "parameters": "hp=true,bs=4096"}, "wall_time": 62.36168909072876}
[       OK ] ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Ran 1 test in 138.440s

OK
I1011 06:44:08.035931 140399706380288 benchmark.py:282] {"name": "ImagenetBenchmarkFakeData.test_fake_data", "succeeded": true, "metrics": {}, "extras": {"description": "ImageNet ResNet50 with fake data", "model_name": "resnet50", "parameters": "hp=true,bs=4096"}, "wall_time": 62.40787744522095}
[       OK ] ImagenetBenchmarkFakeData.test_fake_data
----------------------------------------------------------------------
Ran 1 test in 137.604s

OK
cgarciae commented 11 months ago

Interesting. @IvyZX maybe this is an Orbax issue?

IvyZX commented 11 months ago

Do you save your checkpoint in a local host directory (like /tmp/...) instead of a GCS directory? That will break because checkpoints.save_checkpoint_multiprocess and Orbax assumes the checkpoint destination is accessible by all hosts. Try enter a GCS bucket path as ckpt_path and please let me know if the issue still happens. Sorry about the confusion!