ecmwf / anemoi-models

Apache License 2.0
34 stars 12 forks source link

Specifying normalization layers. #87

Closed jakob-schloer closed 12 hours ago

jakob-schloer commented 13 hours ago

Is your feature request related to a problem? Please describe.

Currently, the processor is implemented with LayerNormalization. I would like to use other normalization layers (https://pytorch.org/docs/stable/nn.html#normalization-layers) including custom normalization layers.

Describe the solution you'd like

I would like to specify the normalization layer of the processor in the config, e.g. transformer.yaml:

layer_norm:  # This needs to be a partial instantiation since it is used in multiple places
  _target_: torch.nn.LayerNorm 
  _partial_: True
  normalized_shape: ${model.num_channels}

processor:
  _target_: anemoi.models.processor.TransformerProcessor
  _convert_: all
  activation: ${model.activation}
  num_layers: 16
  num_chunks: 2
  mlp_hidden_ratio: 4 # GraphTransformer or Transformer only
  num_heads: 16 # GraphTransformer or Transformer only
  window_size: 512
  dropout_p: 0.0 # GraphTransformer
  layer_norm: ${model.layer_norm} # (Optional) Default nn.LayerNorm

Describe alternatives you've considered

No response

Additional context

No response

Organisation

No response

clessig commented 13 hours ago

Cathal already did experiments with RMSNorm (but from TransformerEngine, I think). It might have been hard coded but good to coordinate.

CC: @cathalobrien

cathalobrien commented 13 hours ago

Hey, yeah i have this PR https://github.com/ecmwf/anemoi-models/pull/35 . I put it on ice a while back bc I thought it would cause problems in inference if we have arbitrary functions in the checkpoint file.

but now that the checkpoints are weights only, it should be fine. I can refresh it next week

jakob-schloer commented 13 hours ago

I see, this is related but I was thinking of something more general. I would like to be able to write custom normalization layers, e.g.

class TransformerProcessorBlock(BaseBlock):
    """Transformer block with MultiHeadSelfAttention and MLPs."""

    def __init__(
        self,
        num_channels: int,
        hidden_dim: int,
        num_heads: int,
        activation: str,
        window_size: int,
        dropout_p: float = 0.0,
        layer_norm: Optional[dict] = None,
    ):
        super().__init__()

        try:
            act_func = getattr(nn, activation)
        except AttributeError as ae:
            LOGGER.error("Activation function %s not supported", activation)
            raise RuntimeError from ae

        # Instantiate normalization layers using Hydra
        self.layer_norm1 = layer_norm()
        self.layer_norm2 = layer_norm()
        ...

    def forward(
        self,
        x: Tensor,
        shapes: list,
        batch_size: int,
        model_comm_group: Optional[ProcessGroup] = None,
        **layer_kwargs,
    ) -> Tensor:
        # Need to be out of place for gradient propagation
        x = x + self.attention(self.layer_norm1(x, **layer_kwargs), shapes, batch_size, model_comm_group=model_comm_group)
        x = x + self.mlp(self.layer_norm2(x, **layer_kwargs))
        return x

Do you think this could be combined with your PR @cathalobrien?

cathalobrien commented 13 hours ago

Ah I see, yeah I think this should work.

I already have this implemented

    LayerNorm:
      #_target_: "torch.nn.LayerNorm" #the default PyTorch implementation
      _target_: "liger_kernel.transformers.rms_norm.LigerRMSNorm" # my desired layernorm
      _partial_: True

I havent tried with a handwritten layernorm, but i assume as long as the import in target points to the right place it should be fine.

I like your idea of passing **layer_kwargs directly to the instantiated layer_norm, i was wondering how to handle arbitrary parameters at the time.

jakob-schloer commented 12 hours ago

I like your idea of passing **layer_kwargs directly to the instantiated layer_norm, i was wondering how to handle arbitrary parameters at the time.

On a second thought, I believe it should be only **kwargs. In the future someone wants to do something else in the forward function.

clessig commented 12 hours ago

Yes, e.g. cross attention or some fancy bias terms for the attention could also be passed.

jakob-schloer commented 12 hours ago

I close this, since PR https://github.com/ecmwf/anemoi-models/pull/35 has this already.