ml-explore / mlx-examples

Examples in the MLX framework
MIT License
5.5k stars 791 forks source link

LLMEvaluator : libc++abi: terminating due to uncaught exception of type std::invalid_argument: [matmul] Last dimension of first input with shape (1,916,2048) must match second to last dimension of second input with shape (256,32000) #820

Closed Paramstr closed 2 weeks ago

Paramstr commented 3 weeks ago

using the model through LLMEvaluator gives the following error.

libc++abi: terminating due to uncaught exception of type std::invalid_argument: [matmul] Last dimension of first input with shape (1,916,2048) must match second to last dimension of second input with shape (256,32000)

How I used it..


///--------------------------------------------------------------------------------------------------------
///
///Manages the shared LLMEvaluator Object
/// - Ensures only one LLM is loaded into memory
///
class LLMEvaluatorManager: ObservableObject {
    static let sharedLLM = LLMEvaluatorManager()

    @Published var llmEvaluator: LLMEvaluator?

    private init() {}

    func loadLLMEvaluator() {
        if llmEvaluator == nil {
            llmEvaluator = LLMEvaluator()
        }

    }
}
///--------------------------------------------------------------------------------------------------------

@Observable
class LLMEvaluator {

    @MainActor
    var running = false
    var runs:Int = 0
    var output = ""
    var modelInfo = ""
    var stat = ""
    var modelInputTokens = 0
    var modelOutputTokens = 0
    /// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on
    /// more devices
    ///
    //let modelConfiguration = ModelConfiguration.gemma2bQuantized

    let modelConfiguration = ModelConfiguration(
        id: "Paramstr/MLX-gemma-Code-Instruct-Finetune-test"
        //,overrideTokenizer: "PreTrainedTokenizer"
    ) { prompt in
        "<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
    }

//    let modelConfiguration = ModelConfiguration(
//        id: "mlx-community/OpenELM-1_1B-8bit"
//    )

    /// parameters controlling the output
    let temperature: Float = 0.6
    let maxTokens = 2000

    /// update the display every N tokens -- 4 looks like it updates continuously
    /// and is low overhead.  observed ~15% reduction in tokens/s when updating
    /// on every token
    let displayEveryNTokens = 10

    enum LoadState {
        case idle
        case loaded(LLMModel, Tokenizers.Tokenizer)
    }

    var loadState = LoadState.idle
    var loadloop = 0

    /// load and return the model -- can be called multiple times, subsequent calls will
    /// just return the loaded model
    func load() async throws -> (LLMModel, Tokenizers.Tokenizer) {
        loadloop += 1
        print("load() function loop: \(loadloop)")
        print("⚫️ LLM: Load() function accessed")
        switch loadState {
        case .idle:
            print("⚫️ LLM: Loading model and tokenizer...")
            // limit the buffer cache

            MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)

            let (model, tokenizer) = try await MLXLLM.load(configuration: modelConfiguration) {
                [modelConfiguration] progress in
                DispatchQueue.main.sync {
                    self.modelInfo =
                        "Downloading \(modelConfiguration.id): \(Int(progress.fractionCompleted * 100))%"
                }
            }
            self.modelInfo =
                "Loaded \(modelConfiguration.id).  Weights: \(MLX.GPU.activeMemory / 1024 / 1024)M"
            loadState = .loaded(model, tokenizer)
            print("⚫️ LLM: Model and tokenizer loaded successfully.")

            return (model, tokenizer)

        case .loaded(let model, let tokenizer):
            print("⚫️ LLM: Model and tokenizer already Loaded.")
            return (model, tokenizer)
        }
    }

    func generate(prompt: String) async -> String {
            let canGenerate = await MainActor.run {
                if running {
                    return false
                } else {
                    running = true
                    self.output = ""
                    return true
                }
            }

            guard canGenerate else { return self.output}

            do {
                let (model, tokenizer) = try await load()
                // augment the prompt as needed
                let prompt = modelConfiguration.prepare(prompt: prompt)
                let promptTokens = tokenizer.encode(text: prompt)

                // each time you generate you will get something new
                MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))

                let result = await MLXLLM.generate(
                    promptTokens: promptTokens, parameters: GenerateParameters(), model: model,
                    tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
                ) { tokens in
                    // update the output -- this will make the view show the text as it generates
                    if tokens.count % displayEveryNTokens == 0 {
                        let text = tokenizer.decode(tokens: tokens)
                        await MainActor.run {
                            self.output = text
                        }
                    }

                    if tokens.count >= maxTokens {
                        return .stop
                    } else {
                        return .more
                    }
                }

                // update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
                await MainActor.run {
                    if result.output != self.output {
                        self.output = result.output
                    }

                    running = false
                    self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))"
                    return self.output

                }

            } catch {
                await MainActor.run {
                    running = false
                    output = "Failed: \(error)"
                }
            }
        return self.output

        }

}

note: Code works with MLX converted Gemma 1.1b .