pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.28k stars 22.46k forks source link

Strange recompilations on torch 2.5 + FSDP + UNet #138813

Open GLivshits opened 3 hours ago

GLivshits commented 3 hours ago

🐛 Describe the bug

Simple compilation of UNet model works fine, but FSDP-wrapped UNet gets recompiled on every block. In real setup cache-size limit is rapidly reached.

Code:

import argparse
import os
from contextlib import nullcontext
from typing import List, Optional

import torch
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn import RMSNorm
from torch.nn.parallel import DistributedDataParallel
from tqdm.auto import tqdm

torch._dynamo.config.inline_inbuilt_nn_modules = False
torch._dynamo.config.optimize_ddp = False

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    torch.cuda.set_device(rank)
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module

class SpatialToSeq(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, *args, **kwargs):
        b, c, h, w = x.shape
        return x.permute(0, 2, 3, 1).view(b, h * w, c)

class SeqToSpatial(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, *args, **kwargs):
        b, n, c = x.shape
        spatial_dim = int(n**0.5)
        return x.permute(0, 2, 1).view(b, c, spatial_dim, spatial_dim)

class SelfAttention(nn.Module):
    def __init__(self, input_dim: int, out_dim: int, d_head: int):
        super().__init__()
        self.input_dim = input_dim
        self.out_dim = out_dim
        self.d_head = d_head
        self.n_heads = self.out_dim // self.d_head
        self.d_attn = self.out_dim

        self.pre_norm = nn.LayerNorm(input_dim)
        self.qkv_proj = nn.Linear(input_dim, 3 * self.d_attn, bias=False)
        self.q_norm = RMSNorm(self.d_attn, eps=1e-6)
        self.k_norm = RMSNorm(self.d_attn, eps=1e-6)
        self.to_out = nn.Linear(self.d_attn, self.out_dim)

    def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None, cond_mask: Optional[torch.Tensor] = None):
        x = self.pre_norm(x)
        q, k, v = self.qkv_proj(x).chunk(dim=-1, chunks=3)
        q = self.q_norm(q)
        k = self.k_norm(k)

        q = rearrange(q, "b n (h d) -> b h n d", h=self.n_heads)
        k = rearrange(k, "b n (h d) -> b h n d", h=self.n_heads)
        v = rearrange(v, "b n (h d) -> b h n d", h=self.n_heads)

        out = F.scaled_dot_product_attention(q, k, v)
        out = rearrange(out, "b h n d -> b n (h d)", h=self.n_heads)
        out = self.to_out(out)
        return out

class CrossAttention(nn.Module):
    def __init__(self, input_dim: int, cond_dim: int, out_dim: int, d_head: int):
        super().__init__()
        self.input_dim = input_dim
        self.cond_dim = cond_dim
        self.out_dim = out_dim
        self.d_head = d_head
        self.n_heads = self.out_dim // self.d_head
        self.d_attn = self.out_dim

        self.pre_norm = nn.LayerNorm(input_dim)
        self.cond_pre_norm = nn.LayerNorm(cond_dim)
        self.q_proj = nn.Linear(input_dim, self.d_attn, bias=False)
        self.kv_proj = nn.Linear(cond_dim, 2 * self.d_attn, bias=False)
        self.q_norm = RMSNorm(self.d_attn, eps=1e-6)
        self.k_norm = RMSNorm(self.d_attn, eps=1e-6)
        self.to_out = nn.Linear(self.d_attn, self.out_dim)

    def forward(self, x: torch.Tensor, cond: torch.Tensor, cond_mask: Optional[torch.Tensor] = None):
        x = self.pre_norm(x)
        cond = self.cond_pre_norm(cond)
        q = self.q_proj(x)
        k, v = self.kv_proj(cond).chunk(dim=-1, chunks=2)
        q = self.q_norm(q)
        k = self.k_norm(k)

        q = rearrange(q, "b n (h d) -> b h n d", h=self.n_heads)
        k = rearrange(k, "b n (h d) -> b h n d", h=self.n_heads)
        v = rearrange(v, "b n (h d) -> b h n d", h=self.n_heads)
        if cond_mask is not None:
            cond_mask = cond_mask.unsqueeze(1).unsqueeze(1)

        out = F.scaled_dot_product_attention(q, k, v, attn_mask=cond_mask)
        out = rearrange(out, "b h n d -> b n (h d)", h=self.n_heads)
        out = self.to_out(out)
        return out

