How to reproduce using the latest master (018a1e1796d7ea3e96032d9667042316c8fa7864) on macOS M1:
$ python gpt2.py "Alan Turing theorized that computers would one day become" -n 8
generating: 100%|█████████████████████████████████| 8/8 [00:03<00:00, 2.44it/s]
the most powerful machines on the planet.
Then apply the following patch:
diff --git a/gpt2.py b/gpt2.py
index 62549bc..daf5685 100644
--- a/gpt2.py
+++ b/gpt2.py
@@ -1,4 +1,4 @@
-import numpy as np
+import jax.numpy as np
def gelu(x):
and:
$ python gpt2.py "Alan Turing theorized that computers would one day become" -n 8
generating: 0%| | 0/8 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/Users/ondrej/repos/picoGPT/gpt2.py", line 121, in <module>
fire.Fire(main)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/Users/ondrej/repos/picoGPT/gpt2.py", line 110, in main
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
File "/Users/ondrej/repos/picoGPT/gpt2.py", line 92, in generate
inputs = np.append(inputs, [next_id]) # append prediction to input
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 163, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/api.py", line 694, in cache_miss
execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/dispatch.py", line 240, in _xla_call_impl_lazy
return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
ans = call(fun, *args)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/dispatch.py", line 351, in _xla_callable_uncached
computation = sharded_lowering(fun, device, backend, name, donated_invars,
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/dispatch.py", line 342, in sharded_lowering
return pxla.lower_sharding_computation(
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2797, in lower_sharding_computation
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2073, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2006, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2802, in append
return concatenate([ravel(arr), ravel(values)], 0)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 163, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/api.py", line 698, in cache_miss
top_trace.process_call(primitive, fun_, tracers, params))
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1747, in process_call
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2035, in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers_)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 812, in ravel
_stackable(a) or _check_arraylike("ravel", a)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/util.py", line 345, in _check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: ravel requires ndarray or scalar arguments, got <class 'list'> at position 0.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/ondrej/repos/picoGPT/gpt2.py", line 121, in <module>
fire.Fire(main)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/Users/ondrej/repos/picoGPT/gpt2.py", line 110, in main
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
File "/Users/ondrej/repos/picoGPT/gpt2.py", line 92, in generate
inputs = np.append(inputs, [next_id]) # append prediction to input
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2802, in append
return concatenate([ravel(arr), ravel(values)], 0)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 812, in ravel
_stackable(a) or _check_arraylike("ravel", a)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/util.py", line 345, in _check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: ravel requires ndarray or scalar arguments, got <class 'list'> at position 0.
How to reproduce using the latest master (018a1e1796d7ea3e96032d9667042316c8fa7864) on macOS M1:
Then apply the following patch:
and:
I am running in the following Conda environment: