state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.65k stars 1.06k forks source link

LLVM Error when training mamba2 #566

Closed teowenshen closed 1 week ago

teowenshen commented 1 week ago

I am facing LLVM Error when training using this Mamba2Block code. Do you know what I am doing wrong?

class Mamba2Block(nn.Module):
    def __init__(
        self,
        num_heads : int,
        encoder_dim : int,
        mamba_dim : int,
        kernel_size : int,
        state_dim : int, # E
        A_init_range = (1, 16),
        dt_min = 0.001,
        dt_max = 0.1,
        dt_init_floor = 1e-4,
        dt_limit = (0.0, float("inf")),
    ):
        super().__init__()
        self.num_heads = num_heads
        self.state_dim = state_dim
        self.mamba_dim = mamba_dim 
        self.n_groups = n_groups = 1
        self.dt_limit = dt_limit
        self.kernel_size = kernel_size

        assert not (mamba_dim % num_heads), (mamba_dim % num_heads)

        # [z, x, B, C, dt]
        d_in_proj = 2 * mamba_dim + 2 * n_groups * state_dim + num_heads
        self.in_proj = nn.Linear(encoder_dim, d_in_proj, bias=False)

        conv_dim = mamba_dim + 2 * n_groups * state_dim
        self.conv1d = nn.Conv1d(
            in_channels=conv_dim,
            out_channels=conv_dim,
            bias=True,
            kernel_size=kernel_size,
            groups=conv_dim,
            padding=0,
        )

        self.act = nn.SiLU()

        # Initialize log dt bias
        dt = torch.exp(
            torch.rand(num_heads,) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        )
        dt = torch.clamp(dt, min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        self.dt_bias = nn.Parameter(inv_dt)

        # A parameter
        assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
        A = torch.empty(num_heads, dtype=torch.float32).uniform_(*A_init_range)
        A_log = torch.log(A)
        self.A_log = nn.Parameter(A_log)

        # D "skip" parameter
        self.D = nn.Parameter(torch.ones(num_heads))

        # Extra normalization layer right before output projection
        self.norm = RMSNormGated(
            mamba_dim,
            eps=1e-5,
            norm_before_gate=False,
        )

        self.out_proj = nn.Linear(mamba_dim, encoder_dim, bias=False)

    def forward(self, u : Tensor):
        """
        u: (B, L, D)
        Returns: out : same shape as u
        """

        batch, T, _ = u.shape
        mamba_dim = self.mamba_dim
        n_groups = self.n_groups
        num_heads = self.num_heads

        zxbcdt = self.in_proj(u)  # (B, L, d_in_proj)
        A = -torch.exp(self.A_log)  # (nheads) or (d_inner, d_state)
        dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)

        z, xBC, dt = torch.split(
            zxbcdt, [mamba_dim, mamba_dim + 2 * n_groups * self.state_dim, num_heads], dim=-1
        )
        dt = nn.functional.softplus(dt + self.dt_bias)  # (B, L, nheads)

        # 1D Convolution
        xBC = xBC.transpose(1,2) # (B, T, C) -> (B, C, T)
        xBC = nn.functional.pad(xBC, pad=(self.kernel_size-1,0))
        xBC = self.act(self.conv1d(xBC).transpose(1, 2)) # (B, L, self.d_inner + 2 * n_groups * d_state)

        # split into 3 main branches: X, B, C
        # These correspond to V, K, Q respectively in the SSM/attention duality
        x, B, C = torch.split(xBC, [mamba_dim, n_groups * self.state_dim, n_groups * self.state_dim], dim=-1)
        x = x.reshape(batch, T, num_heads, -1)
        B = B.reshape(batch, T, n_groups, -1)
        C = C.reshape(batch, T, n_groups, -1)

        y : Tensor = mamba_chunk_scan_combined(
            x,
            dt,
            A,
            B,
            C,
            chunk_size=256,
            D=self.D,
            z=None,
            **dt_limit_kwargs,
        )
        y = y.reshape(batch, T, -1)

        # Multiply "gate" branch and apply extra normalization layer
        y = self.norm(y, z)
        out = self.out_proj(y)
        return out

