Closed mlaves closed 6 months ago
@kulinseth @shuhand0
@mlaves is there more to the error? In particular, I think more details about the operation should be printed?
@hawkinsp Sure, here's the full stacktrace.
Traceback (most recent call last):
File "/Users/max/flax/examples/mnist/main.py", line 65, in <module>
app.run(main)
File "/Users/max/jax-metal/lib/python3.11/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/Users/max/jax-metal/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/Users/max/flax/examples/mnist/main.py", line 60, in main
train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
File "/Users/max/flax/examples/mnist/train.py", line 140, in train_and_evaluate
state, train_loss, train_accuracy = train_epoch(state, train_ds,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/max/flax/examples/mnist/train.py", line 89, in train_epoch
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 249, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 160, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 2647, in bind
return self.bind_with_trace(top_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/core.py", line 790, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1193, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1177, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1113, in _pjit_call_impl_python
always_lower=False, lowering_platform=None).compile()
^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2319, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2638, in from_hlo
xla_executable, compile_options = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/max/jax-metal/lib/python3.11/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/max/jax-metal/lib/python3.11/site-packages/flax/linen/pooling.py:69:6: error: failed to legalize operation 'mhlo.pad'
y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
^
/Users/max/jax-metal/lib/python3.11/site-packages/flax/linen/pooling.py:69:6: note: see current operation: %207 = "mhlo.pad"(%206, %19) {edge_padding_high = dense<[0, 1, 1, 0]> : tensor<4xi64>, edge_padding_low = dense<[0, 1, 1, 0]> : tensor<4xi64>, interior_padding = dense<[0, 1, 1, 0]> : tensor<4xi64>} : (tensor<128x7x7x64xf32>, tensor<f32>) -> tensor<128x15x15x64xf32>
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/max/flax/examples/mnist/main.py", line 65, in <module>
app.run(main)
File "/Users/max/jax-metal/lib/python3.11/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/Users/max/jax-metal/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/Users/max/flax/examples/mnist/main.py", line 60, in main
train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
File "/Users/max/flax/examples/mnist/train.py", line 140, in train_and_evaluate
state, train_loss, train_accuracy = train_epoch(state, train_ds,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/max/flax/examples/mnist/train.py", line 89, in train_epoch
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/max/jax-metal/lib/python3.11/site-packages/flax/linen/pooling.py:69:6: error: failed to legalize operation 'mhlo.pad'
y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
^
/Users/max/jax-metal/lib/python3.11/site-packages/flax/linen/pooling.py:69:6: note: see current operation: %207 = "mhlo.pad"(%206, %19) {edge_padding_high = dense<[0, 1, 1, 0]> : tensor<4xi64>, edge_padding_low = dense<[0, 1, 1, 0]> : tensor<4xi64>, interior_padding = dense<[0, 1, 1, 0]> : tensor<4xi64>} : (tensor<128x7x7x64xf32>, tensor<f32>) -> tensor<128x15x15x64xf32>
I get the same error, with almost identical specs. Except python 3.9, and Apple Mac M2 Pro. Seems to be coming from within jax, as opposed to jax-metal however.
File ~/mambaforge/envs/mrff/lib/python3.9/site-packages/jax/_src/dispatch.py:465, in backend_compile(backend, module, options, host_callbacks)
460 return backend.compile(built_c, compile_options=options,
461 host_callbacks=host_callbacks)
462 # Some backends don't have `host_callbacks` option yet
463 # TODO(sharadmv): remove this fallback when all backends allow `compile`
464 # to take in `host_callbacks`
--> 465 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: UNKNOWN: io.py:164:0: error: failed to legalize operation 'mhlo.pad'
Thanks for sending the bug report. JAX-Metal plugin do not support pad with non-zero interior_padding. We will look into expanding the coverage and update here.
I am running a pretrained model, I wonder if there is a way to change my inputs/tokenisation to try and add interior_padding to circumvent this issue?
@BradBalderson It will be impossible to say without more details on how the operator is used in the model. If it is applied to one of the model inputs, perhaps.
Ah OK, I will see if it is due to the inputs. Thanks for the fast reply and feedback @hawkinsp and @shuhand0
BTW, you can implement interior padding with edge padding, if the interior padding is from your user code.
For example, to pad the innermost dimension, you do this:
1 + interior_padding
interior_padding
elements off the end.But... it might just be better to wait for our colleagues from Apple to fix the plugin :-)
I encountered the same bug by trying to calculate the grad of a loss function for a physics informed NN problem in a mac M1.
Jax version: 0.4.20 Bellow is the slacktrace. I am trying to construct a minimum reproducible problem, I will update with a new comment but the problem is very complex and difficult to recreate a more simple version of it.
{
"name": "XlaRuntimeError",
"message": "UNKNOWN: /var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: error: failed to legalize operation 'mhlo.pad'
/var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: note: see current operation: %99 = \"mhlo.pad\"(%98, %1) {edge_padding_high = dense<0> : tensor<2xi64>, edge_padding_low = dense<[0, -1]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<1024x2xf32>, tensor<f32>) -> tensor<1024x1xf32>
",
"stack": "---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()
File <frozen runpy>:88, in _run_code()
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel_launcher.py:17
15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/traitlets/config/application.py:1077, in launch_instance()
1076 app.initialize(argv)
-> 1077 app.start()
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelapp.py:701, in start()
700 try:
--> 701 self.io_loop.start()
702 except KeyboardInterrupt:
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/tornado/platform/asyncio.py:195, in start()
194 def start(self) -> None:
--> 195 self.asyncio_loop.run_forever()
File ~/miniconda3/envs/metal/lib/python3.11/asyncio/base_events.py:607, in run_forever()
606 while True:
--> 607 self._run_once()
608 if self._stopping:
File ~/miniconda3/envs/metal/lib/python3.11/asyncio/base_events.py:1922, in _run_once()
1921 else:
-> 1922 handle._run()
1923 handle = None
File ~/miniconda3/envs/metal/lib/python3.11/asyncio/events.py:80, in _run()
79 try:
---> 80 self._context.run(self._callback, *self._args)
81 except (SystemExit, KeyboardInterrupt):
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelbase.py:534, in dispatch_queue()
533 try:
--> 534 await self.process_one()
535 except Exception:
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelbase.py:523, in process_one()
522 return
--> 523 await dispatch(*args)
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelbase.py:429, in dispatch_shell()
428 if inspect.isawaitable(result):
--> 429 await result
430 except Exception:
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/kernelbase.py:767, in execute_request()
766 if inspect.isawaitable(reply_content):
--> 767 reply_content = await reply_content
769 # Flush output before sending the reply.
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/ipkernel.py:429, in do_execute()
428 if accepts_params[\"cell_id\"]:
--> 429 res = shell.run_cell(
430 code,
431 store_history=store_history,
432 silent=silent,
433 cell_id=cell_id,
434 )
435 else:
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/ipykernel/zmqshell.py:549, in run_cell()
548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3051, in run_cell()
3050 try:
-> 3051 result = self._run_cell(
3052 raw_cell, store_history, silent, shell_futures, cell_id
3053 )
3054 finally:
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3106, in _run_cell()
3105 try:
-> 3106 result = runner(coro)
3107 except BaseException as e:
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner()
128 try:
--> 129 coro.send(None)
130 except StopIteration as exc:
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3311, in run_cell_async()
3308 interactivity = \"none\" if silent else self.ast_node_interactivity
-> 3311 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3312 interactivity=interactivity, compiler=compiler, result=result)
3314 self.last_execution_succeeded = not has_raised
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3493, in run_ast_nodes()
3492 asy = compare(code)
-> 3493 if await self.run_code(code, result, async_=asy):
3494 return True
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3553, in run_code()
3552 else:
-> 3553 exec(code_obj, self.user_global_ns, self.user_ns)
3554 finally:
3555 # Reset our crash handler in place
Cell In[288], line 1
----> 1 jax.grad(los1)(params)
Cell In[286], line 1, in los1()
----> 1 def los1(params): return jnp.mean(los_physics1(params,tt,xx)**2)
Cell In[243], line 15, in los_physics1()
13 rho_V_x_x = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),1)(t,x)
---> 15 return rho_t(tc,xc)+rho_V_x_x(tc,xc)
Cell In[243], line 13, in los_physics1.<locals>.<lambda>()
12 #rho_V_x_t = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),0)(t,x)
---> 13 rho_V_x_x = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),1)(t,x)
15 return rho_t(tc,xc)+rho_V_x_x(tc,xc)
Cell In[243], line 13, in los_physics1.<locals>.<lambda>()
12 #rho_V_x_t = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),0)(t,x)
---> 13 rho_V_x_x = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),1)(t,x)
15 return rho_t(tc,xc)+rho_V_x_x(tc,xc)
Cell In[243], line 10, in los_physics1.<locals>.<lambda>()
9 rho_t= lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)),0)(t,x)
---> 10 Vx_x= lambda t,x: jax.grad(lambda t,x: jnp.sum(Vx(t,x)),1)(t,x)
12 #rho_V_x_t = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),0)(t,x)
Cell In[243], line 10, in los_physics1.<locals>.<lambda>()
9 rho_t= lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)),0)(t,x)
---> 10 Vx_x= lambda t,x: jax.grad(lambda t,x: jnp.sum(Vx(t,x)),1)(t,x)
12 #rho_V_x_t = lambda t,x: jax.grad(lambda t,x: jnp.sum(rho(t,x)*Vx_x(t,x)),0)(t,x)
Cell In[243], line 6, in los_physics1.<locals>.<lambda>()
5 rho=lambda t,x: NN(params,t,x)[0]
----> 6 Vx=lambda t,x: NN(params,t,x)[1]
7 #p=lambda t,x: NN(params,t,x)[2]
JaxStackTraceBeforeTransformation: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: error: failed to legalize operation 'mhlo.pad'
/var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: note: see current operation: %99 = \"mhlo.pad\"(%98, %1) {edge_padding_high = dense<0> : tensor<2xi64>, edge_padding_low = dense<[0, -1]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<1024x2xf32>, tensor<f32>) -> tensor<1024x1xf32>
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
XlaRuntimeError Traceback (most recent call last)
Cell In[288], line 1
----> 1 jax.grad(los1)(params)
[... skipping hidden 21 frame]
File ~/miniconda3/envs/metal/lib/python3.11/site-packages/jax/_src/compiler.py:255, in backend_compile(backend, module, options, host_callbacks)
250 return backend.compile(built_c, compile_options=options,
251 host_callbacks=host_callbacks)
252 # Some backends don't have `host_callbacks` option yet
253 # TODO(sharadmv): remove this fallback when all backends allow `compile`
254 # to take in `host_callbacks`
--> 255 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: UNKNOWN: /var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: error: failed to legalize operation 'mhlo.pad'
/var/folders/l6/g1b2w7y10rg8qzc142j02b240000gn/T/ipykernel_73686/1915339887.py:2:47: note: see current operation: %99 = \"mhlo.pad\"(%98, %1) {edge_padding_high = dense<0> : tensor<2xi64>, edge_padding_low = dense<[0, -1]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<1024x2xf32>, tensor<f32>) -> tensor<1024x1xf32>
"
}
Hello - I'm getting this error when running the following, very simple operation
import jax.numpy as jnp
jnp.cumprod(jnp.arange(10))
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/martin/miniconda3/envs/jax-metal/lib/python3.10/site-packages/jax/_src/numpy/reductions.py", line 657, in cumulative_reduction
return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'mhlo.pad'
<stdin>:1:0: note: called from
<stdin>:1:0: note: see current operation: %43 = "mhlo.pad"(%42, %1) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<1> : tensor<1xi64>} : (tensor<1xsi32>, tensor<si32>) -> tensor<2xsi32>
My env
>>> jax.print_environment_info()
jax: 0.4.20
jaxlib: 0.4.20
numpy: 1.26.4
python: 3.10.13 (main, Sep 11 2023, 08:16:02) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
Would be really cool if someone could fix this, as it makes jax-metal
pretty much unusable for me :(
For the pad with interior_padding, the fix will be in the upcoming jax-metal release and work in 14.4 OS.
The fix is in jax-metal 0.0.6. Some output from running flax/examples/mnist:
python main.py --workdir=/tmp/mnist --config=configs/default.py \
--config.learning_rate=0.05 --config.num_epochs=5
I0312 14:30:01.288698 7932478208 xla_bridge.py:660] Unable to initialize backend 'cuda':
I0312 14:30:01.288793 7932478208 xla_bridge.py:660] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0312 14:30:01.290362 7932478208 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/shuhan/miniconda3/envs/test/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)
W0312 14:30:01.290431 7932478208 xla_bridge.py:758] Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-12 14:30:01.290494: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max
systemMemory: 64.00 GB
maxCacheSize: 24.00 GB
I0312 14:30:01.301946 7932478208 main.py:50] JAX process: 0 / 1
I0312 14:30:01.302008 7932478208 main.py:51] JAX local devices: [METAL(id=0)]
I0312 14:30:01.302042 7932478208 local.py:45] Setting task status: process_index: 0, process_count: 1
I0312 14:30:01.302149 7932478208 local.py:50] Created artifact workdir of type ArtifactType.DIRECTORY and value /tmp/mnist.
I0312 14:30:01.302522 7932478208 dataset_info.py:358] Load dataset info from /Users/shuhan/tensorflow_datasets/mnist/3.0.1
I0312 14:30:01.303414 7932478208 dataset_info.py:411] Field info.citation from disk and from code do not match. Keeping the one from code.
I0312 14:30:01.303492 7932478208 dataset_info.py:411] Field info.splits from disk and from code do not match. Keeping the one from code.
I0312 14:30:01.303527 7932478208 dataset_info.py:411] Field info.module_name from disk and from code do not match. Keeping the one from code.
I0312 14:30:01.303578 7932478208 dataset_builder.py:351] Reusing dataset mnist (/Users/shuhan/tensorflow_datasets/mnist/3.0.1)
I0312 14:30:01.303609 7932478208 logging_logger.py:35] Constructing tf.data.Dataset mnist for split train, from /Users/shuhan/tensorflow_datasets/mnist/3.0.1
WARNING:tensorflow:From /Users/shuhan/miniconda3/envs/test/lib/python3.9/site-packages/tensorflow_datasets/core/dataset_builder.py:622: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.
W0312 14:30:01.332822 7932478208 deprecation.py:50] From /Users/shuhan/miniconda3/envs/test/lib/python3.9/site-packages/tensorflow_datasets/core/dataset_builder.py:622: get_single_element (from tensorflow.python.data.experimental.ops.get_single_element) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.
I0312 14:30:01.966295 7932478208 logging_logger.py:35] Constructing tf.data.Dataset mnist for split test, from /Users/shuhan/tensorflow_datasets/mnist/3.0.1
I0312 14:30:07.488775 7932478208 train.py:146] epoch: 1, train_loss: 0.2369, train_accuracy: 93.05, test_loss: 0.0590, test_accuracy: 98.07
I0312 14:30:10.687863 7932478208 train.py:146] epoch: 2, train_loss: 0.0611, train_accuracy: 98.11, test_loss: 0.0547, test_accuracy: 98.17
I0312 14:30:13.719958 7932478208 train.py:146] epoch: 3, train_loss: 0.0423, train_accuracy: 98.68, test_loss: 0.0330, test_accuracy: 98.73
I0312 14:30:16.748282 7932478208 train.py:146] epoch: 4, train_loss: 0.0308, train_accuracy: 99.00, test_loss: 0.0302, test_accuracy: 99.04
I0312 14:30:19.781277 7932478208 train.py:146] epoch: 5, train_loss: 0.0250, train_accuracy: 99.21, test_loss: 0.0306, test_accuracy: 99.03
Description
When following the MNIST example from
flax
(https://github.com/google/flax/tree/main/examples/mnist/), the following error occurs when using the latestjax-metal
plugin installed as described at https://developer.apple.com/metal/jax/ :What jax/jaxlib version are you using?
jax 0.4.11, jaxlib v0.4.10
Which accelerator(s) are you using?
MPS
Additional system info
Python 3.11, macOS 13.4, Mac Mini M2 Pro
NVIDIA GPU info
No response