This PR allows MLX and CoreML to coexist, this is done by:
added MLX model loading to WhisperKit
adding shape adapters so the ouput of feature extraction and audio encoder is compatible with CoreML
added MLX tests to CI
Problem
There is one problem which I've been debugging for a while now. The current architecture requires to convert back and
forth between MLMultiArray and MLXArray. Additionally, it requires to rehape the MLXArray so it fits the CoreML
model input. There are some helper methods to do so:
asMLMultiArray
asMLXArray
asMLXOutput
asMLXInput
To tests the correctness of these methods, I've added testArrayConversion test. It fails for one case, when asMLXOutput().asMLMultiArray().asMLXArray(Int32.self).asMLXInput() are chained in that particular order. It's really hard to explain because it should be working correctly: first we expand and change the shape of the array, then we convert it to MLMultiArray, then we convert it back to MLXArray and finally we change the shape back to the original one. The result should be the same as the original array, but it's not.
This manifests itself in the WhisperKit when we try to use MLXFeatureExtractor and MLXAudioEncoder. The output is usually empty transcription (when I try to use just MLXFeatureExtractor transcription is correct).
I suspect that there might be something wrong with converting from MLXArray to MLMultiArray but I didn't find it yet
Edit: Solution
The issue was in asMLXArray, pointed by @ZachNagengast -- when converting to MLXArray we need to use the strides of the MLMultiArray
PR
This PR allows
MLX
andCoreML
to coexist, this is done by:MLX
model loading toWhisperKit
CoreML
MLX
tests to CIProblem
There is one problem which I've been debugging for a while now. The current architecture requires to convert back and forth between
MLMultiArray
andMLXArray
. Additionally, it requires to rehape theMLXArray
so it fits theCoreML
model input. There are some helper methods to do so:asMLMultiArray
asMLXArray
asMLXOutput
asMLXInput
To tests the correctness of these methods, I've added
testArrayConversion
test. It fails for one case, whenasMLXOutput().asMLMultiArray().asMLXArray(Int32.self).asMLXInput()
are chained in that particular order. It's really hard to explain because it should be working correctly: first we expand and change the shape of the array, then we convert it toMLMultiArray
, then we convert it back toMLXArray
and finally we change the shape back to the original one. The result should be the same as the original array, but it's not.This manifests itself in the
WhisperKit
when we try to useMLXFeatureExtractor
andMLXAudioEncoder
. The output is usually empty transcription (when I try to use justMLXFeatureExtractor
transcription is correct).I suspect that there might be something wrong with converting from
MLXArray
toMLMultiArray
but I didn't find it yetEdit: Solution
The issue was in
asMLXArray
, pointed by @ZachNagengast -- when converting toMLXArray
we need to use the strides of theMLMultiArray