Open artus-LYTiQ opened 1 month ago
i have looked into it, and gave up, fell back to torch mps for now.
i don't think jax is ready for apple silicon, given the complex numbers we use for rope. the difference between torch and jax complex exponentials computed via 'precompute_freqs_cis' is > 1e-3; which is not ideal, but on the scale of this experiment, it is fine, based on what i've seen.
to add to this, torch mps vs jax cpu -- all other computation are within the threshold of 1e-3, except the complex sinusoid calculation.
i have looked into it, and gave up, fell back to torch mps for now.
i don't think jax is ready for apple silicon, given the complex numbers we use for rope. the difference between torch and jax complex exponentials computed via 'precompute_freqs_cis' is > 1e-3; which is not ideal, but on the scale of this experiment, it is fine, based on what i've seen.
silicon......supremacy..... i utter, writhing.....the day will soon come though
i have looked into it, and gave up, fell back to torch mps for now.
i don't think jax is ready for apple silicon, given the complex numbers we use for rope. the difference between torch and jax complex exponentials computed via 'precompute_freqs_cis' is > 1e-3; which is not ideal, but on the scale of this experiment, it is fine, based on what i've seen.
any example of torch mps? admit have not used pytorch since getting my mac :guilty:
I have it working with jax-metal, working on a pull request now. First change is to use float instead of complex to handle the RoPE. Second change is to reimplement top_k because jax.lax.top_k doesn't appear to work with jax-metal. This appears to be true in their CI test logs too: https://github.com/jax-ml/jax/actions/workflows/metal_plugin_ci.yml
you could implement the pre compute frequency cis by separating cosine and sine components, essentially split the real and imaginary parts so you will work with real numbers.
but i would test extensively if the original jax and the updated version produce the same set of values (within 1e-3 threshold) and keep the original method as well.
on cloud deployment the original method would be still more efficient.
It's true I did the complex multiply the obvious way. It could be behind a flag and/or done with Kahan's algorithm.
There is a fork of this repo made for Apple MLX by @samefarrar:
https://github.com/xjdr-alt/entropix/issues/50#issuecomment-2401356490
I’ll fix my bug in Jax-metal later today. My hope is to compare with MLX regarding ram consumption, tps and the order of magnitude of the numerical deviations. This should allow us to hint to xjdr and the rest of the community which route to choose on Metal devices.
On 10. Oct 2024, at 12:29, Henk Poley @.***> wrote:
There is a fork of this repo made for Apple MLX:
— Reply to this email directly, view it on GitHubhttps://github.com/xjdr-alt/entropix/issues/35#issuecomment-2404709870, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AMYYD6Y5T54ZK3PRLYRIYG3Z2ZJGLAVCNFSM6AAAAABPQWEK4WVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMBUG4YDSOBXGA. You are receiving this because you authored the thread.Message ID: @.***>
Big thank you to @samefarrar for porting it to mlx! It was by far the easiest setup (+extra points for uv)!
For newcomers, mlx has changed deps, so you must use this PR'd version until it gets merged: https://github.com/samefarrar/entropix_mlx/pull/22
Running Jax on Apple silicon with the Metal backend (mps), might need quite some work.
Obviously, add "jax-metal" by running
poetry add jax-metal
which will addjax-metal = "^0.1.0"
to pyproject.toml.Then the "devices(gpu)" needs to be removed. In "weights.py" change line 31 to
#device = jax.devices("gpu")[0]
and line 35 tow[name] = weight
Finally, you will be rewarded with the following error message at which I am stuck at the moment: `entropix % poetry run python entropix/main.py Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1728328226.631887 7453123 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! Metal device set to: Apple M2
systemMemory: 16.00 GB maxCacheSize: 5.33 GB
I0000 00:00:1728328226.659179 7453123 service.cc:145] XLA service 0x11896d380 initialized for platform METAL (this does not guarantee that XLA will be used). Devices: I0000 00:00:1728328226.659198 7453123 service.cc:153] StreamExecutor device (0): Metal,
I0000 00:00:1728328226.660714 7453123 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1728328226.660736 7453123 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Which number is larger, 9.9 or 9.11?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Traceback (most recent call last): File "/Users/artuskg/Documents/LLM/strawberry/entropix/entropix/main.py", line 108, in
tyro.cli(main)
File "/Users/artuskg/Library/Caches/pypoetry/virtualenvs/entropix-H1BIt9k6-py3.12/lib/python3.12/site-packages/tyro/_cli.py", line 229, in cli
return run_with_args_from_cli()
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/artuskg/Documents/LLM/strawberry/entropix/entropix/main.py", line 105, in main
generate(xfmr_weights, model_params, tokens)
File "/Users/artuskg/Documents/LLM/strawberry/entropix/entropix/main.py", line 75, in generate
freqs_cis = precompute_freqs_cis(model_params.head_dim, model_params.max_seq_len, model_params.rope_theta, model_params.use_scaled_rope)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/artuskg/Documents/LLM/strawberry/entropix/entropix/main.py", line 51, in precompute_freqs_cis
return jnp.exp(1j * freqs)