Open GLivshits opened 3 hours ago
Simple compilation of UNet model works fine, but FSDP-wrapped UNet gets recompiled on every block. In real setup cache-size limit is rapidly reached.
Code:
import argparse import os from contextlib import nullcontext from typing import List, Optional import torch import torch.distributed as dist import torch.nn.functional as F from einops import rearrange from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.nn import RMSNorm from torch.nn.parallel import DistributedDataParallel from tqdm.auto import tqdm torch._dynamo.config.inline_inbuilt_nn_modules = False torch._dynamo.config.optimize_ddp = False def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" torch.cuda.set_device(rank) dist.init_process_group("nccl", rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module class SpatialToSeq(nn.Module): def __init__(self): super().__init__() def forward(self, x, *args, **kwargs): b, c, h, w = x.shape return x.permute(0, 2, 3, 1).view(b, h * w, c) class SeqToSpatial(nn.Module): def __init__(self): super().__init__() def forward(self, x, *args, **kwargs): b, n, c = x.shape spatial_dim = int(n**0.5) return x.permute(0, 2, 1).view(b, c, spatial_dim, spatial_dim) class SelfAttention(nn.Module): def __init__(self, input_dim: int, out_dim: int, d_head: int): super().__init__() self.input_dim = input_dim self.out_dim = out_dim self.d_head = d_head self.n_heads = self.out_dim // self.d_head self.d_attn = self.out_dim self.pre_norm = nn.LayerNorm(input_dim) self.qkv_proj = nn.Linear(input_dim, 3 * self.d_attn, bias=False) self.q_norm = RMSNorm(self.d_attn, eps=1e-6) self.k_norm = RMSNorm(self.d_attn, eps=1e-6) self.to_out = nn.Linear(self.d_attn, self.out_dim) def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None, cond_mask: Optional[torch.Tensor] = None): x = self.pre_norm(x) q, k, v = self.qkv_proj(x).chunk(dim=-1, chunks=3) q = self.q_norm(q) k = self.k_norm(k) q = rearrange(q, "b n (h d) -> b h n d", h=self.n_heads) k = rearrange(k, "b n (h d) -> b h n d", h=self.n_heads) v = rearrange(v, "b n (h d) -> b h n d", h=self.n_heads) out = F.scaled_dot_product_attention(q, k, v) out = rearrange(out, "b h n d -> b n (h d)", h=self.n_heads) out = self.to_out(out) return out class CrossAttention(nn.Module): def __init__(self, input_dim: int, cond_dim: int, out_dim: int, d_head: int): super().__init__() self.input_dim = input_dim self.cond_dim = cond_dim self.out_dim = out_dim self.d_head = d_head self.n_heads = self.out_dim // self.d_head self.d_attn = self.out_dim self.pre_norm = nn.LayerNorm(input_dim) self.cond_pre_norm = nn.LayerNorm(cond_dim) self.q_proj = nn.Linear(input_dim, self.d_attn, bias=False) self.kv_proj = nn.Linear(cond_dim, 2 * self.d_attn, bias=False) self.q_norm = RMSNorm(self.d_attn, eps=1e-6) self.k_norm = RMSNorm(self.d_attn, eps=1e-6) self.to_out = nn.Linear(self.d_attn, self.out_dim) def forward(self, x: torch.Tensor, cond: torch.Tensor, cond_mask: Optional[torch.Tensor] = None): x = self.pre_norm(x) cond = self.cond_pre_norm(cond) q = self.q_proj(x) k, v = self.kv_proj(cond).chunk(dim=-1, chunks=2) q = self.q_norm(q) k = self.k_norm(k) q = rearrange(q, "b n (h d) -> b h n d", h=self.n_heads) k = rearrange(k, "b n (h d) -> b h n d", h=self.n_heads) v = rearrange(v, "b n (h d) -> b h n d", h=self.n_heads) if cond_mask is not None: cond_mask = cond_mask.unsqueeze(1).unsqueeze(1) out = F.scaled_dot_product_attention(q, k, v, attn_mask=cond_mask) out = rearrange(out, "b h n d -> b n (h d)", h=self.n_heads) out = self.to_out(out) return out class Upsample(nn.Module): def __init__(self): super().__init__() def forward(self, x: torch.Tensor, *args, **kwargs): x = F.interpolate(x, scale_factor=2, mode="nearest") return x class Downsample(nn.Module): def __init__(self): super().__init__() self.op = nn.AvgPool2d(kernel_size=2, stride=2) def forward(self, x: torch.Tensor, *args, **kwargs): return self.op(x) class Sequential(nn.Sequential): def forward(self, x, *args, **kwargs): for layer in self: x = layer(x, *args, **kwargs) return x class ResBlock(nn.Module): def __init__( self, channels: int, dropout: float, out_channels: Optional[int] = None, mid_channels: Optional[int] = None, use_conv: bool = False, up: bool = False, down: bool = False, norm_groups: int = 32, ): super().__init__() self.channels = channels self.dropout = dropout self.out_channels = out_channels or channels self.mid_channels = mid_channels or self.out_channels self.use_conv = use_conv conv_block = [nn.SiLU(), nn.Conv2d(channels, self.mid_channels, 3, padding=1)] self.in_layers = nn.ModuleList([nn.GroupNorm(num_channels=channels, num_groups=norm_groups), *conv_block]) self.in_layers_len = len(self.in_layers) self.updown = up or down if up: self.h_upd = Upsample() self.x_upd = Upsample() elif down: self.h_upd = Downsample() self.x_upd = Downsample() else: self.h_upd = self.x_upd = nn.Identity() # override num group for shrinked model norm_groups = max(norm_groups * self.mid_channels // self.out_channels, 1) self.out_layers = nn.ModuleList( [ nn.GroupNorm(num_channels=self.mid_channels, num_groups=norm_groups), nn.SiLU(), nn.Dropout(p=dropout), zero_module(nn.Conv2d(self.mid_channels, self.out_channels, 3, padding=1)), ] ) self.out_layers_len = len(self.out_layers) if use_conv: self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) else: if self.out_channels == channels: self.skip_connection = nn.Identity() else: self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) def forward(self, x: torch.Tensor, *args, **kwargs): h = x for i in range(self.in_layers_len - 1): h = self.in_layers[i](h) if self.updown: h = self.h_upd(h) x = self.x_upd(x) h = self.in_layers[self.in_layers_len - 1](h) for i in range(self.out_layers_len): h = self.out_layers[i](h) out = self.skip_connection(x) + h return out class UNet(nn.Module): def __init__(self, in_dim: int, cond_dim: int, channels: List[int], attns: List[int], middle_attns: int = 0): super().__init__() assert len(attns) == len(channels) - 1 self.in_dim = in_dim self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) ch = channels[0] in_chs = [ch] self.in_block = nn.Conv2d(in_dim, channels[0], kernel_size=3, padding=1) for i, (ch, out_ch) in enumerate(zip(channels[:-1], channels[1:])): layer = [ResBlock(ch, 0.0, out_ch, out_ch)] if attns[i] > 0: layer.append(SpatialToSeq()) for _ in range(attns[i]): layer.append(SelfAttention(out_ch, out_ch, 64)) layer.append(CrossAttention(out_ch, cond_dim, out_ch, 64)) layer.append(SeqToSpatial()) layer.append(ResBlock(out_ch, 0.0, out_ch, out_ch, down=True)) self.down_blocks.append(Sequential(*layer)) in_chs.append(out_ch) layer = [ResBlock(out_ch, 0.0, out_ch, out_ch)] if middle_attns > 0: layer.append(SpatialToSeq()) for _ in range(middle_attns): layer.append(SelfAttention(out_ch, out_ch, 64)) layer.append(CrossAttention(out_ch, cond_dim, out_ch, 64)) layer.append(SeqToSpatial()) layer.append(ResBlock(out_ch, 0.0, out_ch, out_ch)) self.middle_block = Sequential(*layer) for i, (ch1, ch2) in enumerate(zip(channels[::-1][:-1], channels[::-1][1:])): i = len(attns) - 1 - i ch = ch1 + in_chs.pop() out_ch = ch2 layer = [ResBlock(ch, 0.0, out_ch, out_ch)] if attns[i] > 0: layer.append(SpatialToSeq()) for _ in range(attns[i]): layer.append(SelfAttention(out_ch, out_ch, 64)) layer.append(CrossAttention(out_ch, cond_dim, out_ch, 64)) layer.append(SeqToSpatial()) layer.append(ResBlock(out_ch, 0.0, out_ch, out_ch, up=True)) self.up_blocks.append(Sequential(*layer)) self.out_block = zero_module(nn.Conv2d(out_ch, in_dim, kernel_size=3, padding=1)) def forward(self, x: torch.Tensor, cond: torch.Tensor, cond_mask: Optional[torch.Tensor] = None): res = [] x = self.in_block(x) for layer in self.down_blocks: x = layer(x, cond, cond_mask) res.append(x) x = self.middle_block(x, cond, cond_mask) for layer in self.up_blocks: x = torch.cat([x, res.pop()], dim=1) x = layer(x, cond, cond_mask) x = self.out_block(x) return x def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_iterations", type=int, default=200) parser.add_argument("--use_ddp", action="store_true") parser.add_argument("--use_fsdp", action="store_true") parser.add_argument("--use_compile", action="store_true") parser.add_argument("--use_controlnet", action="store_true") parser.add_argument("--disable_fa2", action="store_true") args = parser.parse_args() return args def main(rank, world_size, args): setup(rank, world_size) assert not (args.use_ddp and args.use_fsdp) device = torch.device(f"cuda:{rank}") dtype = torch.float16 cond_dim = 1024 cond_len = 128 model = UNet(4, cond_dim, [128, 256, 512, 512], [2, 2, 2], 2).to(device) if args.use_compile: print("Trying compile.") model.compile(mode="default", dynamic=False) if args.use_fsdp: model = FSDP( module=model, device_id=rank, use_orig_params=args.use_compile or args.use_controlnet, sharding_strategy=ShardingStrategy.HYBRID_SHARD, forward_prefetch=True, limit_all_gathers=True, auto_wrap_policy=ModuleWrapPolicy({nn.Sequential}), mixed_precision=MixedPrecision( param_dtype=dtype, buffer_dtype=dtype, reduce_dtype=dtype, ), ) loss_amp_context = torch.amp.autocast("cuda", dtype=dtype, enabled=True) model_amp_context = nullcontext() scaler = ShardedGradScaler(enabled=dtype == torch.float16) else: if args.use_ddp: model = DistributedDataParallel( model, broadcast_buffers=False, gradient_as_bucket_view=True, find_unused_parameters=False ) loss_amp_context = torch.amp.autocast("cuda", dtype=dtype, enabled=True) model_amp_context = loss_amp_context scaler = torch.amp.GradScaler("cuda", enabled=dtype == torch.float16) optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.98)) iterator = range(args.num_iterations) if rank == 0: iterator = tqdm(iterator, total=args.num_iterations) for _ in iterator: x = torch.randn(args.batch_size, 4, 64, 64, device=device) cond = torch.randn(args.batch_size, cond_len, cond_dim, device=device) cond_mask = torch.randn(args.batch_size, cond_len, device=device) > 0 with model_amp_context: out = model(x, cond, cond_mask) with loss_amp_context: loss = F.mse_loss(x, out) loss_test = loss.clone() # Ensure local loss is not changed by allreduce torch.distributed.all_reduce(loss_test) # Check if any gpu has NaN loss if rank == 0: iterator.set_description(f"Loss: {loss_test.item()}") if torch.isnan(loss_test): raise ValueError("NaN loss.") scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() cleanup() if __name__ == "__main__": args = parse_args() world_size = torch.cuda.device_count() torch.multiprocessing.freeze_support() if world_size == 1: main(0, world_size, args) else: torch.multiprocessing.spawn(fn=main, args=(world_size, args), nprocs=world_size, join=True)
Command:
TORCH_LOGS=recompiles CUDA_VISIBLE_DEVICES=4,6 python compile_debug.py --use_fsdp --use_compile
Output:
Trying compile. Trying compile. 0%| | 0/200 [00:00<?, ?it/s][rank0]:W1024 16:27:32.344000 1770485 site-packages/torch/_logging/_internal.py:1081] [0/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored [rank1]:W1024 16:27:32.344000 1770486 site-packages/torch/_logging/_internal.py:1081] [0/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored [rank0]:V1024 16:27:39.578000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/1] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank0]:V1024 16:27:39.578000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/1] [recompiles] triggered by the following guard failure(s): [rank0]:V1024 16:27:39.578000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/1] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 256 [rank1]:V1024 16:27:39.655000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/1] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank1]:V1024 16:27:39.655000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/1] [recompiles] triggered by the following guard failure(s): [rank1]:V1024 16:27:39.655000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/1] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 256 [rank1]:V1024 16:27:44.482000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/2] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank1]:V1024 16:27:44.482000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/2] [recompiles] triggered by the following guard failure(s): [rank1]:V1024 16:27:44.482000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/2] [recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 512 [rank1]:V1024 16:27:44.482000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/2] [__recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 512 [rank0]:V1024 16:27:44.483000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/2] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank0]:V1024 16:27:44.483000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/2] [recompiles] triggered by the following guard failure(s): [rank0]:V1024 16:27:44.483000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/2] [__recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 512 [rank0]:V1024 16:27:44.483000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/2] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 512 [rank0]:V1024 16:27:48.856000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank0]:V1024 16:27:48.856000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] triggered by the following guard failure(s): [rank0]:V1024 16:27:48.856000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] - 1/2: tensor 'L['x']' size mismatch at index 2. expected 16, actual 8 [rank0]:V1024 16:27:48.856000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/3] [__recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 512 [rank0]:V1024 16:27:48.856000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 512 [rank1]:V1024 16:27:48.877000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank1]:V1024 16:27:48.877000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] triggered by the following guard failure(s): [rank1]:V1024 16:27:48.877000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] - 1/2: tensor 'L['x']' size mismatch at index 2. expected 16, actual 8 [rank1]:V1024 16:27:48.877000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/3] [__recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 512 [rank1]:V1024 16:27:48.877000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/3] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 512 [rank0]:V1024 16:27:53.546000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank0]:V1024 16:27:53.546000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] triggered by the following guard failure(s): [rank0]:V1024 16:27:53.546000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] - 1/3: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank0]:V1024 16:27:53.546000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/4] [__recompiles] - 1/2: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank0]:V1024 16:27:53.546000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 1024 [rank0]:V1024 16:27:53.546000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 1024 [rank1]:V1024 16:27:53.910000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank1]:V1024 16:27:53.910000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] triggered by the following guard failure(s): [rank1]:V1024 16:27:53.910000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] - 1/3: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank1]:V1024 16:27:53.910000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] - 1/2: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank1]:V1024 16:27:53.910000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/4] [__recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 1024 [rank1]:V1024 16:27:53.910000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/4] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 1024 [rank0]:V1024 16:27:58.484000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank0]:V1024 16:27:58.484000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] triggered by the following guard failure(s): [rank0]:V1024 16:27:58.484000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] - 1/4: tensor 'L['x']' size mismatch at index 2. expected 8, actual 16 [rank0]:V1024 16:27:58.484000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/5] [__recompiles] - 1/3: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank0]:V1024 16:27:58.484000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] - 1/2: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank0]:V1024 16:27:58.484000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 1024 [rank0]:V1024 16:27:58.484000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/5] [__recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 1024 [rank1]:V1024 16:27:58.940000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank1]:V1024 16:27:58.940000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] triggered by the following guard failure(s): [rank1]:V1024 16:27:58.940000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/5] [__recompiles] - 1/4: tensor 'L['x']' size mismatch at index 2. expected 8, actual 16 [rank1]:V1024 16:27:58.940000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] - 1/3: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank1]:V1024 16:27:58.940000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] - 1/2: tensor 'L['x']' size mismatch at index 1. expected 512, actual 1024 [rank1]:V1024 16:27:58.940000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/5] [__recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 1024 [rank1]:V1024 16:27:58.940000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/5] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 1024 [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] triggered by the following guard failure(s): [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/5: tensor 'L['x']' size mismatch at index 1. expected 1024, actual 512 [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [__recompiles] - 1/4: tensor 'L['x']' size mismatch at index 1. expected 1024, actual 512 [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/3: tensor 'L['x']' size mismatch at index 2. expected 8, actual 32 [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/2: tensor 'L['x']' size mismatch at index 2. expected 16, actual 32 [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [__recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 512 [rank0]:V1024 16:28:03.509000 1770485 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 512 [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] Recompiling function forward in /home/user/compile_debug.py:150 [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] triggered by the following guard failure(s): [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/5: tensor 'L['x']' size mismatch at index 1. expected 1024, actual 512 [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [__recompiles] - 1/4: tensor 'L['x']' size mismatch at index 1. expected 1024, actual 512 [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/3: tensor 'L['x']' size mismatch at index 2. expected 8, actual 32 [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/2: tensor 'L['x']' size mismatch at index 2. expected 16, actual 32 [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [__recompiles] - 1/1: tensor 'L['x']' size mismatch at index 1. expected 256, actual 512 [rank1]:V1024 16:28:04.811000 1770486 site-packages/torch/_dynamo/guards.py:2813] [1/6] [recompiles] - 1/0: tensor 'L['x']' size mismatch at index 1. expected 128, actual 512 Loss: 1.9992671012878418: 0%| | 0/200 [00:38<?, ?it/s]/home/user/anaconda3/envs/torch25/lib/python3.10/site-packages/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at ../aten/src/ATen/native/cudnn/MHA.cpp:674.) return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass /home/user/anaconda3/envs/torch25/lib/python3.10/site-packages/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at ../aten/src/ATen/native/cudnn/MHA.cpp:674.) return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass Loss: 2.003182888031006: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:25<00:00, 2.34it/s]
PyTorch version: 2.5.0+cu124 Is debug build: False CUDA used to build PyTorch: 12.4 ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: 14.0.0-1ubuntu1.1 CMake version: version 3.30.0 Libc version: glibc-2.35
Python version: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0] (64-bit runtime) Python platform: Linux-5.4.210-39.1-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 12.4.99 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A800-SXM4-80GB GPU 1: NVIDIA A800-SXM4-80GB GPU 2: NVIDIA A800-SXM4-80GB GPU 3: NVIDIA A800-SXM4-80GB GPU 4: NVIDIA A800-SXM4-80GB GPU 5: NVIDIA A800-SXM4-80GB GPU 6: NVIDIA A800-SXM4-80GB GPU 7: NVIDIA A800-SXM4-80GB
Nvidia driver version: 550.54.14 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 120 On-line CPU(s) list: 0-119 Vendor ID: AuthenticAMD Model name: AMD EPYC 7662 64-Core Processor CPU family: 23 Model: 49 Thread(s) per core: 1 Core(s) per socket: 1 Socket(s): 120 Stepping: 0 BogoMIPS: 3992.45 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr wbnoinvd arat npt nrip_save umip rdpid arch_capabilities Virtualization: AMD-V L1d cache: 7.5 MiB (120 instances) L1i cache: 7.5 MiB (120 instances) L2 cache: 60 MiB (120 instances) L3 cache: 1.9 GiB (120 instances) NUMA node(s): 4 NUMA node0 CPU(s): 0-29 NUMA node1 CPU(s): 30-59 NUMA node2 CPU(s): 60-89 NUMA node3 CPU(s): 90-119 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected
Versions of relevant libraries: [pip3] flake8==5.0.4 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.4 [pip3] open-clip-torch==2.24.0 [pip3] pytorch-warmup==0.1.1 [pip3] torch==2.5.0 [pip3] torch-fidelity==0.3.0 [pip3] torch-model-archiver==0.11.1 [pip3] torch-tb-profiler==0.4.3 [pip3] torch-workflow-archiver==0.2.14 [pip3] torchaudio==2.5.0 [pip3] torchdata==0.7.1 [pip3] torchmetrics==1.4.0.post0 [pip3] torchsde==0.2.6 [pip3] torchserve==0.11.1 [pip3] torchvision==0.20.0 [pip3] triton==3.1.0 [conda] numpy 1.26.4 pypi_0 pypi [conda] open-clip-torch 2.24.0 pypi_0 pypi [conda] torch 2.5.0 pypi_0 pypi [conda] torch-fidelity 0.3.0 pypi_0 pypi [conda] torch-model-archiver 0.11.1 pypi_0 pypi [conda] torch-tb-profiler 0.4.3 pypi_0 pypi [conda] torch-workflow-archiver 0.2.14 pypi_0 pypi [conda] torchaudio 2.5.0 pypi_0 pypi [conda] torchdata 0.7.1 pypi_0 pypi [conda] torchmetrics 1.4.0.post0 pypi_0 pypi [conda] torchsde 0.2.6 pypi_0 pypi [conda] torchserve 0.11.1 pypi_0 pypi [conda] torchvision 0.20.0 pypi_0 pypi [conda] triton 3.1.0 pypi_0 pypi
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @rec
It seems like after compiling the first FSDP module compiler recognizes the next one as previous
🐛 Describe the bug
Simple compilation of UNet model works fine, but FSDP-wrapped UNet gets recompiled on every block. In real setup cache-size limit is rapidly reached.
Code:
Command:
Output:
Versions
PyTorch version: 2.5.0+cu124 Is debug build: False CUDA used to build PyTorch: 12.4 ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: 14.0.0-1ubuntu1.1 CMake version: version 3.30.0 Libc version: glibc-2.35
Python version: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0] (64-bit runtime) Python platform: Linux-5.4.210-39.1-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 12.4.99 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A800-SXM4-80GB GPU 1: NVIDIA A800-SXM4-80GB GPU 2: NVIDIA A800-SXM4-80GB GPU 3: NVIDIA A800-SXM4-80GB GPU 4: NVIDIA A800-SXM4-80GB GPU 5: NVIDIA A800-SXM4-80GB GPU 6: NVIDIA A800-SXM4-80GB GPU 7: NVIDIA A800-SXM4-80GB
Nvidia driver version: 550.54.14 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 120 On-line CPU(s) list: 0-119 Vendor ID: AuthenticAMD Model name: AMD EPYC 7662 64-Core Processor CPU family: 23 Model: 49 Thread(s) per core: 1 Core(s) per socket: 1 Socket(s): 120 Stepping: 0 BogoMIPS: 3992.45 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr wbnoinvd arat npt nrip_save umip rdpid arch_capabilities Virtualization: AMD-V L1d cache: 7.5 MiB (120 instances) L1i cache: 7.5 MiB (120 instances) L2 cache: 60 MiB (120 instances) L3 cache: 1.9 GiB (120 instances) NUMA node(s): 4 NUMA node0 CPU(s): 0-29 NUMA node1 CPU(s): 30-59 NUMA node2 CPU(s): 60-89 NUMA node3 CPU(s): 90-119 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected
Versions of relevant libraries: [pip3] flake8==5.0.4 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.4 [pip3] open-clip-torch==2.24.0 [pip3] pytorch-warmup==0.1.1 [pip3] torch==2.5.0 [pip3] torch-fidelity==0.3.0 [pip3] torch-model-archiver==0.11.1 [pip3] torch-tb-profiler==0.4.3 [pip3] torch-workflow-archiver==0.2.14 [pip3] torchaudio==2.5.0 [pip3] torchdata==0.7.1 [pip3] torchmetrics==1.4.0.post0 [pip3] torchsde==0.2.6 [pip3] torchserve==0.11.1 [pip3] torchvision==0.20.0 [pip3] triton==3.1.0 [conda] numpy 1.26.4 pypi_0 pypi [conda] open-clip-torch 2.24.0 pypi_0 pypi [conda] torch 2.5.0 pypi_0 pypi [conda] torch-fidelity 0.3.0 pypi_0 pypi [conda] torch-model-archiver 0.11.1 pypi_0 pypi [conda] torch-tb-profiler 0.4.3 pypi_0 pypi [conda] torch-workflow-archiver 0.2.14 pypi_0 pypi [conda] torchaudio 2.5.0 pypi_0 pypi [conda] torchdata 0.7.1 pypi_0 pypi [conda] torchmetrics 1.4.0.post0 pypi_0 pypi [conda] torchsde 0.2.6 pypi_0 pypi [conda] torchserve 0.11.1 pypi_0 pypi [conda] torchvision 0.20.0 pypi_0 pypi [conda] triton 3.1.0 pypi_0 pypi
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @rec