Open kvablack opened 8 months ago
Hi - I'm having trouble understanding your question. It sounds like you're comparing two implementations, but you've only shown us one implementation. Could you edit your question to show the code for both approaches?
@jakevdp sorry, I've edited my question to hopefully make things more clear. The only difference between the two implementations is the addition of the 8 lines indicated (the shard_map itself and the corresponding pmean).
Thanks! Assigning to @yashk2810, who might have some insights here.
Don't you need to jnp.mean
for the jit version (without shard_map)?
@yashk2810 The jnp.mean
happens inside the loss function (a scalar is returned).
I don't see that loss function :)
Can you create a minimal reproducer that we can run?
Sure thing, here's my repro. With the shard_map version, I get 1.09 s/it, and with no shard_map, I get 2.95 s/it. This is on a v4-8 TPU VM.
Hey -- sorry for the late reply, can you try with the latest jax and jaxlib version? (or better try with nightly)
Also can you tell me what TPU you were using? TPUv4 but how many devices?
This was a v4-8 VM (smallest you can get, I think). I no longer have easy access to TPUs, but I replicated the issue with jax[cuda12]==0.4.33
on an 8xH100 DGX machine. With no shard map, I get 1.10s/it
, and with shard map, I get 1.70it/s
.
Description
I'm trying to scale up some transformer training (currently at ~400m params), and as such I've been playing around with various ways to save memory and improve performance. On a whim, I tried replacing my
jax.jit(in_shardings=..., out_shardings=...)
setup for data parallelism withjax.experimental.shard_map
, as so:and I immediately saw a 2.8x (!) speedup. The reason why this is a problem is because I would like to move on to more advanced parallelism techniques (tensor parallel, fully-sharded data parallel, etc) but it seems like it would be prohibitively difficult to write these manually using
shard_map
. However, if I continue using pjit's automatic partitioning, I worry that I'm leaving a bunch of performance on the table. I would think the automatic partitioner would be able to produce code with more or less equal performance in this very simple case.Here are the debugging steps I've tried so far:
jax.debug.inspect_array_sharding
to look at the sharding of intermediate activations, and they all looked correct (fully sharded along the DP axis)custom_vjp
to look at shardings during the backward pass, and they also looked correctI've attached the HLO below. I would really appreciate any guidance on this, thanks!
no-shard-map.txt shard-map.txt
What jax/jaxlib version are you using?
0.4.23
Which accelerator(s) are you using?
TPUv4
Additional system info?
Python 3.10.12, tpu-vm-v4-base
NVIDIA GPU info
No response