This is the error.

LLVM ERROR: Failed to compute parent layout for slice layout.
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  libtriton.so       0x00007f6588955088
1  libtriton.so       0x00007f6584e91880
2  libtriton.so       0x00007f6588952bac
3  libtriton.so       0x00007f658895573d
4  libpthread.so.0    0x00007f665c107420
5  libc.so.6          0x00007f665bdea00b gsignal + 203
6  libc.so.6          0x00007f665bdc9859 abort + 299
7  libtriton.so       0x00007f65888b95ac
8  libtriton.so       0x00007f65888b93d6
9  libtriton.so       0x00007f6584b6dce2
10 libtriton.so       0x00007f6584c50e99
11 libtriton.so       0x00007f6584cccb62
12 libtriton.so       0x00007f6584cccfaf
13 libtriton.so       0x00007f6586913781
14 libtriton.so       0x00007f658695764a
15 libtriton.so       0x00007f658695406f
16 libtriton.so       0x00007f6586914547
17 libtriton.so       0x00007f6586913817
18 libtriton.so       0x00007f65869149a0
19 libtriton.so       0x00007f658691c75b
20 libtriton.so       0x00007f6584d74a3e
21 libtriton.so       0x00007f6585119c06
22 libtriton.so       0x00007f658511a432
23 libtriton.so       0x00007f658511cc5a
24 libtriton.so       0x00007f6584e4c832
25 libtriton.so       0x00007f6584dfc92d
26 python             0x00000000004fc697
27 python             0x00000000004f614b _PyObject_MakeTpCall + 603
28 python             0x000000000050819f
29 python             0x00000000004f1ac6 _PyEval_EvalFrameDefault + 19238
30 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
31 python             0x00000000004f1ac6 _PyEval_EvalFrameDefault + 19238
32 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
33 python             0x00000000004ed2bf _PyEval_EvalFrameDefault + 799
34 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
35 python             0x00000000004ee353 _PyEval_EvalFrameDefault + 5043
36 python             0x0000000000507eae
37 python             0x0000000000508858 PyObject_Call + 184
38 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
39 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
40 python             0x00000000004ed2bf _PyEval_EvalFrameDefault + 799
41 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
42 python             0x00000000004ee353 _PyEval_EvalFrameDefault + 5043
43 python             0x0000000000507eae
44 python             0x0000000000508858 PyObject_Call + 184
45 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
46 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
47 python             0x00000000004ed2bf _PyEval_EvalFrameDefault + 799
48 python             0x0000000000507eae
49 python             0x0000000000508858 PyObject_Call + 184
50 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
51 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
52 python             0x0000000000508858 PyObject_Call + 184
53 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
54 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
55 python             0x00000000004ee353 _PyEval_EvalFrameDefault + 5043
56 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
57 python             0x00000000004ee353 _PyEval_EvalFrameDefault + 5043
58 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
59 libtorch_python.so 0x00007f66140d5ca0 THPFunction_apply(_object*, _object*) + 4112
60 python             0x00000000004fc6c0
61 python             0x00000000005089a9 PyObject_Call + 521
62 python             0x00000000004f2a14 _PyEval_EvalFrameDefault + 23156
63 python             0x0000000000507eae
64 python             0x00000000004f1ac6 _PyEval_EvalFrameDefault + 19238
65 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
66 python             0x0000000000508858 PyObject_Call + 184
67 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
68 python             0x0000000000508006
69 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
70 python             0x0000000000508006
71 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
72 python             0x00000000004f561d _PyObject_FastCallDictTstate + 205
73 python             0x0000000000506596 _PyObject_Call_Prepend + 102
74 python             0x00000000005cc323
75 python             0x00000000004f614b _PyObject_MakeTpCall + 603
76 python             0x00000000004f2376 _PyEval_EvalFrameDefault + 21462
77 python             0x0000000000507eae
78 python             0x0000000000508858 PyObject_Call + 184
79 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
80 python             0x0000000000507eae
81 python             0x0000000000508858 PyObject_Call + 184
82 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
83 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
84 python             0x00000000004f56cd _PyObject_FastCallDictTstate + 381
85 python             0x0000000000506596 _PyObject_Call_Prepend + 102
86 python             0x00000000005cc323
87 python             0x00000000004f614b _PyObject_MakeTpCall + 603
88 python             0x00000000004f26f7 _PyEval_EvalFrameDefault + 22359
89 python             0x0000000000507eae
90 python             0x0000000000508858 PyObject_Call + 184
91 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
92 python             0x0000000000507eae
93 python             0x0000000000508858 PyObject_Call + 184
94 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
95 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
96 python             0x00000000004f56cd _PyObject_FastCallDictTstate + 381
97 python             0x0000000000506596 _PyObject_Call_Prepend + 102
98 python             0x00000000005cc323
99 python             0x00000000004f614b _PyObject_MakeTpCall + 603
100 python             0x00000000004f26f7 _PyEval_EvalFrameDefault + 22359
101 python             0x0000000000508006
102 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
103 python             0x0000000000508006
104 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
105 python             0x00000000004f561d _PyObject_FastCallDictTstate + 205
106 python             0x0000000000506596 _PyObject_Call_Prepend + 102
107 python             0x00000000005cc323
108 python             0x00000000004f614b _PyObject_MakeTpCall + 603
109 python             0x00000000004f2376 _PyEval_EvalFrameDefault + 21462
110 python             0x0000000000507eae
111 python             0x00000000004f1ac6 _PyEval_EvalFrameDefault + 19238
112 python             0x0000000000507eae
113 python             0x0000000000508858 PyObject_Call + 184
114 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
115 python             0x0000000000507eae
116 python             0x0000000000508858 PyObject_Call + 184
117 python             0x00000000004efb19 _PyEval_EvalFrameDefault + 11129
118 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
119 python             0x00000000004f56cd _PyObject_FastCallDictTstate + 381
120 python             0x0000000000506596 _PyObject_Call_Prepend + 102
121 python             0x00000000005cc323
122 python             0x00000000004f614b _PyObject_MakeTpCall + 603
123 python             0x00000000004f26f7 _PyEval_EvalFrameDefault + 22359
124 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
125 python             0x00000000004ee353 _PyEval_EvalFrameDefault + 5043
126 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
127 python             0x00000000004ee353 _PyEval_EvalFrameDefault + 5043
128 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
129 python             0x00000000004ee353 _PyEval_EvalFrameDefault + 5043
130 python             0x00000000004fcadf _PyFunction_Vectorcall + 111
131 python             0x00000000004ed2bf _PyEval_EvalFrameDefault + 799
132 python             0x0000000000591d92
133 python             0x0000000000591cd7 PyEval_EvalCode + 135
134 python             0x00000000005c2967
135 python             0x00000000005bdad0
136 python             0x000000000045956b
137 python             0x00000000005b805f _PyRun_SimpleFileObject + 415
138 python             0x00000000005b7dc3 _PyRun_AnyFileObject + 67
139 python             0x00000000005b4b7d Py_RunMain + 909
140 python             0x0000000000584e49 Py_BytesMain + 57
141 libc.so.6          0x00007f665bdcb083 __libc_start_main + 243
142 python             0x0000000000584cfe
xargs: python: terminated by signal 6

My environment is as below. pytorch = 2.10 triton-nightly = 3.0.0.post20240716052845

teowenshen commented 1 week ago

I managed to fix (or work around) this issue by forcing all inputs to mamba_chunk_scan_combined to be torch.float32. I was using AMP in my training so some outputs were implicitly cast to half-precision.