jaymody / picoGPT

An unnecessarily tiny implementation of GPT-2 in NumPy.
MIT License
3.25k stars 417 forks source link

Fix gpt2.py to work with Jax #10

Closed certik closed 1 year ago

certik commented 1 year ago

The issue was that encoder returned a list, and np.append() expects a Jax array, not a list as arguments.

This patch makes it work with both NumPy and Jax.

Fixes #9.

certik commented 1 year ago

This is ready.