jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.99k stars 2.75k forks source link

Apple Silicon: error: failed to legalize operation 'mhlo.pad' #16366

Closed mlaves closed 6 months ago

mlaves commented 1 year ago

Description

When following the MNIST example from flax (https://github.com/google/flax/tree/main/examples/mnist/), the following error occurs when using the latest jax-metal plugin installed as described at https://developer.apple.com/metal/jax/ :

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/max/jax-metal/lib/python3.11/site-packages/flax/linen/pooling.py:69:6: error: failed to legalize operation 'mhlo.pad'
  y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)

What jax/jaxlib version are you using?

jax 0.4.11, jaxlib v0.4.10

Which accelerator(s) are you using?

MPS

Additional system info

Python 3.11, macOS 13.4, Mac Mini M2 Pro

NVIDIA GPU info

No response

hawkinsp commented 1 year ago

@kulinseth @shuhand0

hawkinsp commented 1 year ago

@mlaves is there more to the error? In particular, I think more details about the operation should be printed?

mlaves commented 1 year ago

@hawkinsp Sure, here's the full stacktrace.

Traceback (most recent call last):
  File "/Users/max/flax/examples/mnist/main.py", line 65, in <module>
    app.run(main)
  File "/Users/max/jax-metal/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/Users/max/jax-metal/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/Users/max/flax/examples/mnist/main.py", line 60, in main
    train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
  File "/Users/max/flax/examples/mnist/train.py", line 140, in train_and_evaluate
    state, train_loss, train_accuracy = train_epoch(state, train_ds,
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/flax/examples/mnist/train.py", line 89, in train_epoch
    grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 249, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
                                          ^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 160, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 2647, in bind
    return self.bind_with_trace(top_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 790, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1193, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1177, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
                         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1113, in _pjit_call_impl_python
    always_lower=False, lowering_platform=None).compile()
                                                ^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2319, in compile
    executable = UnloadedMeshExecutable.from_hlo(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2638, in from_hlo
    xla_executable, compile_options = _cached_compilation(
                                      ^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/max/jax-metal/lib/python3.11/site-packages/flax/linen/pooling.py:69:6: error: failed to legalize operation 'mhlo.pad'
  y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
     ^
/Users/max/jax-metal/lib/python3.11/site-packages/flax/linen/pooling.py:69:6: note: see current operation: %207 = "mhlo.pad"(%206, %19) {edge_padding_high = dense<[0, 1, 1, 0]> : tensor<4xi64>, edge_padding_low = dense<[0, 1, 1, 0]> : tensor<4xi64>, interior_padding = dense<[0, 1, 1, 0]> : tensor<4xi64>} : (tensor<128x7x7x64xf32>, tensor<f32>) -> tensor<128x15x15x64xf32>

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/max/flax/examples/mnist/main.py", line 65, in <module>
    app.run(main)
  File "/Users/max/jax-metal/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/Users/max/jax-metal/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/Users/max/flax/examples/mnist/main.py", line 60, in main
    train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
  File "/Users/max/flax/examples/mnist/train.py", line 140, in train_and_evaluate
    state, train_loss, train_accuracy = train_epoch(state, train_ds,
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/max/flax/examples/mnist/train.py", line 89, in train_epoch
    grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/max/jax-metal/lib/python3.11/site-packages/flax/linen/pooling.py:69:6: error: failed to legalize operation 'mhlo.pad'
  y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
     ^
/Users/max/jax-metal/lib/python3.11/site-packages/flax/linen/pooling.py:69:6: note: see current operation: %207 = "mhlo.pad"(%206, %19) {edge_padding_high = dense<[0, 1, 1, 0]> : tensor<4xi64>, edge_padding_low = dense<[0, 1, 1, 0]> : tensor<4xi64>, interior_padding = dense<[0, 1, 1, 0]> : tensor<4xi64>} : (tensor<128x7x7x64xf32>, tensor<f32>) -> tensor<128x15x15x64xf32>
BradBalderson commented 1 year ago

I get the same error, with almost identical specs. Except python 3.9, and Apple Mac M2 Pro. Seems to be coming from within jax, as opposed to jax-metal however.

File ~/mambaforge/envs/mrff/lib/python3.9/site-packages/jax/_src/dispatch.py:465, in backend_compile(backend, module, options, host_callbacks)
    460   return backend.compile(built_c, compile_options=options,
    461                          host_callbacks=host_callbacks)
    462 # Some backends don't have `host_callbacks` option yet
    463 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    464 # to take in `host_callbacks`
--> 465 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: UNKNOWN: io.py:164:0: error: failed to legalize operation 'mhlo.pad'
shuhand0 commented 1 year ago

Thanks for sending the bug report. JAX-Metal plugin do not support pad with non-zero interior_padding. We will look into expanding the coverage and update here.

BradBalderson commented 1 year ago

I am running a pretrained model, I wonder if there is a way to change my inputs/tokenisation to try and add interior_padding to circumvent this issue?

hawkinsp commented 1 year ago

@BradBalderson It will be impossible to say without more details on how the operator is used in the model. If it is applied to one of the model inputs, perhaps.

BradBalderson commented 1 year ago

Ah OK, I will see if it is due to the inputs. Thanks for the fast reply and feedback @hawkinsp and @shuhand0

hawkinsp commented 1 year ago

BTW, you can implement interior padding with edge padding, if the interior padding is from your user code.

For example, to pad the innermost dimension, you do this:

But... it might just be better to wait for our colleagues from Apple to fix the plugin :-)

Mixpap commented 8 months ago

I encountered the same bug by trying to calculate the grad of a loss function for a physics informed NN problem in a mac M1.

Jax version: 0.4.20 Bellow is the slacktrace. I am trying to construct a minimum reproducible problem, I will update with a new comment but the problem is very complex and difficult to recreate a more simple version of it.

{
    "name": "XlaRuntimeError",
    "message": "UNKNOWN: /var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: error: failed to legalize operation 'mhlo.pad'
/var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: note: see current operation: %99 = \"mhlo.pad\"(%98, %1) {edge_padding_high = dense<0> : tensor<2xi64>, edge_padding_low = dense<[0, -1]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<1024x2xf32>, tensor<f32>) -> tensor<1024x1xf32>
",
    "stack": "---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()

File <frozen runpy>:88, in _run_code()

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel_launcher.py:17
     15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/traitlets/config/application.py:1077, in launch_instance()
   1076 app.initialize(argv)
-> 1077 app.start()

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelapp.py:701, in start()
    700 try:
--> 701     self.io_loop.start()
    702 except KeyboardInterrupt:

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/tornado/platform/asyncio.py:195, in start()
    194 def start(self) -> None:
--> 195     self.asyncio_loop.run_forever()

File ~/miniconda3/envs/metal/lib/python3.11/asyncio/base_events.py:607, in run_forever()
    606 while True:
--> 607     self._run_once()
    608     if self._stopping:

File ~/miniconda3/envs/metal/lib/python3.11/asyncio/base_events.py:1922, in _run_once()
   1921     else:
-> 1922         handle._run()
   1923 handle = None

File ~/miniconda3/envs/metal/lib/python3.11/asyncio/events.py:80, in _run()
     79 try:
---> 80     self._context.run(self._callback, *self._args)
     81 except (SystemExit, KeyboardInterrupt):

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelbase.py:534, in dispatch_queue()
    533 try:
--> 534     await self.process_one()
    535 except Exception:

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelbase.py:523, in process_one()
    522         return
--> 523 await dispatch(*args)

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelbase.py:429, in dispatch_shell()
    428     if inspect.isawaitable(result):
--> 429         await result
    430 except Exception:

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelbase.py:767, in execute_request()
    766 if inspect.isawaitable(reply_content):
--> 767     reply_content = await reply_content
    769 # Flush output before sending the reply.

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/ipkernel.py:429, in do_execute()
    428 if accepts_params[\"cell_id\"]:
--> 429     res = shell.run_cell(
    430         code,
    431         store_history=store_history,
    432         silent=silent,
    433         cell_id=cell_id,
    434     )
    435 else:

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/zmqshell.py:549, in run_cell()
    548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3051, in run_cell()
   3050 try:
-> 3051     result = self._run_cell(
   3052         raw_cell, store_history, silent, shell_futures, cell_id
   3053     )
   3054 finally:

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3106, in _run_cell()
   3105 try:
-> 3106     result = runner(coro)
   3107 except BaseException as e:

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner()
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3311, in run_cell_async()
   3308 interactivity = \"none\" if silent else self.ast_node_interactivity
-> 3311 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3312        interactivity=interactivity, compiler=compiler, result=result)
   3314 self.last_execution_succeeded = not has_raised

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3493, in run_ast_nodes()
   3492     asy = compare(code)
-> 3493 if await self.run_code(code, result, async_=asy):
   3494     return True

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3553, in run_code()
   3552     else:
-> 3553         exec(code_obj, self.user_global_ns, self.user_ns)
   3554 finally:
   3555     # Reset our crash handler in place

Cell In[288], line 1
----> 1 jax.grad(los1)(params)

Cell In[286], line 1, in los1()
----> 1 def los1(params): return jnp.mean(los_physics1(params,tt,xx)**2)

Cell In[243], line 15, in los_physics1()
     13 rho_V_x_x = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),1)(t,x)
---> 15 return rho_t(tc,xc)+rho_V_x_x(tc,xc)

Cell In[243], line 13, in los_physics1.<locals>.<lambda>()
     12 #rho_V_x_t = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),0)(t,x)
---> 13 rho_V_x_x = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),1)(t,x)
     15 return rho_t(tc,xc)+rho_V_x_x(tc,xc)

