Closed teowenshen closed 1 week ago
I am facing LLVM Error when training using this Mamba2Block code. Do you know what I am doing wrong?
class Mamba2Block(nn.Module): def __init__( self, num_heads : int, encoder_dim : int, mamba_dim : int, kernel_size : int, state_dim : int, # E A_init_range = (1, 16), dt_min = 0.001, dt_max = 0.1, dt_init_floor = 1e-4, dt_limit = (0.0, float("inf")), ): super().__init__() self.num_heads = num_heads self.state_dim = state_dim self.mamba_dim = mamba_dim self.n_groups = n_groups = 1 self.dt_limit = dt_limit self.kernel_size = kernel_size assert not (mamba_dim % num_heads), (mamba_dim % num_heads) # [z, x, B, C, dt] d_in_proj = 2 * mamba_dim + 2 * n_groups * state_dim + num_heads self.in_proj = nn.Linear(encoder_dim, d_in_proj, bias=False) conv_dim = mamba_dim + 2 * n_groups * state_dim self.conv1d = nn.Conv1d( in_channels=conv_dim, out_channels=conv_dim, bias=True, kernel_size=kernel_size, groups=conv_dim, padding=0, ) self.act = nn.SiLU() # Initialize log dt bias dt = torch.exp( torch.rand(num_heads,) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ) dt = torch.clamp(dt, min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) self.dt_bias = nn.Parameter(inv_dt) # A parameter assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] A = torch.empty(num_heads, dtype=torch.float32).uniform_(*A_init_range) A_log = torch.log(A) self.A_log = nn.Parameter(A_log) # D "skip" parameter self.D = nn.Parameter(torch.ones(num_heads)) # Extra normalization layer right before output projection self.norm = RMSNormGated( mamba_dim, eps=1e-5, norm_before_gate=False, ) self.out_proj = nn.Linear(mamba_dim, encoder_dim, bias=False) def forward(self, u : Tensor): """ u: (B, L, D) Returns: out : same shape as u """ batch, T, _ = u.shape mamba_dim = self.mamba_dim n_groups = self.n_groups num_heads = self.num_heads zxbcdt = self.in_proj(u) # (B, L, d_in_proj) A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state) dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) z, xBC, dt = torch.split( zxbcdt, [mamba_dim, mamba_dim + 2 * n_groups * self.state_dim, num_heads], dim=-1 ) dt = nn.functional.softplus(dt + self.dt_bias) # (B, L, nheads) # 1D Convolution xBC = xBC.transpose(1,2) # (B, T, C) -> (B, C, T) xBC = nn.functional.pad(xBC, pad=(self.kernel_size-1,0)) xBC = self.act(self.conv1d(xBC).transpose(1, 2)) # (B, L, self.d_inner + 2 * n_groups * d_state) # split into 3 main branches: X, B, C # These correspond to V, K, Q respectively in the SSM/attention duality x, B, C = torch.split(xBC, [mamba_dim, n_groups * self.state_dim, n_groups * self.state_dim], dim=-1) x = x.reshape(batch, T, num_heads, -1) B = B.reshape(batch, T, n_groups, -1) C = C.reshape(batch, T, n_groups, -1) y : Tensor = mamba_chunk_scan_combined( x, dt, A, B, C, chunk_size=256, D=self.D, z=None, **dt_limit_kwargs, ) y = y.reshape(batch, T, -1) # Multiply "gate" branch and apply extra normalization layer y = self.norm(y, z) out = self.out_proj(y) return out
This is the error.
LLVM ERROR: Failed to compute parent layout for slice layout. Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it): 0 libtriton.so 0x00007f6588955088 1 libtriton.so 0x00007f6584e91880 2 libtriton.so 0x00007f6588952bac 3 libtriton.so 0x00007f658895573d 4 libpthread.so.0 0x00007f665c107420 5 libc.so.6 0x00007f665bdea00b gsignal + 203 6 libc.so.6 0x00007f665bdc9859 abort + 299 7 libtriton.so 0x00007f65888b95ac 8 libtriton.so 0x00007f65888b93d6 9 libtriton.so 0x00007f6584b6dce2 10 libtriton.so 0x00007f6584c50e99 11 libtriton.so 0x00007f6584cccb62 12 libtriton.so 0x00007f6584cccfaf 13 libtriton.so 0x00007f6586913781 14 libtriton.so 0x00007f658695764a 15 libtriton.so 0x00007f658695406f 16 libtriton.so 0x00007f6586914547 17 libtriton.so 0x00007f6586913817 18 libtriton.so 0x00007f65869149a0 19 libtriton.so 0x00007f658691c75b 20 libtriton.so 0x00007f6584d74a3e 21 libtriton.so 0x00007f6585119c06 22 libtriton.so 0x00007f658511a432 23 libtriton.so 0x00007f658511cc5a 24 libtriton.so 0x00007f6584e4c832 25 libtriton.so 0x00007f6584dfc92d 26 python 0x00000000004fc697 27 python 0x00000000004f614b _PyObject_MakeTpCall + 603 28 python 0x000000000050819f 29 python 0x00000000004f1ac6 _PyEval_EvalFrameDefault + 19238 30 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 31 python 0x00000000004f1ac6 _PyEval_EvalFrameDefault + 19238 32 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 33 python 0x00000000004ed2bf _PyEval_EvalFrameDefault + 799 34 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 35 python 0x00000000004ee353 _PyEval_EvalFrameDefault + 5043 36 python 0x0000000000507eae 37 python 0x0000000000508858 PyObject_Call + 184 38 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 39 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 40 python 0x00000000004ed2bf _PyEval_EvalFrameDefault + 799 41 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 42 python 0x00000000004ee353 _PyEval_EvalFrameDefault + 5043 43 python 0x0000000000507eae 44 python 0x0000000000508858 PyObject_Call + 184 45 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 46 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 47 python 0x00000000004ed2bf _PyEval_EvalFrameDefault + 799 48 python 0x0000000000507eae 49 python 0x0000000000508858 PyObject_Call + 184 50 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 51 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 52 python 0x0000000000508858 PyObject_Call + 184 53 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 54 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 55 python 0x00000000004ee353 _PyEval_EvalFrameDefault + 5043 56 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 57 python 0x00000000004ee353 _PyEval_EvalFrameDefault + 5043 58 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 59 libtorch_python.so 0x00007f66140d5ca0 THPFunction_apply(_object*, _object*) + 4112 60 python 0x00000000004fc6c0 61 python 0x00000000005089a9 PyObject_Call + 521 62 python 0x00000000004f2a14 _PyEval_EvalFrameDefault + 23156 63 python 0x0000000000507eae 64 python 0x00000000004f1ac6 _PyEval_EvalFrameDefault + 19238 65 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 66 python 0x0000000000508858 PyObject_Call + 184 67 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 68 python 0x0000000000508006 69 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 70 python 0x0000000000508006 71 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 72 python 0x00000000004f561d _PyObject_FastCallDictTstate + 205 73 python 0x0000000000506596 _PyObject_Call_Prepend + 102 74 python 0x00000000005cc323 75 python 0x00000000004f614b _PyObject_MakeTpCall + 603 76 python 0x00000000004f2376 _PyEval_EvalFrameDefault + 21462 77 python 0x0000000000507eae 78 python 0x0000000000508858 PyObject_Call + 184 79 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 80 python 0x0000000000507eae 81 python 0x0000000000508858 PyObject_Call + 184 82 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 83 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 84 python 0x00000000004f56cd _PyObject_FastCallDictTstate + 381 85 python 0x0000000000506596 _PyObject_Call_Prepend + 102 86 python 0x00000000005cc323 87 python 0x00000000004f614b _PyObject_MakeTpCall + 603 88 python 0x00000000004f26f7 _PyEval_EvalFrameDefault + 22359 89 python 0x0000000000507eae 90 python 0x0000000000508858 PyObject_Call + 184 91 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 92 python 0x0000000000507eae 93 python 0x0000000000508858 PyObject_Call + 184 94 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 95 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 96 python 0x00000000004f56cd _PyObject_FastCallDictTstate + 381 97 python 0x0000000000506596 _PyObject_Call_Prepend + 102 98 python 0x00000000005cc323 99 python 0x00000000004f614b _PyObject_MakeTpCall + 603 100 python 0x00000000004f26f7 _PyEval_EvalFrameDefault + 22359 101 python 0x0000000000508006 102 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 103 python 0x0000000000508006 104 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 105 python 0x00000000004f561d _PyObject_FastCallDictTstate + 205 106 python 0x0000000000506596 _PyObject_Call_Prepend + 102 107 python 0x00000000005cc323 108 python 0x00000000004f614b _PyObject_MakeTpCall + 603 109 python 0x00000000004f2376 _PyEval_EvalFrameDefault + 21462 110 python 0x0000000000507eae 111 python 0x00000000004f1ac6 _PyEval_EvalFrameDefault + 19238 112 python 0x0000000000507eae 113 python 0x0000000000508858 PyObject_Call + 184 114 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 115 python 0x0000000000507eae 116 python 0x0000000000508858 PyObject_Call + 184 117 python 0x00000000004efb19 _PyEval_EvalFrameDefault + 11129 118 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 119 python 0x00000000004f56cd _PyObject_FastCallDictTstate + 381 120 python 0x0000000000506596 _PyObject_Call_Prepend + 102 121 python 0x00000000005cc323 122 python 0x00000000004f614b _PyObject_MakeTpCall + 603 123 python 0x00000000004f26f7 _PyEval_EvalFrameDefault + 22359 124 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 125 python 0x00000000004ee353 _PyEval_EvalFrameDefault + 5043 126 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 127 python 0x00000000004ee353 _PyEval_EvalFrameDefault + 5043 128 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 129 python 0x00000000004ee353 _PyEval_EvalFrameDefault + 5043 130 python 0x00000000004fcadf _PyFunction_Vectorcall + 111 131 python 0x00000000004ed2bf _PyEval_EvalFrameDefault + 799 132 python 0x0000000000591d92 133 python 0x0000000000591cd7 PyEval_EvalCode + 135 134 python 0x00000000005c2967 135 python 0x00000000005bdad0 136 python 0x000000000045956b 137 python 0x00000000005b805f _PyRun_SimpleFileObject + 415 138 python 0x00000000005b7dc3 _PyRun_AnyFileObject + 67 139 python 0x00000000005b4b7d Py_RunMain + 909 140 python 0x0000000000584e49 Py_BytesMain + 57 141 libc.so.6 0x00007f665bdcb083 __libc_start_main + 243 142 python 0x0000000000584cfe xargs: python: terminated by signal 6
My environment is as below. pytorch = 2.10 triton-nightly = 3.0.0.post20240716052845
I managed to fix (or work around) this issue by forcing all inputs to mamba_chunk_scan_combined to be torch.float32. I was using AMP in my training so some outputs were implicitly cast to half-precision.
mamba_chunk_scan_combined
torch.float32
I am facing LLVM Error when training using this Mamba2Block code. Do you know what I am doing wrong?
This is the error.
My environment is as below. pytorch = 2.10 triton-nightly = 3.0.0.post20240716052845