I noticed an error/warning because of the strides misshape in DDP when you train the architecture ddpmpp, a simple fix for this is to fix the operation in the attention modules contiguously. Here's the error :
params[166] in this process with sizes [256, 256, 1, 1] appears not to match strides of the same param in process 0.
To fix it you need to modify 5 lines of code.
For the AttentionOP modules you modify the backward and forward from :
class AttentionOp(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k):
w = torch.einsum('ncq,nck->nqk', q.to(torch.float32), (k / np.sqrt(k.shape[1])).to(torch.float32)).softmax(dim=2).to(q.dtype)
ctx.save_for_backward(q, k, w)
return w
@staticmethod
def backward(ctx, dw):
q, k, w = ctx.saved_tensors
db = torch._softmax_backward_data(grad_output=dw.to(torch.float32), output=w.to(torch.float32), dim=2, input_dtype=torch.float32)
dq = torch.einsum('nck,nqk->ncq', k.to(torch.float32), db).to(q.dtype) / np.sqrt(k.shape[1])
dk = torch.einsum('ncq,nqk->nck', q.to(torch.float32), db).to(k.dtype) / np.sqrt(k.shape[1])
return dq, dk
to :
class AttentionOp(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k):
w = torch.einsum('ncq,nck->nqk', q.to(torch.float32), (k / np.sqrt(k.shape[1])).to(torch.float32)).softmax(dim=2).to(q.dtype).contiguous()
ctx.save_for_backward(q, k, w)
return w
@staticmethod
def backward(ctx, dw):
q, k, w = ctx.saved_tensors
db = torch._softmax_backward_data(grad_output=dw.to(torch.float32), output=w.to(torch.float32), dim=2, input_dtype=torch.float32).contiguous()
dq = torch.einsum('nck,nqk->ncq', k.to(torch.float32), db).to(q.dtype).contiguous() / np.sqrt(k.shape[1])
dk = torch.einsum('ncq,nqk->nck', q.to(torch.float32), db).to(k.dtype).contiguous() / np.sqrt(k.shape[1])
return dq, dk
and in the UNetBlock module, from :
def forward(self, x, emb):
orig = x
x = self.conv0(silu(self.norm0(x)))
params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype)
if self.adaptive_scale:
scale, shift = params.chunk(chunks=2, dim=1)
x = silu(torch.addcmul(shift, self.norm1(x), scale + 1))
else:
x = silu(self.norm1(x.add_(params)))
x = self.conv1(torch.nn.functional.dropout(x, p=self.dropout, training=self.training))
x = x.add_(self.skip(orig) if self.skip is not None else orig)
x = x * self.skip_scale
if self.num_heads:
q, k, v = self.qkv(self.norm2(x)).reshape(x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1).unbind(2)
w = AttentionOp.apply(q, k)
a = torch.einsum('nqk,nck->ncq', w, v)
x = self.proj(a.reshape(*x.shape)).add_(x)
x = x * self.skip_scale
return x
to :
def forward(self, x, emb):
orig = x
x = self.conv0(silu(self.norm0(x)))
params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype)
if self.adaptive_scale:
scale, shift = params.chunk(chunks=2, dim=1)
x = silu(torch.addcmul(shift, self.norm1(x), scale + 1))
else:
x = silu(self.norm1(x.add_(params)))
x = self.conv1(torch.nn.functional.dropout(x, p=self.dropout, training=self.training))
x = x.add_(self.skip(orig) if self.skip is not None else orig)
x = x * self.skip_scale
if self.num_heads:
q, k, v = self.qkv(self.norm2(x)).reshape(x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1).unbind(2)
w = AttentionOp.apply(q, k)
a = torch.einsum('nqk,nck->ncq', w, v).contiguous()
x = self.proj(a.reshape(*x.shape)).add_(x)
x = x * self.skip_scale
return x
I know that the author find this warning/error and decided to mute it, to solve it you just need those few lines of code
Hi,
I noticed an error/warning because of the strides misshape in DDP when you train the architecture ddpmpp, a simple fix for this is to fix the operation in the attention modules contiguously. Here's the error :
params[166] in this process with sizes [256, 256, 1, 1] appears not to match strides of the same param in process 0.
To fix it you need to modify 5 lines of code. For the AttentionOP modules you modify the backward and forward from :
to :
and in the UNetBlock module, from :
to :
I know that the author find this warning/error and decided to mute it, to solve it you just need those few lines of code