Cell In[243], line 13, in los_physics1.<locals>.<lambda>()
     12 #rho_V_x_t = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),0)(t,x)
---> 13 rho_V_x_x = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),1)(t,x)
     15 return rho_t(tc,xc)+rho_V_x_x(tc,xc)

Cell In[243], line 10, in los_physics1.<locals>.<lambda>()
      9 rho_t= lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)),0)(t,x)
---> 10 Vx_x= lambda t,x: jax.grad(lambda t,x: jnp.sum(Vx(t,x)),1)(t,x)
     12 #rho_V_x_t = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),0)(t,x)

Cell In[243], line 10, in los_physics1.<locals>.<lambda>()
      9 rho_t= lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)),0)(t,x)
---> 10 Vx_x= lambda t,x: jax.grad(lambda t,x: jnp.sum(Vx(t,x)),1)(t,x)
     12 #rho_V_x_t = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),0)(t,x)

Cell In[243], line 6, in los_physics1.<locals>.<lambda>()
      5 rho=lambda t,x: NN(params,t,x)[0]
----> 6 Vx=lambda t,x: NN(params,t,x)[1]
      7 #p=lambda t,x: NN(params,t,x)[2]

JaxStackTraceBeforeTransformation: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: error: failed to legalize operation 'mhlo.pad'
/var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: note: see current operation: %99 = \"mhlo.pad\"(%98, %1) {edge_padding_high = dense<0> : tensor<2xi64>, edge_padding_low = dense<[0, -1]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<1024x2xf32>, tensor<f32>) -> tensor<1024x1xf32>

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

