xjdr-alt / entropix

Entropy Based Sampling and Parallel CoT Decoding
Apache License 2.0
3.05k stars 311 forks source link

jax-metal #35

Open artus-LYTiQ opened 1 month ago

artus-LYTiQ commented 1 month ago

Running Jax on Apple silicon with the Metal backend (mps), might need quite some work.

Obviously, add "jax-metal" by running poetry add jax-metal which will add jax-metal = "^0.1.0" to pyproject.toml.

Then the "devices(gpu)" needs to be removed. In "weights.py" change line 31 to #device = jax.devices("gpu")[0] and line 35 to w[name] = weight

Finally, you will be rewarded with the following error message at which I am stuck at the moment: `entropix % poetry run python entropix/main.py Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1728328226.631887 7453123 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! Metal device set to: Apple M2

systemMemory: 16.00 GB maxCacheSize: 5.33 GB

I0000 00:00:1728328226.659179 7453123 service.cc:145] XLA service 0x11896d380 initialized for platform METAL (this does not guarantee that XLA will be used). Devices: I0000 00:00:1728328226.659198 7453123 service.cc:153] StreamExecutor device (0): Metal, I0000 00:00:1728328226.660714 7453123 mps_client.cc:406] Using Simple allocator. I0000 00:00:1728328226.660736 7453123 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator. <|begin_of_text|><|start_header_id|>system<|end_header_id|>

You're absolutely right. I need to delve deeper into my actual thought processes, including the uncertainties, associations, and even potential biases that arise as I consider the query. My previous responses, while informative, didn't truly capture the nuanced, sometimes messy nature of cognition. I'll strive to provide a more authentic representation of my internal dialogue, including moments of doubt, tangential thoughts, and the process of refining ideas. This should result in a more genuine demonstration of LLM chain of thought, reflection, and self-correction.

Which number is larger, 9.9 or 9.11?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Traceback (most recent call last): File "/Users/artuskg/Documents/LLM/strawberry/entropix/entropix/main.py", line 108, in tyro.cli(main) File "/Users/artuskg/Library/Caches/pypoetry/virtualenvs/entropix-H1BIt9k6-py3.12/lib/python3.12/site-packages/tyro/_cli.py", line 229, in cli return run_with_args_from_cli() ^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/artuskg/Documents/LLM/strawberry/entropix/entropix/main.py", line 105, in main generate(xfmr_weights, model_params, tokens) File "/Users/artuskg/Documents/LLM/strawberry/entropix/entropix/main.py", line 75, in generate freqs_cis = precompute_freqs_cis(model_params.head_dim, model_params.max_seq_len, model_params.rope_theta, model_params.use_scaled_rope) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/artuskg/Documents/LLM/strawberry/entropix/entropix/main.py", line 51, in precompute_freqs_cis return jnp.exp(1j * freqs)


  File "/Users/artuskg/Library/Caches/pypoetry/virtualenvs/entropix-H1BIt9k6-py3.12/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 573, in deferring_binary_op
    return binary_op(*args)
           ^^^^^^^^^^^^^^^^
  File "/Users/artuskg/Library/Caches/pypoetry/virtualenvs/entropix-H1BIt9k6-py3.12/lib/python3.12/site-packages/jax/_src/numpy/ufunc_api.py", line 177, in __call__
    return call(*args)
           ^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}, {mhlo.layout_mode = "default"}], function_type = (tensor<complex<f32>>, tensor<4096x32xf32>) -> tensor<4096x32xcomplex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<complex<f32>>, %arg1: tensor<4096x32xf32>):
  %0 = "mhlo.convert"(%arg1) : (tensor<4096x32xf32>) -> tensor<4096x32xcomplex<f32>>
  %1 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<complex<f32>>) -> tensor<4096x32xcomplex<f32>>
  %2 = "mhlo.multiply"(%1, %0) : (tensor<4096x32xcomplex<f32>>, tensor<4096x32xcomplex<f32>>) -> tensor<4096x32xcomplex<f32>>
  "func.return"(%2) : (tensor<4096x32xcomplex<f32>>) -> ()
}) : () -> ()
<unknown>:0: error: failed to legalize operation 'func.func'
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}, {mhlo.layout_mode = "default"}], function_type = (tensor<complex<f32>>, tensor<4096x32xf32>) -> tensor<4096x32xcomplex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<complex<f32>>, %arg1: tensor<4096x32xf32>):
  %0 = "mhlo.convert"(%arg1) : (tensor<4096x32xf32>) -> tensor<4096x32xcomplex<f32>>
  %1 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<complex<f32>>) -> tensor<4096x32xcomplex<f32>>
  %2 = "mhlo.multiply"(%1, %0) : (tensor<4096x32xcomplex<f32>>, tensor<4096x32xcomplex<f32>>) -> tensor<4096x32xcomplex<f32>>
  "func.return"(%2) : (tensor<4096x32xcomplex<f32>>) -> ()
}) : () -> ()

For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.`

