Closed alexanderswerdlow closed 3 months ago
I am wondering how did you conclude that this is due to the RNG state? The rng state in pytorch/xla is a scalar(s64), it is in https://github.com/pytorch/xla/blob/1ed2626725c47fcf0979e0790a219bf3b9de4ebd/torch_xla/csrc/xla_graph_executor.cpp#L148-L151
Do you know what bf16[32,384,48385]
is corresponding to in your model code? cc @alanwaketan
Sorry for the confusion, the 1st/2nd are definitely my code [logits]. The 3rd/4th and many subsequent at least appear to be from the rng. It was weird to me as well (not very knowledgeable about RNGs but thought it would just be a single value state as you indicated).
I was able to fit things in memory with some fixes to my sharding, namely it seems like I need to do:
xs.mark_sharding(real_output, mesh, (("fsdp", "model"), None, None))
and
input_sharding=xs.ShardingSpec(xs.get_global_mesh(), (('fsdp', 'model'), None))
However, I imagine that the RNG tensors shown above are still present on each device, although I didn't verify this myself. I don't have anything in my user code that would obviously cause this, besides of course some random() calls. I tried setting/not setting the seeds manually to no effect. My main suspicion is that I have tensor random() calls outside of my wrapped model, but to me that still seems like weird behavior.
I can't find anything in my code that is u8[512,384,2048]
(specifically the last dimension), and there seemed to be a lot of these generated which doesn't match anything I have. Could random() tensor operations be problematic outside of a wrapped model?
oh you meant the rng-bit-generator
. @alanwaketan @jonb377 I remember we hit this issue before and there is a XLA flag for this
Was it export LIBTPU_INIT_ARGS="--xla_tpu_spmd_rng_bit_generator_unsafe=1"
by chance? I believe I was already running with that one, but if it's another one I'd love to know😅
yea I believe that's the one. Hmm if that's the case I don't have any other immediate solution besides reducing the batch size. You are using 512 on v4-32(v4-32 has 16 devices), so the per-device batch size is 32. I think even if you cut it in half and each device has 16 batch size the MFU should still be reasonable.
Ah got it—I've been getting ~35% MFU (both with my own model and with train_decoder_only_fsdp_v2
) which is reasonable (although seems lower than a lot of the quoted numbers).
Also this is not the main issue, but does the sharding config above look right w/1 process per worker? I haven't been able to find many examples of multi-host training (besides the transformers fork) and wanted to make sure since it's difficult to verify what's going on w/SPMD.
yea multi-host and multi-device is generally the same for TPU(as long as you don't use multiple pod). Doing FSDPv2(shard at 1d mesh per layer) is usually the best for MFU. 35% seems a bit low, on v4 we expect ~50% for the decoder only model. maybe try https://github.com/pytorch/xla/blob/master/examples/flash_attention/train_decoder_only_flash_attention_fsdp_v2.py? which uses flash attention and should also help to reduce the memory usage.
Got it—should have mentioned that from my testing (with the files above), flash-attention is....~15% slower than w/o flash-attention. Only have access to TPU v4s so assuming part of the reason is optimizations for v5. The reason why I was asking is because when I used the same config for multi-host, I found it to be very slow, until I replicated per host (which makes perfect sense).
hmm this is weird. I can try to take a profile myself and check after I am back next week. If you also already do the gradient checkpointing then I think that's most of the optimization we applied. On v4 we were also able to get 50% MFU for both llama2 and llama3. @alanwaketan do you have the latest repro for llama that we can let user to run?
Oh also I remember on v4 there is a config to enable the async all-gather, do you remember what flag is?
All the flags are set by default for v4.
@JackCaoG, I wanted to follow up on this since I have been getting what seems to be related errors; When I set the batch size to a reasonable value (far below that works on a single device), I get a hang and then the following OOM several steps (~10) in:
Output size 24.00M; shares 0B with arguments.
Program hbm requirement 31.66G:
global 68.03M
scoped 4.75M
HLO temp 31.21G (100.0% utilization: Unpadded (28.96G) Padded (28.96G), 7.2% fragmentation (2.25G))
overlays 379.70M
Largest program allocations in hbm:
1. Size: 379.70M
XLA label: overlays
Allocation type: overlays
==========================
2. Size: 96.00M
Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
Unpadded size: 96.00M
XLA label: rng-bit-generator.530 = rng-bit-generator(custom-call.570), algorithm=rng_default
Allocation type: HLO temp
==========================
3. Size: 96.00M
Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
Unpadded size: 96.00M
XLA label: rng-bit-generator.547 = rng-bit-generator(custom-call.587), algorithm=rng_default
Allocation type: HLO temp
==========================
4. Size: 96.00M
Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
Unpadded size: 96.00M
XLA label: rng-bit-generator.502 = rng-bit-generator(custom-call.542), algorithm=rng_default
Allocation type: HLO temp
==========================
5. Size: 96.00M
Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
Unpadded size: 96.00M
XLA label: rng-bit-generator.485 = rng-bit-generator(custom-call.523), algorithm=rng_default
Allocation type: HLO temp
==========================
6. Size: 96.00M
Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
Unpadded size: 96.00M
XLA label: rng-bit-generator.481 = rng-bit-generator(custom-call.519), algorithm=rng_default
Allocation type: HLO temp
==========================
7. Size: 96.00M
Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
Unpadded size: 96.00M
XLA label: rng-bit-generator.503 = rng-bit-generator(custom-call.543), algorithm=rng_default
Allocation type: HLO temp
==========================
8. Size: 96.00M
Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
Unpadded size: 96.00M
XLA label: rng-bit-generator.504 = rng-bit-generator(custom-call.544), algorithm=rng_default
Allocation type: HLO temp
==========================
9. Size: 96.00M
Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
Unpadded size: 96.00M
XLA label: rng-bit-generator.487 = rng-bit-generator(custom-call.525), algorithm=rng_default
Allocation type: HLO temp
==========================
10. Size: 96.00M
Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
Unpadded size: 96.00M
XLA label: rng-bit-generator.479 = rng-bit-generator(custom-call.517), algorithm=rng_default
Allocation type: HLO temp
==========================
...[Truncated for brevity]
And finally if I set bs=1, my code similarly hangs several steps in but with a slightly different error:
F0811 00:44:43.435031 1034609 hlo_instruction.h:1877] Check failed: has_sharding()
*** Check failure stack trace: ***
@ 0x7f45614bc6e4 (unknown)
@ 0x7f45614bc204 (unknown)
@ 0x7f45614bca49 (unknown)
@ 0x7f455a6d2b4e (unknown)
@ 0x7f455a6ddbc5 (unknown)
@ 0x7f455a6c5dbd (unknown)
@ 0x7f455a6c5819 (unknown)
@ 0x7f455a6c5dbd (unknown)
@ 0x7f455a6c5819 (unknown)
@ 0x7f45585b93cf (unknown)
@ 0x7f45585b145d (unknown)
@ 0x7f4559b78a93 (unknown)
@ 0x7f45610c5b1b (unknown)
@ 0x7f45610cc104 (unknown)
@ 0x7f45610d4da5 (unknown)
@ 0x7f456138f883 (unknown)
@ 0x7f47fa894ac3 (unknown)
https://symbolize.stripped_domain/r/?trace=7f45614bc6e4,7f45614bc203,7f45614bca48,7f455a6d2b4d,7f455a6ddbc4,7f455a6c5dbc,7f455a6c5818,7f455a6c5dbc,7f455a6c5818,7f45585b93ce,7f45585b145c,7f4559b78a92,7f45610c5b1a,7f45610cc103,7f45610d4da4,7f456138f882,7f47fa894ac2&map=
https://symbolize.stripped_domain/r/?trace=7f47fa8969fc,7f47fa84251f&map=
*** SIGABRT received by PID 1033766 (TID 1034609) on cpu 204 from PID 1033766; ***
E0811 00:44:43.441233 1034609 coredump_hook.cc:316] RAW: Remote crash data gathering hook invoked.
E0811 00:44:43.441243 1034609 coredump_hook.cc:355] RAW: Skipping coredump since rlimit was 0 at process start.
E0811 00:44:43.441253 1034609 client.cc:269] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0811 00:44:43.441258 1034609 coredump_hook.cc:411] RAW: Sending fingerprint to remote end.
E0811 00:44:43.441277 1034609 coredump_hook.cc:420] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0811 00:44:43.441283 1034609 coredump_hook.cc:472] RAW: Dumping core locally.
F0811 00:44:43.435031 1034609 hlo_instruction.h:1877] Check failed: has_sharding()
E0811 00:44:43.913293 1034609 process_state.cc:805] RAW: Raising signal 6 with default behavio
I tried setting PT_XLA_DEBUG
, PT_XLA_DEBUG_LEVEL
, XLA_IR_DEBUG
, XLA_HLO_DEBUG
but nothing obvious is noted. At first I thought the error was due to the parallel device loader but it doesn't seem this is the case. My program is very simple and no logging occurs before this happens, just forward/backward of a wrapped model.
Is it a decoder only model? If so can you give me an example config https://github.com/pytorch/xla/blob/8f7948848570800d4eb2a2c2a0217f80e784496f/examples/decoder_only_model.py#L12-L19 and batch size you are using? I can try to repo it on my end.
@JackCaoG It is, I will try to repro it using the built-in decoder model. I've found that it only occurs when training on multiple workers (e.g., fails on TPUv4-32 but succeeds on a TPUv4-16) at larger batch sizes, although notably there is OOM on a v32 slice where the same batch size is handled by a single v8 slice. Is there perhaps a complete sharding example with data parallel over all devices and parameter sharding per node (e.g., Hybrid Zero2 in DeepSpeed/PyTorch terminology) somewhere I can use to make sure it's not an issue in my side?
@bhavya01 what's the command/repo you used to benchmark the llama3 on single and multi host?
I used https://github.com/pytorch-tpu/transformers/blob/alanwaketan/flash_attention/examples/pytorch/language-modeling/run_clm.py
to benchmark llama3
This script shards only on the batch size
git clone -b alanwaketan/flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets
pip3 install evaluate
pip3 install scikit-learn
pip3 install accelerate
export JAX_PLATFORMS=tpu,cpu
export PJRT_DEVICE=TPU
export XLA_IR_DEBUG=1
export XLA_HLO_DEBUG=1
export PROFILE_EPOCH=0
export PROFILE_STEP=3
export PROFILE_DURATION_MS=40000
export PROFILE_LOGDIR=/home/bbahl/
export XLA_USE_SPMD=1
python3 examples/pytorch/language-modeling/run_clm.py \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size 8 \
--do_train \
--output_dir /home/bbahl/tmp/test-clm \
--overwrite_output_dir \
--config_name /home/bbahl/config-8B.json \
--cache_dir /home/bbahl/cache \
--tokenizer_name meta-llama/Meta-Llama-3-8B \
--block_size 8192 \
--optim adafactor \
--save_strategy no \
--logging_strategy no \
--fsdp "full_shard" \
--fsdp_config /home/bbahl/fsdp_config.json \
--torch_dtype bfloat16 \
--dataloader_drop_last yes \
--flash_attention \
--max_steps 10
config-8B.json
{
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 8192,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 500000.0,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.40.0.dev0",
"use_cache": false,
"vocab_size": 128256
}
fsdp_config.json
{
"fsdp_transformer_layer_cls_to_wrap": [
"LlamaDecoderLayer"
],
"xla": true,
"xla_fsdp_v2": true,
"xla_fsdp_grad_ckpt": true
}
@bhavya01 @JackCaoG Thanks! I was able to reproduce it with the profiler showing 50% MFU with no extra flags and I'll close this issue.
In my case, I found I needed to do the following to get Hybrid-Zero2, which was necessary in my case with a model that's too big for DDP but too small for full FSDP to make sense (it seems there is too much time spent on communication).
xr.use_spmd()
num_global_devices = xr.global_runtime_device_count()
num_local_devices = tpu.num_available_devices()
num_nodes = num_global_devices // num_local_devices
spmd_mesh = xs.Mesh(np.array(range(num_global_devices)), (num_nodes, num_local_devices, 1), ('replica', 'fsdp', 'model'))
xs.set_global_mesh(spmd_mesh)
...
sharding_spec = xs.ShardingSpec(xs.get_global_mesh(), (('replica', 'fsdp'), None)) # Pass to MpDeviceLoader
xs.mark_sharding(output, mesh, (('replica', 'fsdp'), None, None)) # Mark output
I'll add a few obstacles I faced (hopefully helpful to others or could be clarified in docs):
apply_xla_patch_to_nn_linear
is necessary or helpful in some FSDP cases but it's unclear exactly when. I found that using it caused some issues with mixed-precision and I had to use a lot of manual casts back to BF16 when having it enabled. In some quick testing, patching/not didn't make a big difference in my case.
🐛 Bug
When using HF accelerate w/SPMD, I get OOM due to the rng states being duplicated and not sharded. My batch size is 512 in this example. I tried multiple mesh setups:
I am wrapping my model with FSDPv2 and shard my output as:
And I specify input_sharding to the dataloader:
Error:
Expected behavior
An option to not store/sync this state, shard it if possible, or simply disable reproducibility to save memory.
Environment