Closed rjpower closed 2 hours ago
To further demonstrate my JAX ignorance, I tried adding a fully replicated sharding explicitly (since in this case everything should fit) with:
process_mesh = Mesh(
np.array(jax.devices()).reshape((jax.process_count(), -1)),
("process", "device"),
)
shardings = [None for i in range(len(arr.shape))]
sharding = NamedSharding(process_mesh, PartitionSpec(*shardings))
def _shard_fn(input_array):
return jax.device_put(input_array, sharding)
out = jax.jit(_shard_fn)(arr)
but this doesn't seem to have any effect (I get the "non-addressable" error, or the "incompatible devices" error if I also specify output_shardings=sharding
).
And, oddly enough, everything starts to work when I remove the jax.jit
for the shard function and just do:
shardings = [None for i in range(len(arr.shape))]
sharding = NamedSharding(process_mesh, PartitionSpec(*shardings))
input_array = jax.device_put(arr, sharding)
return np.array(input_array)
Are there any downsides to doing this for this type of copy operation?
so, in the small scale, that first error is crazy because the devices are literally the same, just in a different sort order. This seems like a bug? Could it be (another) issue with best_effort_sharding?
Python 3.10.9 | packaged by conda-forge | (main, Feb 2 2023, 20:26:08) [Clang 14.0.6 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>>
>>> a = [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
>>> b = [0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15]
>>> sorted(a) == sorted(b)
True
>>>
2) i'm still wrestling with the right thing to do here and I should be more systematic.
The constraints, as I understand them, are:
1) only device_put can move things between device types (cpu->tpu) 2) only jit can coordinate cross-host data movement across devices (I think), so the arrays must be fully addressable to do without jit. 3) device_put cannot be used inside jit. device_put cannot be differentiated. 4) with_sharding_constraint can be used in either jit or outside jit 5) "device_put transposes while with_sharding_constraint doesn't" is what the JAX people have told me, which is apparently just because device_put can take a src.
I don't think my current solution is the right one, but I'm not entirely sure what it should be. I think it's:
1) cross-device, device_put 2) un'jitted wsc if the array has a sharding and is fully addressable 3) jitted wsc otherwise
which can probably simplify to
1) cross-device, device_put 2) jitted wsc otherwise, but I need to test it out
Yeah, I'm puzzled by the behavior. I think your analysis is correct, you'd expect to need a 2 stage movement, but why doesn't your current version work for this simple case? Even when I switched to the fully replicated sharding, JAX still reported the device issue. Isn't reshuffling exactly what .with_sharding_constraint
is supposed to do?
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#constraining-shardings-of-intermediates-in-jitted-code """ Using jax.lax.with_sharding_constraint is much like jax.device_put, except we use it inside staged-out (i.e. jit-decorated) functions: """
IIUC, your original code should have worked fine: there were no input shardings, and output_shards
should be the same as having jax.lax.with_sharding_constraint, which should force the sharding the way we want. The device_put
should be identical (but might not work if we needed cross-host movement in this case), but for some reason works...
I synced with the Jax folks, and they said that jit requires the devices to be in the same order for input and output. They’re working on relaxing that but there are weird performance regressions.
the right thing is probably to detect if the device sequence/mesh(not set) is changing and use device_put if so?
I feel like this is much more complicated than it needs to be.
Interesting... I guess you're expected to effectively re-use the same device mesh all the time, and that's why this isn't hitting people more often? This just seems confusingly hard for what I'd think would be such a common desire: "get the data to/from the CPU".
If you're forced to use the original mesh, it seems like you either have to give up and fully replicate the array, or somehow carefully choose a sharding that doesn't partition you across hosts, and then run device_put.
Are we holding it wrong? (I guess fully-replicating isn't too bad if you're doing it one array at a time?)
Yeah so the best effort sharding logic relatively new. I added it and the crazy CPU-vs-accelerator logic to handle loading 34B param models on v4-8 (in response to #508 ) and also to be able to load smaller models on our internal gpu cluster, which have less CPU memory and we needed a solution there.
At the time I didn't realize that you were effectively stuck with one mesh, and it just happened that the one mesh we were using at the time is the same one I created for best effort sharding and got lucky.
I think the things to do are:
WDYT?
Ah interesting, for the model loading side, everything makes sense: you're building the whole pytree, and if you don't shard, lots of models won't end up fitting. I don't quite follow how the CPU usage is reduced, since IIUC we're always loading the model replicated on the CPU and then sharding (but I didn't read the model loading code closely...). Oh, maybe because you have the implicit mesh you can avoid making a copy of the tree -> state_dict first... I should really just read the code and PR 😛.
skip best effort sharding if context mesh is not set (easy and I'm guessing fixes this particular issue)
For export we implicitly have the mesh from the input array already, so we can re-use that. I'm probably confused, but I think generating an appropriate (non-replicated) sharding for an arbitrary mesh for the export side seems hard. We'd need to choose an arrangement that keeps physical devices within a single host so that we don't get the "non-addressable devices" issue when we convert to CPU, right?
(It doesn't seem like there's anything wrong with the idea of the best-effort sharding for loading, and we could even avoid looking for the implicit mesh, if just moving between meshs worked at all...)
I thought something dumb like this would work to make the array fully replicated and then copy-able, but I still hit the "fetching an array that spans non-addressable devices". Though again, just switching to device_put
works fine. I'm assuming this is just because I'm getting lucky and there's no model parallelism here. Maybe you just need both: first the sharding to convert to replicated and then the device_put
to... I'm not sure what it's doing at this point, TBH.
shardings = [None for i in range(len(arr.shape))]
sharding = NamedSharding(arr.sharding.mesh, PartitionSpec(*shardings))
def _copy(in_array):
return jax.lax.with_sharding_constraint(in_array, sharding)
arr = jax.jit(_copy, donate_argnums=0)(arr)
# but jax.device_put(arr, sharding) works fine!
return np.array(arr)
probably ensure that we have the context mesh set inside this script at load time, because I added best effort sharding specifically for the lora use case
(I think) we're now getting the mesh correctly for loading, since the model loads quickly and without errors. It's only at save time that we run into this issue (there's a sort of copy of the best-effort sharding there). For this PEFT save logic, do we need the best-effort sharding at all? It seems like we're iterating over layers one at a time and copying them to the CPU, so we only need enough device memory for that single layer (again, I could be missing something).
check inside named_jit that shardings are consistent with the context mesh (harder but probably we can back it out from JAX?). For now error, but one could stretch to say reshard before
Yeah I feel some well-compartmentalized functions would help a lot: "make this CPU array appear on the devices with this sharding", "make this device array appear on the CPU replicated", both handling the resharding as necessary... it seems like they should be part of JAX TBH... (Some of that could be extending hax.shard
it seems like, but for some of the state_dict manipulation, you're outside of the PyTree context so maybe harder to use then?)
Having named_jit
do some magic to reshard seems okay, but probably unnecessary if it's relatively easy to coerce things ahead of time. Automatic resharding always worries me a bit that you'll accidentally keep doing a bunch of data movement on every step without realizing it (I know it's unlikely but e.g. "oops my output sharding for my weights is different from my input sharding").
I haven't read the entire thread but to do this: This just seems confusingly hard for what I'd think would be such a common desire: "get the data to/from the CPU".
maybe just try putting your array on pinned_host
memory? jax.device_put(x, NamedSharding(mesh, pspec, memory_kind='pinned_host')
This should keep the sharding the same as the TPU one without having to mess around with devices. i.e. the mesh stays the same with TPU devices! You only change the memory kind of the sharding to point to host.
Note you need to enable this config: jax.config.update('jax_enable_memories', True)
Talking to @yashk2810, it seems like device_put is more capable than I thought (in particular it seems like it can do cross-host transfers)
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import jax.experimental.mesh_utils as mesh_utils
D = 4096
mesh_devices = mesh_utils.create_device_mesh((len(jax.devices()),))
smart_mesh = Mesh(mesh_devices, ('dp',))
# mesh = Mesh(jax.devices(), ('dp'))
host_mesh = Mesh(np.array(jax.devices()).reshape(jax.process_count(), -1), ('host', 'device'))
z = jnp.full((D, D), jax.process_index())
print(z.sharding)
smart_sharding = NamedSharding(smart_mesh, P('dp'))
host_sharding = NamedSharding(host_mesh, P('device'))
z2 = jax.jit(lambda: jnp.zeros((D, D)), out_shardings=smart_sharding)()
print(z2.sharding)
z3 = jax.jit(lambda: jnp.zeros((D, D)), out_shardings=host_sharding)()
print(z3.sharding)
# this is no good:
# Traceback (most recent call last):
# File "/home/dlwh/test_device_put.py", line 30, in <module>
# z4 = jax.jit(lambda x: x, out_shardings=host_sharding)(z2)
# ValueError: Received incompatible devices for jitted computation. Got argument x of <lambda> with shape float32[4096,4096] and device ids
# z4 = jax.jit(lambda x: x, out_shardings=host_sharding)(z2)
# print(z4.sharding)
z5 = jax.device_put(z2, host_sharding)
print(z5.sharding)
ok @yashk2810 figured this out for me and I'm patching Haliax with the fix. https://github.com/stanford-crfm/haliax/pull/96 (Yash avert your eyes...)
@rjpower if you get a chance, could you check if things work with the latest jax nightly?
(@yashk2810 fixed it i think)
Hrm, I changed the Haliax dependency to "haliax @ git+https://github.com/stanford-crfm/haliax.git@main"
and ran:
BASE_DIR=gs://wasabi-tpu-training/gsm8k/test/llama2-0 python infra/launch.py --foreground --tpu_name=tpu-0 -- python examples/gsm8k-lora/gsm8k_lora.py --config=examples/gsm8k-lora/gsm8k-llama2.yaml --hf_save_path=$BASE_DIR/hf --data_cache_dir=gs://wasabi-tpu-training/gsm8k/data --data_seed=0 --trainer.num_train_steps=10
I still see this error:
File "/opt/levanter/src/levanter/compat/torch_serialization.py", line 449, in <lambda>
model = jax.tree_util.tree_map(lambda arr: get_to_cpu(arr), model)
File "/opt/levanter/src/levanter/compat/torch_serialization.py", line 445, in get_to_cpu
out = jax.jit(_identity_fn, out_shardings=sharding)(arr)
ValueError: Received incompatible devices for jitted computation. Got argument x of _identity_fn with shape float32[32,8,4096] and device ids [0, 4, 8, 12, 16, 20, 24, 28, 1, 5, 9, 13, 17, 21, 25, 29, 2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11, 15, 19, 23, 27, 31] on platform TPU and explicit output sharding with device ids [0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15, 16, 17, 20, 21, 18, 19, 22, 23, 24, 25, 28, 29, 26, 27, 30, 31] on platform TPU
Guessing without looking, do we also need to adjust the torch_serialization.py
in Levanter?
.config:
env:
XLA_FLAGS: "--xla_dump_to=/tmp/output_folder/xla_dumps --xla_dump_hlo_pass_re=.*"
LIBTPU_INIT_ARGS: --xla_tpu_impure_oom_fast_exit_threshold=-1
docker_repository: levanter
zone: us-west4-a
tpu_type: v5litepod-32
vm_image: "tpu-ubuntu2204-base"
capacity_type: preemptible
autodelete: false
subnetwork: "default"
You need to use jax.device_put, not jax.jit
Thanks Yash, with this change to torch_serialization.py
, things work for me:
- out = jax.jit(_identity_fn, out_shardings=sharding)(arr)
+ out = jax.device_put(arr, sharding)
@dlwh I'll send the one-line patch CL; let me know if I'm missing something obvious though!
lgtm!
622 fixes issue #609 for loading HF models into a sharded representation. But now when I try to serialize a PEFT model I'm getting a similar error as before:
I tried the simple thing of removing the sharding annotation entirely and just running
jax.array(input)
but that yields the (I guess expected) error since I'm assuming the original sharding is spread across multiple machines:I haven't yet reproduced as a test, but you can reproduce using the gsm8k example:
The script needs one patch to load models correctly (there's an error now if you try to load a model using a name the way the script used to). I'll clean up the fixes and send them as a separate PR, but for now:
https://github.com/stanford-crfm/levanter/compare/main...rjpower:levanter:multi-lora?expand=1