ml-explore / mlx-swift-examples

Examples using MLX Swift
MIT License
1.03k stars 111 forks source link

LLM evaluator setting any repetitionPenalty crashes the program #71

Closed shawiz closed 6 months ago

shawiz commented 6 months ago

I'm adding a repetitionPenalty to the GenerateParameters constructor. Regardless what values I set (I tried 0.5, 1, 1.2, 1.5), it crashes the program immediate as the evaluator runs. I was testing various Qwen1.5 models. Error message I got is

-[MTLDebugComputeCommandEncoder dispatchThreads:threadsPerThreadgroup:]:1441: failed assertion '(threadsPerGrid.width(0) * threadsPerGrid.y(1) * threadsPerGrid.depth(0))(0) must not be 0.'

davidkoski commented 6 months ago

I tried this on mlx-community/Mistral-7B-v0.1-hf-4bit-mlx without error. The Qwen1.5 models are not loading for me because of #53 -- I will try it again once that is resolved.

davidkoski commented 6 months ago

OK, I can reproduce it. I won't check it in yet because of the quantization issues in #53 -- it will be included there.

If you want to try it locally, here is the change:

diff --git a/Libraries/LLM/Evaluate.swift b/Libraries/LLM/Evaluate.swift
index 6c85558..89212a7 100644
--- a/Libraries/LLM/Evaluate.swift
+++ b/Libraries/LLM/Evaluate.swift
@@ -31,7 +31,7 @@ private func applyRepetitionPenalty(
 ) -> MLXArray {
     if repetitionContext.shape[0] > 0 {
         let indices = repetitionContext
-        var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0)
+        var selectedLogits = logits[0..., indices]

         selectedLogits = MLX.where(
             selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty)
@@ -100,7 +100,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
             if prompt.shape[0] <= parameters.repetitionContextSize {
                 self.repetitionContext = prompt
             } else {
-                self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1]
+                self.repetitionContext = prompt[(-parameters.repetitionContextSize)...]
             }
         } else {
             self.repetitionContext = []
@@ -120,9 +120,8 @@ public struct TokenIterator: Sequence, IteratorProtocol {
         y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
         // append the current token to the context and check repetitionPenalty context see if need to remove the first token
         if parameters.repetitionContextSize > 1 {
-            repetitionContext = concatenated([repetitionContext, y], axis: 0)
             if repetitionContext.shape[0] > parameters.repetitionContextSize {
-                repetitionContext = repetitionContext[1...]
+                repetitionContext = repetitionContext[(-parameters.repetitionContextSize)...]
             }
         }

I just switched it to use the full array indexing and made it conform to the python code. I don't know if this is a bug in the mlx core code or a bug in the calling code -- certainly the calling code requires some changes and I don't think it is logically the same.

davidkoski commented 6 months ago

76 should fi this