ml-explore / mlx-swift-examples

Examples using MLX Swift
MIT License
1.03k stars 111 forks source link

Download progress is not accurate #77

Closed DePasqualeOrg closed 6 months ago

DePasqualeOrg commented 6 months ago

In Libraries/LLM/Load.swift, the tokenizer is first loaded. As part of this process, config.json is downloaded if it hasn't been downloaded already. Then that same file is included in modelFiles (which are subsequently downloaded), and this small file contributes an equal amount to the download progress as a multi-gigabyte .safetensors file. This causes the download progress to perceptibly start or end at 50% when the model consists of one file. Wouldn't it make sense not to include config.json in modelFiles?

/// Load and return the model and tokenizer
public func load(
    hub: HubApi = HubApi(), configuration: ModelConfiguration,
    progressHandler: @escaping (Progress) -> Void = { _ in }
) async throws -> (LLMModel, Tokenizer) {
    do {
        let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)

        let modelDirectory: URL

        switch configuration.id {
        case .id(let id):
            // download the model weights and config
            let repo = Hub.Repo(id: id)
            let modelFiles = ["config.json", "*.safetensors"]
            modelDirectory = try await hub.snapshot(
                from: repo, matching: modelFiles, progressHandler: progressHandler)

        case .directory(let directory):
            modelDirectory = directory
        }

        // create the model (no weights loaded)
        let configurationURL = modelDirectory.appending(component: "config.json")
        let baseConfig = try JSONDecoder().decode(
            BaseConfiguration.self, from: Data(contentsOf: configurationURL))

        let model = try baseConfig.modelType.createModel(configuration: configurationURL)

        // load the weights
        var weights = [String: MLXArray]()
        let enumerator = FileManager.default.enumerator(
            at: modelDirectory, includingPropertiesForKeys: nil)!
        for case let url as URL in enumerator {
            if url.pathExtension == "safetensors" {
                let w = try loadArrays(url: url)
                for (key, value) in w {
                    weights[key] = value
                }
            }
        }

        // quantize if needed
        if let quantization = baseConfig.quantization {
            quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
        }

        // apply the loaded weights
        let parameters = ModuleParameters.unflattened(weights)
        try model.update(parameters: parameters, verify: [.all])

        eval(model)

        return (model, tokenizer)

    } catch Hub.HubClientError.authorizationRequired {
        // an authorizationRequired means (typically) that the named repo doesn't exist on
        // on the server so retry with local only configuration
        var newConfiguration = configuration
        newConfiguration.id = .directory(configuration.modelDirectory(hub: hub))
        return try await load(
            hub: hub, configuration: newConfiguration, progressHandler: progressHandler)
    }
}
davidkoski commented 6 months ago

That seems like a reasonable idea -- do you want to make a PR for it?

DePasqualeOrg commented 6 months ago

https://github.com/ml-explore/mlx-swift-examples/pull/78