pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 461 forks source link

OOM due to RNG states w/SPMD #7813

Closed alexanderswerdlow closed 1 month ago

alexanderswerdlow commented 1 month ago

🐛 Bug

When using HF accelerate w/SPMD, I get OOM due to the rng states being duplicated and not sharded. My batch size is 512 in this example. I tried multiple mesh setups:

axis_names = ('fsdp', 'model')
xs.set_global_mesh(xs.Mesh(np.array(range(num_global_devices)), (16, 1), axis_names))

# Also occurs with
xs.set_global_mesh(xs.Mesh(np.array(range(num_global_devices)), (4, 4), axis_names))

I am wrapping my model with FSDPv2 and shard my output as:

    def shard_output(output, mesh):
        real_output = None
        if isinstance(output, torch.Tensor):
            real_output = output
        elif isinstance(output, tuple):
            real_output = output[0]
        if real_output is None:
            raise ValueError("Something went wrong, the output of the model shouldn't be `None`")

        xs.mark_sharding(real_output, mesh, ("fsdp", None, None))

And I specify input_sharding to the dataloader:

input_sharding=xs.ShardingSpec(xs.get_global_mesh(), ('fsdp', None))

Error:

2024-08-07 00:36:47.191430: I torch_xla/csrc/device.cpp:81] Using SPMD virtual device optimization
Steps:   0%|                                                                                  | 1/1000000000 [00:02<706305:02:22,  2.54s/it]2024-08-07 00:36:47.916654: I torch_xla/csrc/xla_graph_executor.cpp:422] 725 live tensors: devices=()
2024-08-07 00:36:47.916681: I torch_xla/csrc/xla_graph_executor.cpp:400] Trying to sync the value of 725 tensor(s)
2024-08-07 00:36:47.917286: I torch_xla/csrc/xla_graph_executor.cpp:714] Tensors graph hash d2f66a8798ab062f97f99f7e13a21b on device SPMD:0
2024-08-07 00:36:47.927172: I torch_xla/csrc/xla_graph_executor.cpp:1515] Parameter sequence graph hash df3ea0926f078250dadb9eade3852d1b
2024-08-07 00:36:48.092352: I torch_xla/csrc/xla_graph_executor.cpp:1419] Compiling IR graph hash df3ea0926f078250dadb9eade3852d1b on device SPMD:0 ...
2024-08-07 00:36:48.092438: I torch_xla/csrc/runtime/pjrt_computation_client.cc:564] Auto SPMD partitioning disabled.
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 42.40G of 31.75G hbm. Exceeded hbm capacity by 10.66G.

Total hbm usage >= 42.66G:
    reserved        259.00M
    program          42.40G
    arguments            0B

Output size 0B; shares 0B with arguments.

Program hbm requirement 42.40G:
    global           576.5K
    HLO temp         42.40G (100.0% utilization: Unpadded (41.93G) Padded (41.93G), 1.1% fragmentation (478.37M))

  Largest program allocations in hbm:

  1. Size: 1.11G
     Shape: bf16[32,384,48385]{1,2,0:T(8,128)(2,1)}
     Unpadded size: 1.11G
     Extra memory due to padding: 168.0K (1.0x expansion)
     XLA label: fusion.4130 = fusion(fusion.4131, fusion.4140, copy-done.1383), kind=kLoop, calls=fused_computation.3400
     Allocation type: HLO temp
     ==========================

  2. Size: 1.11G
     Shape: bf16[32,384,48385]{1,2,0:T(8,128)(2,1)}
     Unpadded size: 1.11G
     Extra memory due to padding: 168.0K (1.0x expansion)
     XLA label: copy.3927 = copy(fusion.4130)
     Allocation type: HLO temp
     ==========================

  3. Size: 384.00M
     Shape: u8[512,384,2048]{2,1,0:T(8,128)(4,1)}
     Unpadded size: 384.00M
     XLA label: rng-bit-generator.52 = rng-bit-generator(custom-call.75), algorithm=rng_default
     Allocation type: HLO temp
     ==========================

  4. Size: 384.00M
     Shape: u8[512,384,2048]{2,1,0:T(8,128)(4,1)}
     Unpadded size: 384.00M
     XLA label: rng-bit-generator.72 = rng-bit-generator(custom-call.95), algorithm=rng_default
     Allocation type: HLO temp
     ==========================

  5. Size: 384.00M
     Shape: u8[512,384,2048]{2,1,0:T(8,128)(4,1)}
     Unpadded size: 384.00M
     XLA label: rng-bit-generator.55 = rng-bit-generator(custom-call.78), algorithm=rng_default
     Allocation type: HLO temp
     ==========================
