pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 454 forks source link

Test Script Not Utilizing All XLA Devices #7843

Open radna0 opened 1 month ago

radna0 commented 1 month ago

❓ Questions and Help

I'm running this official script here, but I only see two xla devices being used, xla:0 and xla:1, which I read that for TPU V2-8 and V3-8 Instance, you would have 4 processor for one worker? But does that mean I can not fully use all memory of all devices? Given this code here, there should be around 128gb in total memory?

import torch
import torch_xla
import torch_xla.core.xla_model as xm

torch_xla.devices()
devices = xm.get_xla_supported_devices()
print(f"Devices: {devices}")
total = {
 0: 0,
 1: 0
}
for device in devices:
        mem = round(xm.get_memory_info(device)["bytes_limit"] / 1e9, 2)
        t = torch.randn(torch.randint(1, 8, (1,)), 4, 144, 720, 1280).to(device)
        mem_used =  round(xm.get_memory_info(device)["bytes_used"] / 1e9, 2)
        total[0] += mem_used
        total[1] += mem
        print(f'Total TPU device: {device} memory: {mem_used} / {mem} GB')
        xm.mark_step()

print(f"Total TPU memory: {total[0]} / {total[1]} GB")
Devices: ['xla:0', 'xla:1', 'xla:2', 'xla:3', 'xla:4', 'xla:5', 'xla:6', 'xla:7']

Total TPU device: xla:0 memory: 2.12 / 16.62 GB

Total TPU device: xla:1 memory: 2.12 / 16.62 GB

Total TPU device: xla:2 memory: 2.12 / 16.62 GB

Total TPU device: xla:3 memory: 4.25 / 16.62 GB

Total TPU device: xla:4 memory: 8.49 / 16.62 GB

Total TPU device: xla:5 memory: 12.74 / 16.62 GB

Total TPU device: xla:6 memory: 6.37 / 16.62 GB

Total TPU device: xla:7 memory: 8.49 / 16.62 GB

I want to use all of this memory capacity with SPMD for example to shard and train a video diffusion transformer model, which requires sharding a tensor of 5 dimensions (b, c, f, h, w). Given that SPMD abstracts the complexity of coding using multiple devices and acts as one big device, it would be perfect for the job?

JackCaoG commented 1 month ago

take a look at spmd docs at https://pytorch.org/xla/release/r2.4/spmd.html.

If you want to do ddp with spmd, try https://github.com/pytorch/xla/blob/master/examples/data_parallel/train_resnet_spmd_data_parallel.py

If you want to do FSDP with spmd, try https://github.com/pytorch/xla/blob/master/examples/fsdp/train_decoder_only_fsdp_v2.py