huggingface / swift-transformers

Swift Package to implement a transformers-like API in Swift
Apache License 2.0
657 stars 70 forks source link

download progress handler not being called #59

Closed davidkoski closed 6 months ago

davidkoski commented 6 months ago

tag 0.1.2

I see the progress handler for hub.snapshot() being called only at the beginning and end:

    let hub = HubApi()
    let repo = Hub.Repo(id: "mlx-community/starcoder2-3b-4bit")
    let modelFiles = ["config.json", "*.safetensors"]
    let modelDirectory = try await hub.snapshot(
        from: repo, matching: modelFiles) { progress in
            print(progress)
        }

gives:

<NSProgress: 0x600001d7e200> : Parent: 0x0 (portion: 0) / Fraction completed: 0.0000 / Completed: 0 of 2  
  <NSProgress: 0x600001d7ec80> : Parent: 0x600001d7e200 (portion: 1) / Fraction completed: 0.0000 / Completed: 0 of 100  

<NSProgress: 0x600001d7e200> : Parent: 0x0 (portion: 0) / Fraction completed: 1.0000 / Completed: 2 of 2  

I swear I have seen this get called for the download but I can't seem to reproduce it now.

I can see:

extension Downloader: URLSessionDelegate, URLSessionDownloadDelegate {
    func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten _: Int64, totalBytesExpectedToWrite _: Int64) {

get called and I see that this subscriber:

class Downloader: NSObject, ObservableObject {
...
    @discardableResult
    func waitUntilDone() throws -> URL {
        // It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky)
        let semaphore = DispatchSemaphore(value: 0)
        stateSubscriber = downloadState.sink { state in
            switch state {
            case .completed: semaphore.signal()
            case .failed:    semaphore.signal()
            default:         break
            }
        }

Is receiving them. I think the problem may be here:

public extension HubApi {
    func localRepoLocation(_ repo: Repo) -> URL {
        downloadBase.appending(component: repo.type.rawValue).appending(component: repo.id)
    }

    struct HubFileDownloader {
...
        @discardableResult
        func download(progressHandler: @escaping (Double) -> Void) async throws -> URL {
            guard !downloaded else { return destination }

            try prepareDestination()
            let downloader = Downloader(from: source, to: destination, using: hfToken)
            let downloadSubscriber = downloader.downloadState.sink { state in
                if case .downloading(let progress) = state {
                    progressHandler(progress)
                }
            }
            // We need to assign the cancellable to a var so we keep receiving events, so we suppress the "unused var" warning here
            let _ = downloadSubscriber
            try downloader.waitUntilDone()
            return destination
        }

That sink never gets called -- I think that the cancellable (the sink) is actually not being kept alive here when built Release. I think maybe it works for Debug builds, but since the variable is dead before the call to waitUntilDone() the optimized code will release the block.

I think something like this would work (I tested it and it seems to work :-) )

        func download(progressHandler: @escaping (Double) -> Void) async throws -> URL {
            guard !downloaded else { return destination }

            try prepareDestination()
            let downloader = Downloader(from: source, to: destination, using: hfToken)
            let downloadSubscriber = downloader.downloadState.sink { state in
                if case .downloading(let progress) = state {
                    progressHandler(progress)
                }
            }
            _ = try withExtendedLifetime(downloadSubscriber) {
                try downloader.waitUntilDone()
            }
            return destination
        }

What do you think?

ZachNagengast commented 6 months ago

Nice! Also noticed this intermittently 👍