class Upsample(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor, *args, **kwargs):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        return x

class Downsample(nn.Module):
    def __init__(self):
        super().__init__()
        self.op = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x: torch.Tensor, *args, **kwargs):
        return self.op(x)

class Sequential(nn.Sequential):
    def forward(self, x, *args, **kwargs):
        for layer in self:
            x = layer(x, *args, **kwargs)
        return x

class ResBlock(nn.Module):
    def __init__(
        self,
        channels: int,
        dropout: float,
        out_channels: Optional[int] = None,
        mid_channels: Optional[int] = None,
        use_conv: bool = False,
        up: bool = False,
        down: bool = False,
        norm_groups: int = 32,
    ):
        super().__init__()
        self.channels = channels
        self.dropout = dropout
        self.out_channels = out_channels or channels
        self.mid_channels = mid_channels or self.out_channels
        self.use_conv = use_conv

        conv_block = [nn.SiLU(), nn.Conv2d(channels, self.mid_channels, 3, padding=1)]

        self.in_layers = nn.ModuleList([nn.GroupNorm(num_channels=channels, num_groups=norm_groups), *conv_block])
        self.in_layers_len = len(self.in_layers)
        self.updown = up or down

        if up:
            self.h_upd = Upsample()
            self.x_upd = Upsample()
        elif down:
            self.h_upd = Downsample()
            self.x_upd = Downsample()
        else:
            self.h_upd = self.x_upd = nn.Identity()

        # override num group for shrinked model
        norm_groups = max(norm_groups * self.mid_channels // self.out_channels, 1)
        self.out_layers = nn.ModuleList(
            [
                nn.GroupNorm(num_channels=self.mid_channels, num_groups=norm_groups),
                nn.SiLU(),
                nn.Dropout(p=dropout),
                zero_module(nn.Conv2d(self.mid_channels, self.out_channels, 3, padding=1)),
            ]
        )
        self.out_layers_len = len(self.out_layers)

        if use_conv:
            self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
        else:
            if self.out_channels == channels:
                self.skip_connection = nn.Identity()
            else:
                self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)

    def forward(self, x: torch.Tensor, *args, **kwargs):
        h = x
        for i in range(self.in_layers_len - 1):
            h = self.in_layers[i](h)
        if self.updown:
            h = self.h_upd(h)
            x = self.x_upd(x)
        h = self.in_layers[self.in_layers_len - 1](h)

        for i in range(self.out_layers_len):
            h = self.out_layers[i](h)
        out = self.skip_connection(x) + h
        return out

class UNet(nn.Module):
    def __init__(self, in_dim: int, cond_dim: int, channels: List[int], attns: List[int], middle_attns: int = 0):
        super().__init__()
        assert len(attns) == len(channels) - 1

        self.in_dim = in_dim
        self.down_blocks = nn.ModuleList([])
        self.up_blocks = nn.ModuleList([])

        ch = channels[0]
        in_chs = [ch]

        self.in_block = nn.Conv2d(in_dim, channels[0], kernel_size=3, padding=1)

        for i, (ch, out_ch) in enumerate(zip(channels[:-1], channels[1:])):
            layer = [ResBlock(ch, 0.0, out_ch, out_ch)]
            if attns[i] > 0:
                layer.append(SpatialToSeq())
                for _ in range(attns[i]):
                    layer.append(SelfAttention(out_ch, out_ch, 64))
                    layer.append(CrossAttention(out_ch, cond_dim, out_ch, 64))
                layer.append(SeqToSpatial())
            layer.append(ResBlock(out_ch, 0.0, out_ch, out_ch, down=True))
            self.down_blocks.append(Sequential(*layer))
            in_chs.append(out_ch)

        layer = [ResBlock(out_ch, 0.0, out_ch, out_ch)]
        if middle_attns > 0:
            layer.append(SpatialToSeq())
            for _ in range(middle_attns):
                layer.append(SelfAttention(out_ch, out_ch, 64))
                layer.append(CrossAttention(out_ch, cond_dim, out_ch, 64))
            layer.append(SeqToSpatial())
        layer.append(ResBlock(out_ch, 0.0, out_ch, out_ch))
        self.middle_block = Sequential(*layer)

        for i, (ch1, ch2) in enumerate(zip(channels[::-1][:-1], channels[::-1][1:])):
            i = len(attns) - 1 - i
            ch = ch1 + in_chs.pop()
            out_ch = ch2
            layer = [ResBlock(ch, 0.0, out_ch, out_ch)]
            if attns[i] > 0:
                layer.append(SpatialToSeq())
                for _ in range(attns[i]):
                    layer.append(SelfAttention(out_ch, out_ch, 64))
                    layer.append(CrossAttention(out_ch, cond_dim, out_ch, 64))
                layer.append(SeqToSpatial())
            layer.append(ResBlock(out_ch, 0.0, out_ch, out_ch, up=True))
            self.up_blocks.append(Sequential(*layer))

        self.out_block = zero_module(nn.Conv2d(out_ch, in_dim, kernel_size=3, padding=1))

    def forward(self, x: torch.Tensor, cond: torch.Tensor, cond_mask: Optional[torch.Tensor] = None):
        res = []
        x = self.in_block(x)

        for layer in self.down_blocks:
            x = layer(x, cond, cond_mask)
            res.append(x)

        x = self.middle_block(x, cond, cond_mask)

        for layer in self.up_blocks:
            x = torch.cat([x, res.pop()], dim=1)
            x = layer(x, cond, cond_mask)

        x = self.out_block(x)
        return x

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_iterations", type=int, default=200)
    parser.add_argument("--use_ddp", action="store_true")
    parser.add_argument("--use_fsdp", action="store_true")
    parser.add_argument("--use_compile", action="store_true")
    parser.add_argument("--use_controlnet", action="store_true")
    parser.add_argument("--disable_fa2", action="store_true")
    args = parser.parse_args()
    return args