...

Expected behavior

An option to not store/sync this state, shard it if possible, or simply disable reproducibility to save memory.

Environment

JackCaoG commented 1 month ago

I am wondering how did you conclude that this is due to the RNG state? The rng state in pytorch/xla is a scalar(s64), it is in https://github.com/pytorch/xla/blob/1ed2626725c47fcf0979e0790a219bf3b9de4ebd/torch_xla/csrc/xla_graph_executor.cpp#L148-L151

Do you know what bf16[32,384,48385] is corresponding to in your model code? cc @alanwaketan

alexanderswerdlow commented 1 month ago

Sorry for the confusion, the 1st/2nd are definitely my code [logits]. The 3rd/4th and many subsequent at least appear to be from the rng. It was weird to me as well (not very knowledgeable about RNGs but thought it would just be a single value state as you indicated).

I was able to fit things in memory with some fixes to my sharding, namely it seems like I need to do:

xs.mark_sharding(real_output, mesh, (("fsdp", "model"), None, None))
and
input_sharding=xs.ShardingSpec(xs.get_global_mesh(), (('fsdp', 'model'), None))

However, I imagine that the RNG tensors shown above are still present on each device, although I didn't verify this myself. I don't have anything in my user code that would obviously cause this, besides of course some random() calls. I tried setting/not setting the seeds manually to no effect. My main suspicion is that I have tensor random() calls outside of my wrapped model, but to me that still seems like weird behavior.

I can't find anything in my code that is u8[512,384,2048] (specifically the last dimension), and there seemed to be a lot of these generated which doesn't match anything I have. Could random() tensor operations be problematic outside of a wrapped model?

JackCaoG commented 1 month ago

oh you meant the rng-bit-generator. @alanwaketan @jonb377 I remember we hit this issue before and there is a XLA flag for this

alexanderswerdlow commented 1 month ago

Was it export LIBTPU_INIT_ARGS="--xla_tpu_spmd_rng_bit_generator_unsafe=1" by chance? I believe I was already running with that one, but if it's another one I'd love to know😅

JackCaoG commented 1 month ago

yea I believe that's the one. Hmm if that's the case I don't have any other immediate solution besides reducing the batch size. You are using 512 on v4-32(v4-32 has 16 devices), so the per-device batch size is 32. I think even if you cut it in half and each device has 16 batch size the MFU should still be reasonable.

alexanderswerdlow commented 1 month ago

Ah got it—I've been getting ~35% MFU (both with my own model and with train_decoder_only_fsdp_v2) which is reasonable (although seems lower than a lot of the quoted numbers).

Also this is not the main issue, but does the sharding config above look right w/1 process per worker? I haven't been able to find many examples of multi-host training (besides the transformers fork) and wanted to make sure since it's difficult to verify what's going on w/SPMD.

JackCaoG commented 1 month ago

