mustafaaljadery / lightning-whisper-mlx

An extremely fast implementation of whisper optimized for Apple Silicon using MLX.
https://mustafaaljadery.github.io/lightning-whisper-mlx/
588 stars 30 forks source link

Align load_models.py to latest MLX version for quantization #14

Closed ivanfioravanti closed 6 months ago

ivanfioravanti commented 6 months ago

This to fix Issue #11

kadirnar commented 6 months ago

I updated the code but it doesn't work. Can you fix this?

ValueError: [addmm] Last dimension of first input with shape (1,1500,1280) must match second to last dimension of second input with shape (160,1280).
x4080 commented 5 months ago

@kadirnar I got the same result, did you find the solution ?

NoahBPeterson commented 1 month ago

@x4080 @kadirnar @mustafaaljadery Yes, the weights = tree_unflatten(list(weights.items())) statement needs to be moved after the if quantization is not None: branch.

Like so:

load_models.py

def load_model(
    path_or_hf_repo: str,
    dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
    model_path = Path(path_or_hf_repo)
    if not model_path.exists():
        model_path = Path(snapshot_download(repo_id=path_or_hf_repo))

    with open(str(model_path / "config.json"), "r") as f:
        config = json.loads(f.read())
        config.pop("model_type", None)
        quantization = config.pop("quantization", None)

    model_args = whisper.ModelDimensions(**config)

    weights = mx.load(str(model_path / "weights.npz"))
-   weights = tree_unflatten(list(weights.items()))

    model = whisper.Whisper(model_args, dtype)

    if quantization is not None:
        class_predicate = (
            lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
            and f"{p}.scales" in weights
        )
        nn.quantize(model, **quantization, class_predicate=class_predicate)

+   weights = tree_unflatten(list(weights.items()))
    model.update(weights)
    mx.eval(model.parameters())
    return model