def main(rank, world_size, args):
    setup(rank, world_size)

    assert not (args.use_ddp and args.use_fsdp)

    device = torch.device(f"cuda:{rank}")
    dtype = torch.float16
    cond_dim = 1024
    cond_len = 128

    model = UNet(4, cond_dim, [128, 256, 512, 512], [2, 2, 2], 2).to(device)

    if args.use_compile:
        print("Trying compile.")
        model.compile(mode="default", dynamic=False)

    if args.use_fsdp:
        model = FSDP(
            module=model,
            device_id=rank,
            use_orig_params=args.use_compile or args.use_controlnet,
            sharding_strategy=ShardingStrategy.HYBRID_SHARD,
            forward_prefetch=True,
            limit_all_gathers=True,
            auto_wrap_policy=ModuleWrapPolicy({nn.Sequential}),
            mixed_precision=MixedPrecision(
                param_dtype=dtype,
                buffer_dtype=dtype,
                reduce_dtype=dtype,
            ),
        )
        loss_amp_context = torch.amp.autocast("cuda", dtype=dtype, enabled=True)
        model_amp_context = nullcontext()
        scaler = ShardedGradScaler(enabled=dtype == torch.float16)
    else:
        if args.use_ddp:
            model = DistributedDataParallel(
                model, broadcast_buffers=False, gradient_as_bucket_view=True, find_unused_parameters=False
            )
        loss_amp_context = torch.amp.autocast("cuda", dtype=dtype, enabled=True)
        model_amp_context = loss_amp_context
        scaler = torch.amp.GradScaler("cuda", enabled=dtype == torch.float16)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.98))

    iterator = range(args.num_iterations)
    if rank == 0:
        iterator = tqdm(iterator, total=args.num_iterations)

    for _ in iterator:
        x = torch.randn(args.batch_size, 4, 64, 64, device=device)
        cond = torch.randn(args.batch_size, cond_len, cond_dim, device=device)
        cond_mask = torch.randn(args.batch_size, cond_len, device=device) > 0
        with model_amp_context:
            out = model(x, cond, cond_mask)
        with loss_amp_context:
            loss = F.mse_loss(x, out)

        loss_test = loss.clone()  # Ensure local loss is not changed by allreduce
        torch.distributed.all_reduce(loss_test)  # Check if any gpu has NaN loss
        if rank == 0:
            iterator.set_description(f"Loss: {loss_test.item()}")
        if torch.isnan(loss_test):
            raise ValueError("NaN loss.")

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

    cleanup()

if __name__ == "__main__":
    args = parse_args()
    world_size = torch.cuda.device_count()
    torch.multiprocessing.freeze_support()
    if world_size == 1:
        main(0, world_size, args)
    else:
        torch.multiprocessing.spawn(fn=main, args=(world_size, args), nprocs=world_size, join=True)

Command:

TORCH_LOGS=recompiles CUDA_VISIBLE_DEVICES=4,6 python compile_debug.py --use_fsdp --use_compile