yea multi-host and multi-device is generally the same for TPU(as long as you don't use multiple pod). Doing FSDPv2(shard at 1d mesh per layer) is usually the best for MFU. 35% seems a bit low, on v4 we expect ~50% for the decoder only model. maybe try https://github.com/pytorch/xla/blob/master/examples/flash_attention/train_decoder_only_flash_attention_fsdp_v2.py? which uses flash attention and should also help to reduce the memory usage.

alexanderswerdlow commented 1 month ago

Got it—should have mentioned that from my testing (with the files above), flash-attention is....~15% slower than w/o flash-attention. Only have access to TPU v4s so assuming part of the reason is optimizations for v5. The reason why I was asking is because when I used the same config for multi-host, I found it to be very slow, until I replicated per host (which makes perfect sense).

JackCaoG commented 1 month ago

hmm this is weird. I can try to take a profile myself and check after I am back next week. If you also already do the gradient checkpointing then I think that's most of the optimization we applied. On v4 we were also able to get 50% MFU for both llama2 and llama3. @alanwaketan do you have the latest repro for llama that we can let user to run?

Oh also I remember on v4 there is a config to enable the async all-gather, do you remember what flag is?

alanwaketan commented 1 month ago

All the flags are set by default for v4.

alexanderswerdlow commented 1 month ago

@JackCaoG, I wanted to follow up on this since I have been getting what seems to be related errors; When I set the batch size to a reasonable value (far below that works on a single device), I get a hang and then the following OOM several steps (~10) in:


Output size 24.00M; shares 0B with arguments.

Program hbm requirement 31.66G:
    global           68.03M
    scoped            4.75M
    HLO temp         31.21G (100.0% utilization: Unpadded (28.96G) Padded (28.96G), 7.2% fragmentation (2.25G))
    overlays        379.70M

  Largest program allocations in hbm:

  1. Size: 379.70M
     XLA label: overlays
     Allocation type: overlays
     ==========================

  2. Size: 96.00M
     Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
     Unpadded size: 96.00M
     XLA label: rng-bit-generator.530 = rng-bit-generator(custom-call.570), algorithm=rng_default
     Allocation type: HLO temp
     ==========================

  3. Size: 96.00M
     Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
     Unpadded size: 96.00M
     XLA label: rng-bit-generator.547 = rng-bit-generator(custom-call.587), algorithm=rng_default
     Allocation type: HLO temp
     ==========================

  4. Size: 96.00M
     Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
     Unpadded size: 96.00M
     XLA label: rng-bit-generator.502 = rng-bit-generator(custom-call.542), algorithm=rng_default
     Allocation type: HLO temp
     ==========================

  5. Size: 96.00M
     Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
     Unpadded size: 96.00M
     XLA label: rng-bit-generator.485 = rng-bit-generator(custom-call.523), algorithm=rng_default
     Allocation type: HLO temp
     ==========================

  6. Size: 96.00M
     Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
     Unpadded size: 96.00M
     XLA label: rng-bit-generator.481 = rng-bit-generator(custom-call.519), algorithm=rng_default
     Allocation type: HLO temp
     ==========================

  7. Size: 96.00M
     Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
     Unpadded size: 96.00M
     XLA label: rng-bit-generator.503 = rng-bit-generator(custom-call.543), algorithm=rng_default
     Allocation type: HLO temp
     ==========================

  8. Size: 96.00M
     Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
     Unpadded size: 96.00M
     XLA label: rng-bit-generator.504 = rng-bit-generator(custom-call.544), algorithm=rng_default
     Allocation type: HLO temp
     ==========================

  9. Size: 96.00M
     Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
     Unpadded size: 96.00M
     XLA label: rng-bit-generator.487 = rng-bit-generator(custom-call.525), algorithm=rng_default
     Allocation type: HLO temp
     ==========================

  10. Size: 96.00M
     Shape: u8[256,384,1024]{2,1,0:T(8,128)(4,1)}
     Unpadded size: 96.00M
     XLA label: rng-bit-generator.479 = rng-bit-generator(custom-call.517), algorithm=rng_default
     Allocation type: HLO temp
     ==========================
...[Truncated for brevity]

And finally if I set bs=1, my code similarly hangs several steps in but with a slightly different error:

F0811 00:44:43.435031 1034609 hlo_instruction.h:1877] Check failed: has_sharding()
*** Check failure stack trace: ***
    @     0x7f45614bc6e4  (unknown)
    @     0x7f45614bc204  (unknown)
    @     0x7f45614bca49  (unknown)
    @     0x7f455a6d2b4e  (unknown)
    @     0x7f455a6ddbc5  (unknown)
    @     0x7f455a6c5dbd  (unknown)
    @     0x7f455a6c5819  (unknown)
    @     0x7f455a6c5dbd  (unknown)
    @     0x7f455a6c5819  (unknown)
    @     0x7f45585b93cf  (unknown)
    @     0x7f45585b145d  (unknown)
    @     0x7f4559b78a93  (unknown)
    @     0x7f45610c5b1b  (unknown)
    @     0x7f45610cc104  (unknown)
    @     0x7f45610d4da5  (unknown)
    @     0x7f456138f883  (unknown)
    @     0x7f47fa894ac3  (unknown)
https://symbolize.stripped_domain/r/?trace=7f45614bc6e4,7f45614bc203,7f45614bca48,7f455a6d2b4d,7f455a6ddbc4,7f455a6c5dbc,7f455a6c5818,7f455a6c5dbc,7f455a6c5818,7f45585b93ce,7f45585b145c,7f4559b78a92,7f45610c5b1a,7f45610cc103,7f45610d4da4,7f456138f882,7f47fa894ac2&map=
https://symbolize.stripped_domain/r/?trace=7f47fa8969fc,7f47fa84251f&map=
*** SIGABRT received by PID 1033766 (TID 1034609) on cpu 204 from PID 1033766; ***
E0811 00:44:43.441233 1034609 coredump_hook.cc:316] RAW: Remote crash data gathering hook invoked.
E0811 00:44:43.441243 1034609 coredump_hook.cc:355] RAW: Skipping coredump since rlimit was 0 at process start.
E0811 00:44:43.441253 1034609 client.cc:269] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0811 00:44:43.441258 1034609 coredump_hook.cc:411] RAW: Sending fingerprint to remote end.
E0811 00:44:43.441277 1034609 coredump_hook.cc:420] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0811 00:44:43.441283 1034609 coredump_hook.cc:472] RAW: Dumping core locally.
F0811 00:44:43.435031 1034609 hlo_instruction.h:1877] Check failed: has_sharding()
E0811 00:44:43.913293 1034609 process_state.cc:805] RAW: Raising signal 6 with default behavio