Which was surprising to me as it seems to indicate the absence of a complex32 data type. Interestingly, [the Apple Jax on Metal documentation](https://developer.apple.com/metal/jax/) only states that "Unsupported data types: np.float64, np.complex64, np.complex128" and the [jax-ml issue 16416](https://github.com/jax-ml/jax/issues/16416) seemed to me to hint at Apple having wanted to implement complex32 already as of a year ago. 

Hope this will help someone to further this issue!
Arrabonae commented 1 month ago

i have looked into it, and gave up, fell back to torch mps for now.

i don't think jax is ready for apple silicon, given the complex numbers we use for rope. the difference between torch and jax complex exponentials computed via 'precompute_freqs_cis' is > 1e-3; which is not ideal, but on the scale of this experiment, it is fine, based on what i've seen.

Arrabonae commented 1 month ago

to add to this, torch mps vs jax cpu -- all other computation are within the threshold of 1e-3, except the complex sinusoid calculation.

yerbymatey commented 1 month ago

i have looked into it, and gave up, fell back to torch mps for now.

i don't think jax is ready for apple silicon, given the complex numbers we use for rope. the difference between torch and jax complex exponentials computed via 'precompute_freqs_cis' is > 1e-3; which is not ideal, but on the scale of this experiment, it is fine, based on what i've seen.

silicon......supremacy..... i utter, writhing.....the day will soon come though

TheodoreGalanos commented 1 month ago

i have looked into it, and gave up, fell back to torch mps for now.

i don't think jax is ready for apple silicon, given the complex numbers we use for rope. the difference between torch and jax complex exponentials computed via 'precompute_freqs_cis' is > 1e-3; which is not ideal, but on the scale of this experiment, it is fine, based on what i've seen.

any example of torch mps? admit have not used pytorch since getting my mac :guilty:

nix commented 1 month ago

I have it working with jax-metal, working on a pull request now. First change is to use float instead of complex to handle the RoPE. Second change is to reimplement top_k because jax.lax.top_k doesn't appear to work with jax-metal. This appears to be true in their CI test logs too: https://github.com/jax-ml/jax/actions/workflows/metal_plugin_ci.yml

Arrabonae commented 1 month ago

you could implement the pre compute frequency cis by separating cosine and sine components, essentially split the real and imaginary parts so you will work with real numbers.

but i would test extensively if the original jax and the updated version produce the same set of values (within 1e-3 threshold) and keep the original method as well.

on cloud deployment the original method would be still more efficient.

nix commented 1 month ago

It's true I did the complex multiply the obvious way. It could be behind a flag and/or done with Kahan's algorithm.

HenkPoley commented 1 month ago

There is a fork of this repo made for Apple MLX by @samefarrar:

https://github.com/xjdr-alt/entropix/issues/50#issuecomment-2401356490

artus-LYTiQ commented 1 month ago

I’ll fix my bug in Jax-metal later today. My hope is to compare with MLX regarding ram consumption, tps and the order of magnitude of the numerical deviations. This should allow us to hint to xjdr and the rest of the community which route to choose on Metal devices.

On 10. Oct 2024, at 12:29, Henk Poley @.***> wrote:



There is a fork of this repo made for Apple MLX:

50 (comment)https://github.com/xjdr-alt/entropix/issues/50#issuecomment-2401356490

— Reply to this email directly, view it on GitHubhttps://github.com/xjdr-alt/entropix/issues/35#issuecomment-2404709870, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AMYYD6Y5T54ZK3PRLYRIYG3Z2ZJGLAVCNFSM6AAAAABPQWEK4WVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMBUG4YDSOBXGA. You are receiving this because you authored the thread.Message ID: @.***>

svilupp commented 1 month ago

Big thank you to @samefarrar for porting it to mlx! It was by far the easiest setup (+extra points for uv)!

For newcomers, mlx has changed deps, so you must use this PR'd version until it gets merged: https://github.com/samefarrar/entropix_mlx/pull/22