ecmwf / anemoi-training

Apache License 2.0
17 stars 17 forks source link

Add ScaleTensor #96

Closed HCookie closed 3 weeks ago

HCookie commented 1 month ago

ScaleTensor

    Examples
    --------
    >>> tensor = torch.randn(3, 4, 5)
    >>> scalars = ScaleTensor((0, torch.randn(3)), (1, torch.randn(4)))
    >>> scaled_tensor = scalars.scale(tensor)
    >>> scalars.get_scalar(tensor.ndim).shape
    torch.Size([3, 4, 1])
    >>> scalars.add_scalar(-1, torch.randn(5))
    >>> scalars.get_scalar(tensor.ndim).shape
    torch.Size([3, 4, 5])