jaymody / picoGPT

An unnecessarily tiny implementation of GPT-2 in NumPy.
MIT License
3.25k stars 415 forks source link

Using jax.numpy instead of numpy gives TypeError on macOS #9

Closed certik closed 1 year ago

certik commented 1 year ago

How to reproduce using the latest master (018a1e1796d7ea3e96032d9667042316c8fa7864) on macOS M1:

$ python gpt2.py "Alan Turing theorized that computers would one day become" -n 8
generating: 100%|█████████████████████████████████| 8/8 [00:03<00:00,  2.44it/s]
 the most powerful machines on the planet.

Then apply the following patch:

diff --git a/gpt2.py b/gpt2.py
index 62549bc..daf5685 100644
--- a/gpt2.py
+++ b/gpt2.py
@@ -1,4 +1,4 @@
-import numpy as np
+import jax.numpy as np

 def gelu(x):

and:

$ python gpt2.py "Alan Turing theorized that computers would one day become" -n 8
generating:   0%|                                         | 0/8 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/Users/ondrej/repos/picoGPT/gpt2.py", line 121, in <module>
    fire.Fire(main)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/Users/ondrej/repos/picoGPT/gpt2.py", line 110, in main
    output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
  File "/Users/ondrej/repos/picoGPT/gpt2.py", line 92, in generate
    inputs = np.append(inputs, [next_id])  # append prediction to input
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 163, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/api.py", line 694, in cache_miss
    execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/dispatch.py", line 240, in _xla_call_impl_lazy
    return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
    ans = call(fun, *args)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/dispatch.py", line 351, in _xla_callable_uncached
    computation = sharded_lowering(fun, device, backend, name, donated_invars,
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/dispatch.py", line 342, in sharded_lowering
    return pxla.lower_sharding_computation(
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2797, in lower_sharding_computation
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2073, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2006, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2802, in append
    return concatenate([ravel(arr), ravel(values)], 0)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 163, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/api.py", line 698, in cache_miss
    top_trace.process_call(primitive, fun_, tracers, params))
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1747, in process_call
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2035, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 812, in ravel
    _stackable(a) or _check_arraylike("ravel", a)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/util.py", line 345, in _check_arraylike
    raise TypeError(msg.format(fun_name, type(arg), pos))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: ravel requires ndarray or scalar arguments, got <class 'list'> at position 0.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/ondrej/repos/picoGPT/gpt2.py", line 121, in <module>
    fire.Fire(main)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/Users/ondrej/repos/picoGPT/gpt2.py", line 110, in main
    output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
  File "/Users/ondrej/repos/picoGPT/gpt2.py", line 92, in generate
    inputs = np.append(inputs, [next_id])  # append prediction to input
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2802, in append
    return concatenate([ravel(arr), ravel(values)], 0)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 812, in ravel
    _stackable(a) or _check_arraylike("ravel", a)
  File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/util.py", line 345, in _check_arraylike
    raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: ravel requires ndarray or scalar arguments, got <class 'list'> at position 0.

I am running in the following Conda environment:

$ conda env export
name: pico
channels:
  - conda-forge
dependencies:
  - appdirs=1.4.4=pyh9f0ad1d_0
  - brotlipy=0.7.0=py39h02fc5c5_1005
  - bzip2=1.0.8=h3422bc3_4
  - c-ares=1.18.1=h3422bc3_0
  - ca-certificates=2022.12.7=h4653dfc_0
  - cffi=1.15.1=py39h7e6b969_3
  - cryptography=39.0.1=py39he2a39a8_0
  - idna=3.4=pyhd8ed1ab_0
  - jax=0.4.3=pyhd8ed1ab_0
  - jaxlib=0.4.3=cpu_py39h99d3290_1
  - libabseil=20220623.0=cxx17_h28b99d4_6
  - libblas=3.9.0=16_osxarm64_openblas
  - libcblas=3.9.0=16_osxarm64_openblas
  - libcxx=14.0.6=h2692d47_0
  - libffi=3.4.2=h3422bc3_5
  - libgfortran=5.0.0=11_3_0_hd922786_27
  - libgfortran5=11.3.0=hdaf2cc0_27
  - libgrpc=1.51.1=hb15be72_1
  - liblapack=3.9.0=16_osxarm64_openblas
  - libopenblas=0.3.21=openmp_hc731615_3
  - libprotobuf=3.21.12=hb5ab8b9_0
  - libsqlite=3.40.0=h76d750c_0
  - libzlib=1.2.13=h03a7124_4
  - llvm-openmp=15.0.7=h7cfbb63_0
  - ncurses=6.3=h07bb92c_1
  - openssl=3.0.8=h03a7124_0
  - opt_einsum=3.3.0=pyhd8ed1ab_1
  - packaging=23.0=pyhd8ed1ab_0
  - pip=23.0=pyhd8ed1ab_0
  - pooch=1.6.0=pyhd8ed1ab_0
  - pycparser=2.21=pyhd8ed1ab_0
  - pyopenssl=23.0.0=pyhd8ed1ab_0
  - pysocks=1.7.1=pyha2e5f31_6
  - python=3.9.16=hea58f1e_0_cpython
  - python_abi=3.9=3_cp39
  - re2=2023.02.01=hb7217d7_0
  - readline=8.1.2=h46ed386_0
  - scipy=1.10.0=py39h18313fe_2
  - setuptools=67.1.0=pyhd8ed1ab_0
  - tk=8.6.12=he1e0b03_0
  - tzdata=2022g=h191b570_0
  - urllib3=1.26.14=pyhd8ed1ab_0
  - wheel=0.38.4=pyhd8ed1ab_0
  - xz=5.2.6=h57fd34a_0
  - zlib=1.2.13=h03a7124_4
  - pip:
    - absl-py==1.4.0
    - astunparse==1.6.3
    - cachetools==5.3.0
    - certifi==2022.12.7
    - charset-normalizer==2.0.12
    - fire==0.5.0
    - flatbuffers==23.1.21
    - gast==0.4.0
    - google-auth==2.16.0
    - google-auth-oauthlib==0.4.6
    - google-pasta==0.2.0
    - grpcio==1.51.1
    - h5py==3.8.0
    - importlib-metadata==6.0.0
    - keras==2.11.0
    - libclang==15.0.6.1
    - markdown==3.4.1
    - markupsafe==2.1.2
    - numpy==1.24.1
    - oauthlib==3.2.2
    - protobuf==3.19.6
    - pyasn1==0.4.8
    - pyasn1-modules==0.2.8
    - regex==2017.4.5
    - requests==2.27.1
    - requests-oauthlib==1.3.1
    - rsa==4.9
    - six==1.16.0
    - tensorboard==2.11.2
    - tensorboard-data-server==0.6.1
    - tensorboard-plugin-wit==1.8.1
    - tensorflow-estimator==2.11.0
    - tensorflow-macos==2.11.0
    - termcolor==2.2.0
    - tqdm==4.64.0
    - typing-extensions==4.4.0
    - werkzeug==2.2.2
    - wrapt==1.14.1
    - zipp==3.13.0
prefix: /Users/ondrej/mambaforge/envs/pico
certik commented 1 year ago

This particular error is fixed by https://github.com/jaymody/picoGPT/pull/10.