XlaRuntimeError                           Traceback (most recent call last)
Cell In[288], line 1
----> 1 jax.grad(los1)(params)

    [... skipping hidden 21 frame]

File ~/miniconda3/envs/metal/lib/python3.11/site-packages/jax/_src/compiler.py:255, in backend_compile(backend, module, options, host_callbacks)
    250   return backend.compile(built_c, compile_options=options,
    251                          host_callbacks=host_callbacks)
    252 # Some backends don't have `host_callbacks` option yet
    253 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    254 # to take in `host_callbacks`
--> 255 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: UNKNOWN: /var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: error: failed to legalize operation 'mhlo.pad'
/var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: note: see current operation: %99 = \"mhlo.pad\"(%98, %1) {edge_padding_high = dense<0> : tensor<2xi64>, edge_padding_low = dense<[0, -1]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<1024x2xf32>, tensor<f32>) -> tensor<1024x1xf32>
"
}
mar-muel commented 6 months ago

Hello - I'm getting this error when running the following, very simple operation

import jax.numpy as jnp
jnp.cumprod(jnp.arange(10))
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/martin/miniconda3/envs/jax-metal/lib/python3.10/site-packages/jax/_src/numpy/reductions.py", line 657, in cumulative_reduction
    return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'mhlo.pad'
