Closed shawiz closed 6 months ago
I tried this on mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
without error. The Qwen1.5 models are not loading for me because of #53 -- I will try it again once that is resolved.
OK, I can reproduce it. I won't check it in yet because of the quantization issues in #53 -- it will be included there.
If you want to try it locally, here is the change:
diff --git a/Libraries/LLM/Evaluate.swift b/Libraries/LLM/Evaluate.swift
index 6c85558..89212a7 100644
--- a/Libraries/LLM/Evaluate.swift
+++ b/Libraries/LLM/Evaluate.swift
@@ -31,7 +31,7 @@ private func applyRepetitionPenalty(
) -> MLXArray {
if repetitionContext.shape[0] > 0 {
let indices = repetitionContext
- var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0)
+ var selectedLogits = logits[0..., indices]
selectedLogits = MLX.where(
selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty)
@@ -100,7 +100,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
if prompt.shape[0] <= parameters.repetitionContextSize {
self.repetitionContext = prompt
} else {
- self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1]
+ self.repetitionContext = prompt[(-parameters.repetitionContextSize)...]
}
} else {
self.repetitionContext = []
@@ -120,9 +120,8 @@ public struct TokenIterator: Sequence, IteratorProtocol {
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
if parameters.repetitionContextSize > 1 {
- repetitionContext = concatenated([repetitionContext, y], axis: 0)
if repetitionContext.shape[0] > parameters.repetitionContextSize {
- repetitionContext = repetitionContext[1...]
+ repetitionContext = repetitionContext[(-parameters.repetitionContextSize)...]
}
}
I just switched it to use the full array indexing and made it conform to the python code. I don't know if this is a bug in the mlx core code or a bug in the calling code -- certainly the calling code requires some changes and I don't think it is logically the same.
I'm adding a repetitionPenalty to the GenerateParameters constructor. Regardless what values I set (I tried 0.5, 1, 1.2, 1.5), it crashes the program immediate as the evaluator runs. I was testing various Qwen1.5 models. Error message I got is
-[MTLDebugComputeCommandEncoder dispatchThreads:threadsPerThreadgroup:]:1441: failed assertion '(threadsPerGrid.width(0) * threadsPerGrid.y(1) * threadsPerGrid.depth(0))(0) must not be 0.'