microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.45k stars 4.12k forks source link

[BUG] Zero3 for torch.compile with compiled_autograd when running LayerNorm #6719

Open yitingw1 opened 1 week ago

yitingw1 commented 1 week ago

Describe the bug

When running a simple model including torch.nn.LayerNorm using deepspeed zero3 with torch.compile and compiled_autograd. An error occurs:

site-packages/torch/_subclasses/fake_tensor.py:2017] RuntimeError: Attempting to broadcast a dimension of length 0 at -1! Mismatching argument at index 1 had torch.Size([0]); but expected shape should be broadcastable to [100, 120]

We first found this error in BERT model with deepspeed Zero3 with torch.compile and compiled_autograd.

Expected behavior Running the model with deepspeed Zero3 without error.

Investigation

The error: "RuntimeError: Attempting to broadcast a dimension of length 0 at -1! Mismatching argument at index 1 had torch.Size([0]); but expected shape should be broadcastable to [128, 128, 1600]" It occurs when compiled autograd tries to trace the backward graph. It appears in LayerNorm backward decompositions. It tries to broadcast weight_cast(torch.Size([0]) to grad_out_cast' shape([128,128,1600]) and fails.

if weight_cast is not None:         
    grad_x_hat = grad_out_cast * weight_cast 

If bypassing the LayerNorm weight by setting nn.LayerNorm(120, eps=1e-12, elementwise_affine=False) instead of elementwise_affine=True in the file deepspeed_reproducer_cpu.py, the running is ok.

System info:

To Reproduce Steps to reproduce the behavior:

  1. Set environment variable for more verbose logs: TORCH_LOGS="+dynamo,graph,graph_code,graph_breaks,recompiles,aot_graphs,aot_joint_graph,compiled_autograd_verbose"
  2. Run with deepspeed --num_nodes 1 --num_gpus 1 deepspeed_reproducer_cpu.py
  3. You can use --num_gpus 2/4/8 for multi-cards
  4. Below is deepspeed_reproducer_cpu.py
    
    import torch
    import torchvision
    import torchvision.transforms as transforms
    import torch.distributed as dist
    import deepspeed
    from deepspeed.accelerator import get_accelerator
    from tqdm import tqdm
    from torch.utils.data import DataLoader
    from torch.utils.data.distributed import DistributedSampler
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim

class Net(nn.Module): def init(self): super().init() self.fc1 = nn.Linear(32 32 3, 120) self.fc2 = nn.Linear(120, 10) self.LayerNorm1 = nn.LayerNorm(120, eps=1e-12, elementwise_affine=True)

def forward(self, x):
    x = torch.flatten(x, 1)  # flatten all dimensions except batch
    x = F.relu(self.fc1(x))
    x = self.LayerNorm1(x)
    x = self.fc2(x)
    return x

compile_kwargs = {"dynamic": False} device = torch.device('cpu')

model = Net() model.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) modelengine, optimizer, * = deepspeed.initialize( model=model, model_parameters=model.parameters(), optimizer=optimizer, config="./deepspeed_config.json", )

torch_compile

model_engine.compile( compile_kwargs=compile_kwargs, )

dataset

transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) batch_size = 100 trainset = torchvision.datasets.CIFAR10( root="./DATA/CIFAR10", train=True, download=True, transform=transform )

process dataset

trainloader = DataLoader( trainset, batch_size=batch_size, sampler=DistributedSampler(trainset, shuffle=True), num_workers=16, pin_memory=True, ) progress_bar = tqdm( total=len(trainloader), desc=f"Training 1/1 epoch", position=0, leave=True, disable= dist.is_initialized() and dist.get_rank() != 0, ) for epoch in range(100): with torch._dynamo.compiled_autograd.enable( torch.compile(backend=get_accelerator().get_compile_backend(), **compile_kwargs)): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model_engine(inputs)
        loss = criterion(outputs, labels)
        model_engine.backward(loss)
        model_engine.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:  # print every 2000 mini-batches
            print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
            running_loss = 0.0
        progress_bar.update(1)

print("Finished Training")

5. Below is deepspeed_config.json
```json
{
    "train_batch_size": 32, 
    "optimizer": {
        "type": "SGD",
        "params": {
            "lr": 0.001,
            "momentum": 0.9
        }
    },
    "zero_allow_untested_optimizer": true,
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": false,
        "reduce_scatter" : false,
        "contiguous_gradients" : false
    },
}
tohtana commented 5 days ago

Hi @yitingw1, I wonder if persistent parameter might not work well with the compiler. Can you try setting stage3_param_persistence_threshold to zero?

yitingw1 commented 3 days ago

Hi @tohtana, I have tried setting stage3_param_persistence_threshold to zero, but it seems it doesn't help. The error still occurs. I also opened an issue in pytorch.