Open fordflip opened 2 months ago
Hi, could you give a little bit more info on the usecase? LayerNorm already has scale and offset in weight
and bias
- why do you need an addiitonal set of those parameters?
For things like DiT https://arxiv.org/abs/2212.09748. Adaptive Layernorm (adaptive to a condition vector of some sort)
Hm, interesting - I looked at the HuggingFace implementation here: https://github.com/huggingface/diffusers/blob/v0.27.2/src/diffusers/models/normalization.py#L28. It basically computes the weight and bias of LayerNorm rather than keeping them as parameters.
A quick and dirty implementation would basically take the LayerNormMLP implementation from https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py, added the code computing weight and bias (scale and shift in HF implementation) in forward
function here: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1496 and pass them instead of self.layer_norm_weight
and self.layer_norm_bias
here (use scale
, not 1+scale
, see note below): https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1551
A random note on numerical precision: the code from HF is multiplying the LN output by 1+scale
. This is not great if you use e.g. BF16 since this addition would be performed in that precision then which loses some precision (BF16 as all floating-point formats preserve the most precision around 0, not 1). That is why we introduced the option zero_centered_gamma
in our LN implementation, which takes the weight and adds 1 to it inside the LayerNorm kernel in FP32 precision. That's why I would enable that option and pass just the scale
and not 1+scale
.
would also be very interested in both an AdaLayerNorm and AdaLayerNormMLP, or alternatively a fused MLP without norm included, as requested in https://github.com/NVIDIA/TransformerEngine/issues/817
Hm, interesting - I looked at the HuggingFace implementation here: https://github.com/huggingface/diffusers/blob/v0.27.2/src/diffusers/models/normalization.py#L28. It basically computes the weight and bias of LayerNorm rather than keeping them as parameters.
A quick and dirty implementation would basically take the LayerNormMLP implementation from https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py, added the code computing weight and bias (scale and shift in HF implementation) in
forward
function here: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1496 and pass them instead ofself.layer_norm_weight
andself.layer_norm_bias
here (usescale
, not1+scale
, see note below): https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1551A random note on numerical precision: the code from HF is multiplying the LN output by
1+scale
. This is not great if you use e.g. BF16 since this addition would be performed in that precision then which loses some precision (BF16 as all floating-point formats preserve the most precision around 0, not 1). That is why we introduced the optionzero_centered_gamma
in our LN implementation, which takes the weight and adds 1 to it inside the LayerNorm kernel in FP32 precision. That's why I would enable that option and pass just thescale
and not1+scale
.
Would this be a correct implementation? I don't know if the _LayerNorm forward methods can take batch sizes though, so I just added a for loop (which is probably pretty slow)
class AdaLayerNorm(LayerNorm):
def __init__(self, hidden_size, cond_hidden_size, scale=True, shift=True, bias=True, **kwargs):
assert not (not scale and not shift)
kwargs["zero_centered_gamma"] = True
super().__init__(hidden_size, **kwargs)
weight = torch.zeros_like(self.weight)
bias = torch.zeros_like(self.bias)
del self.weight, self.bias
self.register_buffer("weight", weight)
self.register_buffer("bias", bias)
if scale and shift:
self.c_proj = Linear(cond_hidden_size, hidden_size * 2, bias=bias)
else:
self.c_proj = Linear(cond_hidden_size, hidden_size, bias=bias)
self.scale = scale
self.shift = shift
@no_torch_dynamo()
def forward(self, x, cond):
# Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, x)
assert len(cond.shape) <= 2
if len(cond.shape) == 1:
cond = cond.unsqueeze(0).expand(x.shape[0], cond.shape[-1])
embs = self.c_proj(cond)
if torch.is_grad_enabled():
fwd_fn = _LayerNorm.apply
else:
fwd_fn = _LayerNorm.forward
out = []
for emb in embs:
if self.scale and self.shift:
scale, shift = emb.chunk(2, dim=-1)
elif self.scale:
scale, shift = emb, self.bias
elif self.shift:
scale, shift = self.weight, emb
if torch.is_grad_enabled():
args = []
else:
args = [None]
args += (
inp,
scale,
shift,
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.inf_ln_sm_margin,
self.zero_centered_gamma,
torch.is_grad_enabled(),
self.activation_dtype,
)
out.append(fwd_fn(*args))
return torch.stack(out)
yeah, considering that LayerNorm doesn't seem to be able to take batch size on the conditioning, this probably needs to be implemented a different way since the for loop is slow
Hm, interesting - I looked at the HuggingFace implementation here: https://github.com/huggingface/diffusers/blob/v0.27.2/src/diffusers/models/normalization.py#L28. It basically computes the weight and bias of LayerNorm rather than keeping them as parameters.
A quick and dirty implementation would basically take the LayerNormMLP implementation from https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py, added the code computing weight and bias (scale and shift in HF implementation) in
forward
function here: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1496 and pass them instead ofself.layer_norm_weight
andself.layer_norm_bias
here (usescale
, not1+scale
, see note below): https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1551A random note on numerical precision: the code from HF is multiplying the LN output by
1+scale
. This is not great if you use e.g. BF16 since this addition would be performed in that precision then which loses some precision (BF16 as all floating-point formats preserve the most precision around 0, not 1). That is why we introduced the optionzero_centered_gamma
in our LN implementation, which takes the weight and adds 1 to it inside the LayerNorm kernel in FP32 precision. That's why I would enable that option and pass just thescale
and not1+scale
.
about that numerical precision thing, I'm trying to make a compiled version of it and I'm confused as to whether only the addition is performed in fp32 or the multiplication too?
@torch.compile
def ada_rms_norm(x: torch.Tensor, n_weight: torch.Tensor):
B, D = x.shape[0], x.shape[-1]
scale = D ** 0.5
n_weight = n_weight.view(B, *((1,) * (len(x.shape)-2)), -1)
return F.normalize(x, dim=-1) * (1 + n_weight.to(dtype=torch.float32)).to(dtype=x.dtype) * scale
It'd be amazing to have support for a pytorch LayerNormMLP implementation that supports a scale and offset tensor to be applied after the layernorm but before the MLP. Would be curious to hear what it would take to implement this! happy to help