alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
960 stars 86 forks source link

The delta value is unreasonable during calculation on MacOS #58

Open StepNeverStop opened 4 days ago

StepNeverStop commented 4 days ago

When I was doing experiments, I encountered the problem of loss being nan. The same problem was also mentioned in mamba-ssm. After debugging step by step, I found that this was mainly caused by incorrect calculation when executing the operation on some devices (such as macos 14.4.1 m3 pro for me).

As we all know, softplus should output a value greater than 0, but in my calculation below, it can be seen that it can actually output a negative value, at line. This is a very accidental error. It may be alleviated by adjusting the learning rate, initialization, etc., but the calculation error on some devices makes it impossible to eradicate the problem:

delta = delta.transpose(1, 2)
print(
    "delta",
    torch.isnan(delta).any(),
    delta.min().item(),
    delta.max().item(),
    delta.mean().item(),
)
print(
    "self.dt_proj.bias",
    torch.isnan(self.dt_proj.bias).any(),
    self.dt_proj.bias.min().item(),
    self.dt_proj.bias.max().item(),
    self.dt_proj.bias.mean().item(),
)
delta = F.softplus(delta + self.dt_proj.bias)
print(
    "delta",
    torch.isnan(delta).any(),
    delta.min().item(),
    delta.max().item(),
    delta.mean().item(),
)

and the outputs:

delta tensor(False, device='mps:0') -1.2973963022232056 0.5328638553619385 -0.7353726029396057
self.dt_proj.bias tensor(False, device='mps:0') -6.887321472167969 -2.252901554107666 -4.5936760902404785
delta tensor(False, device='mps:0') -8.184718132019043 -1.7200376987457275 -5.3290486335754395
delta tensor(True, device='mps:0') -0.45835205912590027 1.662992238998413 nan
self.dt_proj.bias tensor(False, device='mps:0') -6.9040141105651855 -2.2643179893493652 -4.700575351715088
delta tensor(True, device='mps:0') -0.45835205912590027 1.662992238998413 nan

This may be due to the difference in PyTorch's calculations on different devices, or it may be due to its own bugs. There is currently no perfect solution, unless you manually rewrite a softplus operation. At least it can be guaranteed that this problem will not occur when doing experiments under Linux CUDA.

Write a softplus function and try it out:

            def stable_softplus(x, threshold=20):
                """
                Implements a numerically stable version of the Softplus function.

                Args:
                    x (torch.Tensor): Input tensor.
                    threshold (float): The value above which exp(x) is considered large enough to approximate softplus.

                Returns:
                    torch.Tensor: The result of applying the Softplus function.
                """
                # For large values of x, approximate softplus(x) as x to avoid overflow in exp(x).
                return torch.where(x > threshold, x, torch.log1p(torch.exp(-torch.abs(x))) + torch.maximum(x, torch.tensor(0.0)))
            stable_delta = stable_softplus(delta+self.dt_proj.bias)
            print(
                "stable_delta",
                torch.isnan(stable_delta).any(),
                stable_delta.min().item(),
                stable_delta.max().item(),
                stable_delta.mean().item(),
            )
            delta = F.softplus(delta + self.dt_proj.bias)
            print(
                "delta",
                torch.isnan(delta).any(),
                delta.min().item(),
                delta.max().item(),
                delta.mean().item(),
            )

The result looks more normal:

delta tensor(False, device='mps:0') -1.2973963022232056 0.5328638553619385 -0.7353726029396057
self.dt_proj.bias tensor(False, device='mps:0') -6.887321472167969 -2.252901554107666 -4.5936760902404785
stable_delta tensor(False, device='mps:0') 0.00027879164554178715 0.16471698880195618 0.010891437530517578
delta tensor(False, device='mps:0') -8.184718132019043 -1.7200376987457275 -5.3290486335754395
alxndrTL commented 2 days ago

Hello, thank you for this detailed report.

That's very useful to know, I will update the repo with the possibility of using a manually defined softplus function. On CUDA devices, I didn't have any problem training mamba.py models, but yes as you said this could differ on other devices, and one of the goal of this repo is to allow Mamba training on non-CUDA devices, so it's kind of a big deal.