Closed jakob-schloer closed 12 hours ago
Cathal already did experiments with RMSNorm (but from TransformerEngine, I think). It might have been hard coded but good to coordinate.
CC: @cathalobrien
Hey, yeah i have this PR https://github.com/ecmwf/anemoi-models/pull/35 . I put it on ice a while back bc I thought it would cause problems in inference if we have arbitrary functions in the checkpoint file.
but now that the checkpoints are weights only, it should be fine. I can refresh it next week
I see, this is related but I was thinking of something more general. I would like to be able to write custom normalization layers, e.g.
class TransformerProcessorBlock(BaseBlock):
"""Transformer block with MultiHeadSelfAttention and MLPs."""
def __init__(
self,
num_channels: int,
hidden_dim: int,
num_heads: int,
activation: str,
window_size: int,
dropout_p: float = 0.0,
layer_norm: Optional[dict] = None,
):
super().__init__()
try:
act_func = getattr(nn, activation)
except AttributeError as ae:
LOGGER.error("Activation function %s not supported", activation)
raise RuntimeError from ae
# Instantiate normalization layers using Hydra
self.layer_norm1 = layer_norm()
self.layer_norm2 = layer_norm()
...
def forward(
self,
x: Tensor,
shapes: list,
batch_size: int,
model_comm_group: Optional[ProcessGroup] = None,
**layer_kwargs,
) -> Tensor:
# Need to be out of place for gradient propagation
x = x + self.attention(self.layer_norm1(x, **layer_kwargs), shapes, batch_size, model_comm_group=model_comm_group)
x = x + self.mlp(self.layer_norm2(x, **layer_kwargs))
return x
Do you think this could be combined with your PR @cathalobrien?
Ah I see, yeah I think this should work.
I already have this implemented
LayerNorm:
#_target_: "torch.nn.LayerNorm" #the default PyTorch implementation
_target_: "liger_kernel.transformers.rms_norm.LigerRMSNorm" # my desired layernorm
_partial_: True
I havent tried with a handwritten layernorm, but i assume as long as the import in target
points to the right place it should be fine.
I like your idea of passing **layer_kwargs
directly to the instantiated layer_norm, i was wondering how to handle arbitrary parameters at the time.
I like your idea of passing
**layer_kwargs
directly to the instantiated layer_norm, i was wondering how to handle arbitrary parameters at the time.
On a second thought, I believe it should be only **kwargs. In the future someone wants to do something else in the forward function.
Yes, e.g. cross attention or some fancy bias terms for the attention could also be passed.
I close this, since PR https://github.com/ecmwf/anemoi-models/pull/35 has this already.
Is your feature request related to a problem? Please describe.
Currently, the processor is implemented with LayerNormalization. I would like to use other normalization layers (https://pytorch.org/docs/stable/nn.html#normalization-layers) including custom normalization layers.
Describe the solution you'd like
I would like to specify the normalization layer of the processor in the config, e.g. transformer.yaml:
Describe alternatives you've considered
No response
Additional context
No response
Organisation
No response