Closed ivanfioravanti closed 6 months ago
I updated the code but it doesn't work. Can you fix this?
ValueError: [addmm] Last dimension of first input with shape (1,1500,1280) must match second to last dimension of second input with shape (160,1280).
@kadirnar I got the same result, did you find the solution ?
@x4080 @kadirnar @mustafaaljadery Yes, the weights = tree_unflatten(list(weights.items()))
statement needs to be moved after the if quantization is not None:
branch.
Like so:
load_models.py
def load_model(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
weights = mx.load(str(model_path / "weights.npz"))
- weights = tree_unflatten(list(weights.items()))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
+ weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model
This to fix Issue #11