ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[BUG] libc++abi crash when using recurrent layer and transformer #1063

Open domschl opened 2 weeks ago

domschl commented 2 weeks ago

Describe the bug

(libc++abi: terminating due to uncaught exception of type std::runtime_error: 
[compile] Too many inputs/outputs fused in the Metal Compiled primitive which 
exhausted the available argument buffers for the kernel. Please file an issue with 
the function that results in this error. The name of the kernel is
'Nf4MultiplyABOf4AddEFPf4AddOGQf4AddPHRf4AddQISf4MultiplyRJTf4AddDSUf4MultiplyCTVf4SquareRWf4MultiplyVLXf4AddKWYf4SqrtXZf4AddYMAAf4DivideUZABf4SubtractNAA_VVVVVVVVVVVVV_f4f4f4f4f4f4f4f4f4f4f4f4f4_11160318154034397263_strided_dynamic')

In:

def __call__(self, x):
        L = x.shape[1]
        mask = nn.MultiHeadAttention.create_additive_causal_mask(L)
        x = self.embed(x)
        x = x + self.pe(mx.arange(L))
        x = self.transformer(x, mask)
        x = self.context_recurrent(x) + x
        x = self.transformer2(x, mask)
        x = self.out_proj(x)
        return x

If self.context_recurrent is any of nn.LSTM, RNN, or GRU, the above crash happens.

To Reproduce

Insert x = nn.RNN(x) between two transformer layers causes the crash.

Complete code https://github.com/domschl/mlx-poet/blob/cedac548256a1bd2a1bb33362cf9d99f22a360c7/mlx_poet_bug.py (requires pip install ml-indie-tools)

Expected behavior

No crash, and if necessary clear error message. I've checked that there is no tensor-shape problem.

Desktop (please complete the following information):

Additional context Add any other context about the problem here.

angeloskath commented 2 weeks ago

Hi @domschl, thanks for the bug report. The bug is from the compilation. There is a subgraph that is too big (actually just has too many inputs) to fuse into a single kernel but compile still tries and fails. This big cryptic string is actually a representation of the graph to be fused.

We 'll look into fixing it (ie compile should break the subgraph in two smaller ones). In the meantime, you could disable compile and the code should run fine.

domschl commented 2 weeks ago

Tx! Confirmed: without compilation it works fine.