Output:

Trying compile. Trying compile. 0%| | 0/200 [00:00<?, ?it/s][rank0]:W1024 16:27:32.344000 1770485 site-packages/torch/_logging/_internal.py:1081] [0/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored [rank1]:W1024 16:27:32.344000 1770486 site-packages/torch/_logging/_internal.py:1081] [0/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored [rank0]:V1024 16:27:39.578000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/1] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank0]:V1024 16:27:39.578000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/1] [recompiles] triggered by the following guard failure(s): [rank0]:V1024 16:27:39.578000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/1] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 256 [rank1]:V1024 16:27:39.655000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/1] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank1]:V1024 16:27:39.655000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/1] [recompiles] triggered by the following guard failure(s): [rank1]:V1024 16:27:39.655000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/1] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 256 [rank1]:V1024 16:27:44.482000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/2] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank1]:V1024 16:27:44.482000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/2] [recompiles] triggered by the following guard failure(s): [rank1]:V1024 16:27:44.482000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/2] [recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 512 [rank1]:V1024 16:27:44.482000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/2] [__recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 512 [rank0]:V1024 16:27:44.483000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/2] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank0]:V1024 16:27:44.483000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/2] [recompiles] triggered by the following guard failure(s): [rank0]:V1024 16:27:44.483000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/2] [__recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 512 [rank0]:V1024 16:27:44.483000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/2] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 512 [rank0]:V1024 16:27:48.856000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank0]:V1024 16:27:48.856000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] triggered by the following guard failure(s): [rank0]:V1024 16:27:48.856000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] - 1/2: tensor 'L['x']' size mismatch at index 2. expected 16, actual 8 [rank0]:V1024 16:27:48.856000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/3] [__recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 512 [rank0]:V1024 16:27:48.856000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 512 [rank1]:V1024 16:27:48.877000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank1]:V1024 16:27:48.877000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] triggered by the following guard failure(s): [rank1]:V1024 16:27:48.877000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] - 1/2: tensor 'L['x']' size mismatch at index 2. expected 16, actual 8 [rank1]:V1024 16:27:48.877000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/3] [__recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 512 [rank1]:V1024 16:27:48.877000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 512 [rank0]:V1024 16:27:53.546000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank0]:V1024 16:27:53.546000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] triggered by the following guard failure(s): [rank0]:V1024 16:27:53.546000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] - 1/3: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank0]:V1024 16:27:53.546000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/4] [__recompiles] - 1/2: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank0]:V1024 16:27:53.546000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 1024 [rank0]:V1024 16:27:53.546000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 1024 [rank1]:V1024 16:27:53.910000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank1]:V1024 16:27:53.910000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] triggered by the following guard failure(s): [rank1]:V1024 16:27:53.910000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] - 1/3: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank1]:V1024 16:27:53.910000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] - 1/2: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank1]:V1024 16:27:53.910000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/4] [__recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 1024 [rank1]:V1024 16:27:53.910000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 1024 [rank0]:V1024 16:27:58.484000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank0]:V1024 16:27:58.484000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] triggered by the following guard failure(s): [rank0]:V1024 16:27:58.484000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] - 1/4: tensor 'L['x']' size mismatch at index 2. expected 8, actual 16 [rank0]:V1024 16:27:58.484000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/5] [__recompiles] - 1/3: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank0]:V1024 16:27:58.484000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] - 1/2: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank0]:V1024 16:27:58.484000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 1024 [rank0]:V1024 16:27:58.484000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/5] [__recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 1024 [rank1]:V1024 16:27:58.940000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank1]:V1024 16:27:58.940000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] triggered by the following guard failure(s): [rank1]:V1024 16:27:58.940000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/5] [__recompiles] - 1/4: tensor 'L['x']' size mismatch at index 2. expected 8, actual 16 [rank1]:V1024 16:27:58.940000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] - 1/3: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank1]:V1024 16:27:58.940000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] - 1/2: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank1]:V1024 16:27:58.940000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/5] [__recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 1024 [rank1]:V1024 16:27:58.940000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 1024 [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] triggered by the following guard failure(s): [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/5: tensor 'L['x']' size mismatch at index 1. expected 1024, actual 512 [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [__recompiles] - 1/4: tensor 'L['x']' size mismatch at index 1. expected 1024, actual 512 [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/3: tensor 'L['x']' size mismatch at index 2. expected 8, actual 32 [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/2: tensor 'L['x']' size mismatch at index 2. expected 16, actual 32 [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [__recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 512 [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 512 [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] triggered by the following guard failure(s): [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/5: tensor 'L['x']' size mismatch at index 1. expected 1024, actual 512 [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [__recompiles] - 1/4: tensor 'L['x']' size mismatch at index 1. expected 1024, actual 512 [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/3: tensor 'L['x']' size mismatch at index 2. expected 8, actual 32 [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/2: tensor 'L['x']' size mismatch at index 2. expected 16, actual 32 [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [__recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 512 [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 512 Loss: 1.9992671012878418: 0%| | 0/200 [00:38<?, ?it/s]/home/user/anaconda3/envs/torch25/lib/python3.10/site-packages/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at ../aten/src/ATen/native/cudnn/MHA.cpp:674.) return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass /home/user/anaconda3/envs/torch25/lib/python3.10/site-packages/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at ../aten/src/ATen/native/cudnn/MHA.cpp:674.) return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass Loss: 2.003182888031006: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:25<00:00, 2.34it/s]

Versions

PyTorch version: 2.5.0+cu124 Is debug build: False CUDA used to build PyTorch: 12.4 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: 14.0.0-1ubuntu1.1 CMake version: version 3.30.0 Libc version: glibc-2.35

Python version: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0] (64-bit runtime) Python platform: Linux-5.4.210-39.1-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 12.4.99 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A800-SXM4-80GB GPU 1: NVIDIA A800-SXM4-80GB GPU 2: NVIDIA A800-SXM4-80GB GPU 3: NVIDIA A800-SXM4-80GB GPU 4: NVIDIA A800-SXM4-80GB GPU 5: NVIDIA A800-SXM4-80GB GPU 6: NVIDIA A800-SXM4-80GB GPU 7: NVIDIA A800-SXM4-80GB

Nvidia driver version: 550.54.14 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 120 On-line CPU(s) list: 0-119 Vendor ID: AuthenticAMD Model name: AMD EPYC 7662 64-Core Processor CPU family: 23 Model: 49 Thread(s) per core: 1 Core(s) per socket: 1 Socket(s): 120 Stepping: 0 BogoMIPS: 3992.45 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr wbnoinvd arat npt nrip_save umip rdpid arch_capabilities Virtualization: AMD-V L1d cache: 7.5 MiB (120 instances) L1i cache: 7.5 MiB (120 instances) L2 cache: 60 MiB (120 instances) L3 cache: 1.9 GiB (120 instances) NUMA node(s): 4 NUMA node0 CPU(s): 0-29 NUMA node1 CPU(s): 30-59 NUMA node2 CPU(s): 60-89 NUMA node3 CPU(s): 90-119 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] flake8==5.0.4 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.4 [pip3] open-clip-torch==2.24.0 [pip3] pytorch-warmup==0.1.1 [pip3] torch==2.5.0 [pip3] torch-fidelity==0.3.0 [pip3] torch-model-archiver==0.11.1 [pip3] torch-tb-profiler==0.4.3 [pip3] torch-workflow-archiver==0.2.14 [pip3] torchaudio==2.5.0 [pip3] torchdata==0.7.1 [pip3] torchmetrics==1.4.0.post0 [pip3] torchsde==0.2.6 [pip3] torchserve==0.11.1 [pip3] torchvision==0.20.0 [pip3] triton==3.1.0 [conda] numpy 1.26.4 pypi_0 pypi [conda] open-clip-torch 2.24.0 pypi_0 pypi [conda] torch 2.5.0 pypi_0 pypi [conda] torch-fidelity 0.3.0 pypi_0 pypi [conda] torch-model-archiver 0.11.1 pypi_0 pypi [conda] torch-tb-profiler 0.4.3 pypi_0 pypi [conda] torch-workflow-archiver 0.2.14 pypi_0 pypi [conda] torchaudio 2.5.0 pypi_0 pypi [conda] torchdata 0.7.1 pypi_0 pypi [conda] torchmetrics 1.4.0.post0 pypi_0 pypi [conda] torchsde 0.2.6 pypi_0 pypi [conda] torchserve 0.11.1 pypi_0 pypi [conda] torchvision 0.20.0 pypi_0 pypi [conda] triton 3.1.0 pypi_0 pypi

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @rec

GLivshits commented 3 hours ago

It seems like after compiling the first FSDP module compiler recognizes the next one as previous