pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.7k stars 22.27k forks source link

Using DTensor with device meshes that use different devices for input and output #126795

Open rzambre opened 4 months ago

rzambre commented 4 months ago

🐛 Describe the bug

What is the expected physical placement of the local tensors for the gpu_replica_tensor_redist DTensor object below? I would expect gpu_replica_tensor_redist to be on the GPU memory, just like gpu_replica_tensor_dist.

# torchrun --standalone --nnodes=1 --nproc-per-node=4 dtensor_example.py
import os
import torch
from torch.distributed._tensor import DTensor, Shard, Replicate, distribute_tensor, distribute_module, init_device_mesh

gpu_device_mesh = init_device_mesh("cuda", (4,))
cpu_device_mesh = init_device_mesh("cpu", (4,))

rowwise_placement=[Shard(0)]
replica_placement=[Replicate()]

big_tensor = torch.randn(4, 4)
cpu_rowwise_tensor        = distribute_tensor(big_tensor, device_mesh=cpu_device_mesh, placements=rowwise_placement)
gpu_replica_tensor_dist   = distribute_tensor(big_tensor, device_mesh=gpu_device_mesh, placements=replica_placement)

gpu_replica_tensor_redist = cpu_rowwise_tensor.redistribute(gpu_device_mesh, replica_placement)

If I run the code with the nvcr.io/nvidia/nemo:24.03.01.framework container and print the outputs, I do not see a cuda device for gpu_replica_tensor_redist (snippet of output below). cpu_rowwise_tensor also does not show a CUDA device, which hints at the possibility of gpu_replica_tensor_redist being on CPU memory instead of GPU memory.

root@batch-block4-1055:/workdir/torch-play# torchrun --nproc-per-node 4 --node-rank 0 ./redistribute_diff_mem_pools.py
Rank 1 (CPU rowwise):
DTensor(local_tensor=tensor([[ 0.2640, -0.4017,  1.0207,  1.5920]]), device_mesh=DeviceMesh([0, 1, 2, 3]), placements=(Shard(dim=0),))
Rank 2 (CPU rowwise):
DTensor(local_tensor=tensor([[ 0.7826, -1.3320, -1.1844, -1.2736]]), device_mesh=DeviceMesh([0, 1, 2, 3]), placements=(Shard(dim=0),))
Rank 0 (CPU rowwise):
DTensor(local_tensor=tensor([[ 0.3556, -0.4043, -0.2076, -1.3428]]), device_mesh=DeviceMesh([0, 1, 2, 3]), placements=(Shard(dim=0),))
Rank 3 (CPU rowwise):
DTensor(local_tensor=tensor([[ 0.1196, -1.8618, -0.7965, -0.1710]]), device_mesh=DeviceMesh([0, 1, 2, 3]), placements=(Shard(dim=0),))
…
…
Rank 1 (GPU replica):
DTensor(local_tensor=tensor([[ 0.3556, -0.4043, -0.2076, -1.3428],
        [ 0.2640, -0.4017,  1.0207,  1.5920],
        [ 0.7826, -1.3320, -1.1844, -1.2736],
        [ 0.1196, -1.8618, -0.7965, -0.1710]], device='cuda:1'), device_mesh=DeviceMesh([0, 1, 2, 3]), placements=(Replicate(),))
…
…
Rank 1 (GPU replica redistributed from CPU rowwise):
DTensor(local_tensor=AsyncCollectiveTensor(tensor([[ 0.3556, -0.4043, -0.2076, -1.3428],
        [ 0.2640, -0.4017,  1.0207,  1.5920],
        [ 0.7826, -1.3320, -1.1844, -1.2736],
        [ 0.1196, -1.8618, -0.7965, -0.1710]])), device_mesh=DeviceMesh([0, 1, 2, 3]), placements=(Replicate(),))

@wanchaol is this a bug, or are DTensor APIs not meant to be used with different memory types?

Versions

root@batch-block4-2143:/workspace# python collect_env.py Collecting environment information... PyTorch version: 2.3.0a0+ebedce2 Is debug build: False CUDA used to build PyTorch: 12.3 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: version 3.28.1 Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) Is CUDA available: True CUDA runtime version: 12.3.107 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB GPU 1: NVIDIA A100-SXM4-80GB GPU 2: NVIDIA A100-SXM4-80GB GPU 3: NVIDIA A100-SXM4-80GB GPU 4: NVIDIA A100-SXM4-80GB GPU 5: NVIDIA A100-SXM4-80GB GPU 6: NVIDIA A100-SXM4-80GB GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 535.129.03 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.0.0 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: AMD EPYC 7J13 64-Core Processor

Versions of relevant libraries: [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.24.4 [pip3] nvidia-pytriton==0.5.5 [pip3] onnx==1.15.0rc2 [pip3] open-clip-torch==2.24.0 [pip3] optree==0.10.0 [pip3] pytorch-lightning==2.2.2 [pip3] pytorch-quantization==2.1.2 [pip3] torch==2.3.0a0+ebedce2 [pip3] torch-tensorrt==2.3.0a0 [pip3] torchdata==0.7.1a0 [pip3] torchdiffeq==0.2.3 [pip3] torchmetrics==1.3.2 [pip3] torchsde==0.2.6 [pip3] torchtext==0.17.0a0 [pip3] torchvision==0.18.0a0 [pip3] triton==2.2.0+e28a256 [pip3] tritonclient==2.44.0 [conda] Could not collect

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @msaroufim

wanchaol commented 4 months ago

@rzambre I think this is a bug, let me create a fix

rzambre commented 4 months ago

Thank you for confirming @wanchaol

wanchaol commented 4 months ago

@rzambre My fix would basicallly be: if the cpu_rowwise_tensor is redistributing to a GPU device mesh, we'll move the DTensor to a GPU device first, then do redistribute. The question I have is:

What's your expected behavior for the backward of redistribute? I think the backward would organically do a to_copy backward, where it would move the gradients from GPU device to CPU device, similar to a CPU requires_grad Tensor call a cpu_tensor.to("cuda").sum().backward() would do. Would like to check with you and see if it's matching your expectation

rzambre commented 4 months ago

@wanchaol could you please clarify what you mean by "backward of redistribute"? I may be missing some context for the use case you are referring to

rzambre commented 4 months ago

My fix would basicallly be: if the cpu_rowwise_tensor is redistributing to a GPU device mesh, we'll move the DTensor to a GPU device first, then do redistribute.

@wanchaol ideally, we'd want the redistribute to translate into an AllGather that takes in input and output buffers being on different device types. Some backends like NCCL do support different buffer types for src and dest as long as they are accessible from the GPU. Does the ProcessGroup API support src and dst buffers to be of different device types? If it does, we could prevent a manual copy