I tried setting PT_XLA_DEBUG, PT_XLA_DEBUG_LEVEL, XLA_IR_DEBUG, XLA_HLO_DEBUG but nothing obvious is noted. At first I thought the error was due to the parallel device loader but it doesn't seem this is the case. My program is very simple and no logging occurs before this happens, just forward/backward of a wrapped model.

JackCaoG commented 1 month ago

Is it a decoder only model? If so can you give me an example config https://github.com/pytorch/xla/blob/8f7948848570800d4eb2a2c2a0217f80e784496f/examples/decoder_only_model.py#L12-L19 and batch size you are using? I can try to repo it on my end.

alexanderswerdlow commented 1 month ago

@JackCaoG It is, I will try to repro it using the built-in decoder model. I've found that it only occurs when training on multiple workers (e.g., fails on TPUv4-32 but succeeds on a TPUv4-16) at larger batch sizes, although notably there is OOM on a v32 slice where the same batch size is handled by a single v8 slice. Is there perhaps a complete sharding example with data parallel over all devices and parameter sharding per node (e.g., Hybrid Zero2 in DeepSpeed/PyTorch terminology) somewhere I can use to make sure it's not an issue in my side?

JackCaoG commented 1 month ago

@bhavya01 what's the command/repo you used to benchmark the llama3 on single and multi host?

bhavya01 commented 1 month ago

I used https://github.com/pytorch-tpu/transformers/blob/alanwaketan/flash_attention/examples/pytorch/language-modeling/run_clm.py to benchmark llama3

This script shards only on the batch size

git clone -b alanwaketan/flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets
pip3 install evaluate
pip3 install scikit-learn
pip3 install accelerate

export JAX_PLATFORMS=tpu,cpu
export PJRT_DEVICE=TPU
export XLA_IR_DEBUG=1
export XLA_HLO_DEBUG=1
export PROFILE_EPOCH=0
export PROFILE_STEP=3
export PROFILE_DURATION_MS=40000
export PROFILE_LOGDIR=/home/bbahl/
export XLA_USE_SPMD=1
python3 examples/pytorch/language-modeling/run_clm.py \
  --dataset_name wikitext \
  --dataset_config_name wikitext-2-raw-v1 \
  --per_device_train_batch_size 8 \
  --do_train \
  --output_dir /home/bbahl/tmp/test-clm \
  --overwrite_output_dir \
  --config_name /home/bbahl/config-8B.json \
  --cache_dir /home/bbahl/cache \
  --tokenizer_name meta-llama/Meta-Llama-3-8B \
  --block_size 8192 \
  --optim adafactor \
  --save_strategy no \
  --logging_strategy no \
  --fsdp "full_shard" \
  --fsdp_config /home/bbahl/fsdp_config.json \
  --torch_dtype bfloat16 \
  --dataloader_drop_last yes \
  --flash_attention \
  --max_steps 10

config-8B.json

{
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.40.0.dev0",
  "use_cache": false,
  "vocab_size": 128256
}

fsdp_config.json

{
    "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
    ],
    "xla": true,
    "xla_fsdp_v2": true,
    "xla_fsdp_grad_ckpt": true
}
alexanderswerdlow commented 1 month ago

@bhavya01 @JackCaoG Thanks! I was able to reproduce it with the profiler showing 50% MFU with no extra flags and I'll close this issue.

In my case, I found I needed to do the following to get Hybrid-Zero2, which was necessary in my case with a model that's too big for DDP but too small for full FSDP to make sense (it seems there is too much time spent on communication).

xr.use_spmd()
num_global_devices = xr.global_runtime_device_count()
num_local_devices = tpu.num_available_devices()
num_nodes = num_global_devices // num_local_devices
spmd_mesh = xs.Mesh(np.array(range(num_global_devices)), (num_nodes, num_local_devices, 1), ('replica', 'fsdp', 'model'))
xs.set_global_mesh(spmd_mesh)
...
sharding_spec = xs.ShardingSpec(xs.get_global_mesh(), (('replica', 'fsdp'), None)) # Pass to MpDeviceLoader
xs.mark_sharding(output, mesh, (('replica', 'fsdp'), None, None)) # Mark output

I'll add a few obstacles I faced (hopefully helpful to others or could be clarified in docs):