pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

A large number of Tensors (>8000) in the graph will trigger an spmd sharding error #7161

Closed mars1248 closed 4 months ago

mars1248 commented 5 months ago

šŸ› Bug

To Reproduce

Steps to reproduce the behavior: test.sh

rm -rf ./hlo_logs
#export XLA_FLAGS="--xla_gpu_enable_async_all_gather=true \

#export XLA_FLAGS="--xla_dump_to=./hlo_logs"
export XLA_FLAGS="--xla_dump_to=./hlo_logs \
    --xla_gpu_enable_analytical_latency_estimator=true \
    --xla_cpu_enable_fast_math=false \
    --xla_gpu_simplify_all_fp_conversions=false \
    --xla_gpu_force_compilation_parallelism=64  \
    --xla_gpu_enable_pipelined_collectives=true \
    --xla_gpu_enable_pipelined_all_reduce=true \
    --xla_gpu_enable_async_collectives=true \
    --xla_disable_hlo_passes=post-scheduling-passes,gpu-schedule-postprocessing \
    --xla_gpu_enable_triton_gemm=false \
"

export PJRT_ALLOCATOR_PREALLOCATE=false
export PJRT_ALLOCATOR_FRACTION=0.75
export PJRT_ALLOCATOR_CUDA_ASYNC=false
export PT_XLA_DEBUG=1
#export TF_CPP_MIN_LOG_LEVEL=0
#export TF_CPP_VMODULE="lazy_graph_executor=4,xla_graph_executor=5,nccl_collective_thunk=5,gpu_executable=5,gpu_compiler=5,service=5,collectives=5,xla_graph_executor=5,pjrt_computation_client=5,pjrt_stream_executor_client=5"
export XLA_IR_DEBUG=1
export XLA_HLO_DEBUG=1
export XLA_SAVE_TENSORS_FILE=debug.txt
#XLA_USE_SPMD=1 \
export GPU_NUM_DEVICES=8 \
export PJRT_DEVICE=CUDA \
#CUDA_VISIBLE_DEVICES=1,2,3,4 \
#python test_activation_local.py
python test_multi_param_layer.py

test_multi_param_layer.py

from typing import Dict, List, Optional, Tuple, Union
import torch
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharding import Mesh
from torch_xla.amp import autocast, GradScaler
import numpy as np
import torch.optim as optim
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
import transformers
import torch_xla.debug.profiler as xp
from torch_xla.amp import syncfree
import time
import os
import sys
import math
from torch import nn
from torch.nn import Linear
from torch.autograd import Function
import torch.nn.functional as F
from torch.optim.adamw import AdamW

#device = "cuda"
#device = xm.xla_device()
xr.use_spmd()
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
# mesh shape will be (2,2) in this example
mesh_shape = (num_devices // 1, 1)

#mesh_shape = (2, None)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('fsdp', 'replica'))
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layers = torch.nn.ModuleList([nn.Linear(8, 8) for _ in range(400)])

    def forward(self, xt1):
        for layer in self.layers:
            xt1 = layer(xt1)
        return xt1

my_model = MyModel().to(device)

my_model = FSDPv2(my_model, mesh)
optimizer = syncfree.AdamW(my_model.parameters(), lr=0.01)
#optimizer = AdamW(my_model.parameters(), lr=0.01)
# loss = my_model(hidden_states.to(device)).sum()
# loss.backward()
# optimizer.step()
# print(loss)
t1 = torch.randn(8, 8)
for i in range(2):
    optimizer.zero_grad()
    ans = []
    partition_spec = [None] * len(t1.shape)
    partition_spec[0] = "fsdp"
    spec = xs.ShardingSpec(mesh, partition_spec)
    xt1 = xm.send_cpu_data_to_device(t1, xm.xla_device(), input_sharding=spec)[0]
    # ans.append(xt1)
    #xt1 = t1.to(device)
    #loss = my_model(xt1).sum()
    #print(loss)
    with autocast(xm.xla_device(), dtype=torch.bfloat16):
        loss = my_model(xt1).sum()
    # print(torch_xla._XLAC._get_xla_tensors_text([loss]))
    loss.backward()
    #optimizer.step()
    found_inf = torch.isnan(loss).to(torch.float32)
    optimizer.step(found_inf=found_inf)
    # xm.optimizer_step(optimizer)
    xm.mark_step()
  1. sh test.sh

Expected behavior

Environment

In the preceding example, if you change the number of linear to 10, it will work, but if you change it to 400, you will get an error. I observed that on the second compilation, all the input tensors were compressed into a tuple, and the sharding information was lost after the after compile optimization

Additional context

JackCaoG commented 5 months ago

I think wrap to tuple is due to https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L1290-L1291. On TPU we can't have more than 3200 HLO input paramters so we wrap them into a tuple. I would image SPMD still works through. Do you have the HLO for the wrapped case?

mars1248 commented 5 months ago

@JackCaoG Thanks for your reply, I managed to run it successfully by turning up XLA_PARAMETER_WRAPPING_THREADSHOLD. I see the bug is documented here, and the todo is documented. https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L1307 What is the function of tuple input?

JackCaoG commented 5 months ago

That TODO Is actually already fixed, SPMD graph input and output can correctly be aliased. wrapping happens https://github.com/pytorch/xla/blob/7938bb5da6c993609aff614ccfa5b722a339d158/torch_xla/csrc/helpers.cpp#L974-L1005

JackCaoG commented 4 months ago

https://github.com/pytorch/xla/pull/7604 should fix this issue