<stdin>:1:0: note: called from
<stdin>:1:0: note: see current operation: %43 = "mhlo.pad"(%42, %1) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<1> : tensor<1xi64>} : (tensor<1xsi32>, tensor<si32>) -> tensor<2xsi32>

My env

>>> jax.print_environment_info()
jax:    0.4.20
jaxlib: 0.4.20
numpy:  1.26.4
python: 3.10.13 (main, Sep 11 2023, 08:16:02) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1

Would be really cool if someone could fix this, as it makes jax-metal pretty much unusable for me :(

shuhand0 commented 6 months ago

For the pad with interior_padding, the fix will be in the upcoming jax-metal release and work in 14.4 OS.

shuhand0 commented 6 months ago

The fix is in jax-metal 0.0.6. Some output from running flax/examples/mnist:

python main.py --workdir=/tmp/mnist --config=configs/default.py \
--config.learning_rate=0.05 --config.num_epochs=5
I0312 14:30:01.288698 7932478208 xla_bridge.py:660] Unable to initialize backend 'cuda': 
I0312 14:30:01.288793 7932478208 xla_bridge.py:660] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0312 14:30:01.290362 7932478208 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/shuhan/miniconda3/envs/test/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)
W0312 14:30:01.290431 7932478208 xla_bridge.py:758] Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-12 14:30:01.290494: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

I0312 14:30:01.301946 7932478208 main.py:50] JAX process: 0 / 1
I0312 14:30:01.302008 7932478208 main.py:51] JAX local devices: [METAL(id=0)]
I0312 14:30:01.302042 7932478208 local.py:45] Setting task status: process_index: 0, process_count: 1
I0312 14:30:01.302149 7932478208 local.py:50] Created artifact workdir of type ArtifactType.DIRECTORY and value /tmp/mnist.
I0312 14:30:01.302522 7932478208 dataset_info.py:358] Load dataset info from /Users/shuhan/tensorflow_datasets/mnist/3.0.1
I0312 14:30:01.303414 7932478208 dataset_info.py:411] Field info.citation from disk and from code do not match. Keeping the one from code.
I0312 14:30:01.303492 7932478208 dataset_info.py:411] Field info.splits from disk and from code do not match. Keeping the one from code.
I0312 14:30:01.303527 7932478208 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I0312 14:30:01.303578 7932478208 dataset_builder.py:351] Reusing dataset mnist (/Users/shuhan/tensorflow_datasets/mnist/3.0.1)
I0312 14:30:01.303609 7932478208 logging_logger.py:35] Constructing tf.data.Dataset mnist for split train, from /Users/shuhan/tensorflow_datasets/mnist/3.0.1
WARNING:tensorflow:From /Users/shuhan/miniconda3/envs/test/lib/python3.9/site-packages/tensorflow_datasets/core/dataset_builder.py:622: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.
W0312 14:30:01.332822 7932478208 deprecation.py:50] From /Users/shuhan/miniconda3/envs/test/lib/python3.9/site-packages/tensorflow_datasets/core/dataset_builder.py:622: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.
I0312 14:30:01.966295 7932478208 logging_logger.py:35] Constructing tf.data.Dataset mnist for split test, from /Users/shuhan/tensorflow_datasets/mnist/3.0.1
I0312 14:30:07.488775 7932478208 train.py:146] epoch:  1, train_loss: 0.2369, train_accuracy: 93.05, test_loss: 0.0590, test_accuracy: 98.07
I0312 14:30:10.687863 7932478208 train.py:146] epoch:  2, train_loss: 0.0611, train_accuracy: 98.11, test_loss: 0.0547, test_accuracy: 98.17
I0312 14:30:13.719958 7932478208 train.py:146] epoch:  3, train_loss: 0.0423, train_accuracy: 98.68, test_loss: 0.0330, test_accuracy: 98.73
I0312 14:30:16.748282 7932478208 train.py:146] epoch:  4, train_loss: 0.0308, train_accuracy: 99.00, test_loss: 0.0302, test_accuracy: 99.04
I0312 14:30:19.781277 7932478208 train.py:146] epoch:  5, train_loss: 0.0250, train_accuracy: 99.21, test_loss: 0.0306, test_accuracy: 99.03