apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
BSD 3-Clause "New" or "Revised" License
4.32k stars 627 forks source link

Model state doesn't work with transpose #2278

Open seayoung1112 opened 1 month ago

seayoung1112 commented 1 month ago

🐞Describing the bug

This is a follow up on Sorry I couldn't find the reopen option in the original issue. To clarify, the issue didn't happen with batch prediction where a list of tensor is used in the input, the input is ONLY one tensor.

I tried to lower a PyTorch LLama model with KV cache to coreML using the latest stateful feature introduced in 8.0. The export steps succeeded and I could generate a mlpackage, however during runtime, the code failed immediately when constructing the model class. Error message is like: "Fatal error: 'try!' expression unexpectedly raised an error: Error Code=0 "MIL program input, 'k_cache', not found in Core ML model inputs" UserInfo={NSLocalizedDescription=MIL program input, 'k_cache', not found in Core ML model inputs}"

I debugged a bit and found that a view + transpose combination would cause this issue but couldn't get any more insight why. The code to repro is attached below. Specifically in the code, if I change it from

        k = k.view(1, seqlen, 16, 128)
        v = v.view(1, seqlen, 16, 128)

        k = k.transpose(1, 2)
        v = v.transpose(1, 2)


        k = k.view(1, 16, seqlen, 128)
        v = v.view(1, 16, seqlen, 128)

the inference would work

Stack Trace

only the error message, no further stack trace.

To Reproduce

class TestAttention(nn.Module): def init(self): super().init() self.wk = nn.Linear(2048, 2048, bias=False) self.wv = nn.Linear(2048, 2048, bias=False) self.wo = nn.Linear(128, 128, bias=False)

    cache_shape = (1, 16, 128, 128)

        "k_cache", torch.zeros(cache_shape, dtype=torch.float32, device="cpu")
        "v_cache", torch.zeros(cache_shape, dtype=torch.float32, device="cpu")

def forward(
    self, embedding
    bsz, seqlen, _ = embedding.shape

    k, v = self.wk(embedding), self.wv(embedding)

    k = k.view(1, seqlen, 16, 128)
    v = v.view(1, seqlen, 16, 128)

    k = k.transpose(1, 2)
    v = v.transpose(1, 2)

    self.k_cache[:, :, 0 : seqlen] = k
    self.v_cache[:, :, 0 : seqlen] = v

    return self.wo(self.k_cache)

model_t = TestAttention().eval()

config = PostTrainingQuantizerConfig.from_dict( { "module_type_configs": { torch.nn.Linear: { "weight_dtype": "int4", "granularity": "per_channel", }, } } )

quantizer = PostTrainingQuantizer(model_t, config)

quantized_model = quantizer.compress()

inputs = ( torch.rand(1, 48, 16 * 128), )

traced_model = torch.jit.trace(quantized_model, inputs)

states = [ct.StateType( wrapped_type=ct.TensorType( shape=(1, 16, 128, 128), ), name=v, ) for v in ['k_cache', 'v_cache']] mlmodel = ct.convert( traced_model, inputs = [ct.TensorType(shape=(1, 48, 16 * 128)), ], outputs = [ ct.TensorType(name="op")], states=states,, compute_units=ct.ComputeUnit.CPU_AND_NE, )

note if I change the code 
    k = k.view(1, seqlen, 16, 128)
    v = v.view(1, seqlen, 16, 128)

    k = k.transpose(1, 2)
    v = v.transpose(1, 2)
    k = k.view(1, 16, seqlen, 128)
    v = v.view(1, 16, seqlen, 128)
it could run successfully during inference time
- xcode code to run the model

func predict_kv_cache() { let model = try! test_attention().model

guard let x = try? MLMultiArray(shape:[1, 48, 2048], dataType:MLMultiArrayDataType.float16) else {
    fatalError("Unexpected runtime error. MLMultiArray")
for i in 0...98303 {
    x[i] = 0.1

let inputs = test_attentionInput(embedding: x)

let state = model.makeState()

try! model.prediction(from: inputs, using: state)    


it failed immediately when executing 
let model = try! test_attention().model
The converted mlprogram is like this

main[CoreML8](%embedding: (1, 48, 2048, fp16)(Tensor), %k_cache: (1, 16, 128, 128, fp16)(State), %v_cache: (1, 16, 128, 128, fp16)(State)) { block145() { %wk_weight_cast_fp16: (2048, 2048, fp16)(Tensor) = constexpr_blockwise_shift_scale(data=%wk_weight_data_0, scale=%wk_weight_scale_0_to_fp16, name="wk_weight_cast_fp16") %wv_weight_cast_fp16: (2048, 2048, fp16)(Tensor) = constexpr_blockwise_shift_scale(data=%wv_weight_data_0, scale=%wv_weight_scale_0_to_fp16, name="wv_weight_cast_fp16") %wo_weight_cast_fp16: (128, 128, fp16)(Tensor) = constexpr_blockwise_shift_scale(data=%wo_weight_data_0, scale=%wo_weight_scale_0_to_fp16, name="wo_weight_cast_fp16") %linear_0_cast_fp16: (1, 48, 2048, fp16)(Tensor) = linear(x=%embedding, weight=%wk_weight_cast_fp16, bias=%linear_0_bias_0_to_fp16, name="linear_0_cast_fp16") %linear_1_cast_fp16: (1, 48, 2048, fp16)(Tensor) = linear(x=%embedding, weight=%wv_weight_cast_fp16, bias=%linear_0_bias_0_to_fp16, name="linear_1_cast_fp16") %k_3_cast_fp16: (1, 48, 16, 128, fp16)(Tensor) = reshape(x=%linear_0_cast_fp16, shape=[1, 48, 16, 128], name="k_3_cast_fp16") %v_3_cast_fp16: (1, 48, 16, 128, fp16)(Tensor) = reshape(x=%linear_1_cast_fp16, shape=[1, 48, 16, 128], name="v_3_cast_fp16") %read_state_0: (1, 16, 128, 128, fp16)(Tensor) = read_state(input=%k_cache, name="read_state_0") %k_cast_fp16: (1, 16, 48, 128, fp16)(Tensor) = transpose(x=%k_3_cast_fp16, perm=[0, 2, 1, 3], name="transpose_1") %k_cache_internal_tensor_assign_1_cast_fp16: (1, 16, 128, 128, fp16)(Tensor) = slice_update(x=%read_state_0, update=%k_cast_fp16, begin=[0, 0, 0, 0], end=[0, 0, 48, 0], stride=[1, 1, 1, 1], begin_mask=[False, False, False, True], end_mask=[True, True, False, True], squeeze_mask=[False, False, False, False], name="k_cache_internal_tensor_assign_1_cast_fp16") %coreml_update_state_0: (1, 16, 128, 128, fp16)(Tensor) = coreml_update_state(state=%k_cache, value=%k_cache_internal_tensor_assign_1_cast_fp16, name="coreml_update_state_0") %read_state_1: (1, 16, 128, 128, fp16)(Tensor) = read_state(input=%v_cache, name="read_state_1") %v_cast_fp16: (1, 16, 48, 128, fp16)(Tensor) = transpose(x=%v_3_cast_fp16, perm=[0, 2, 1, 3], name="transpose_0") %v_cache_internal_tensor_assign_1_cast_fp16: (1, 16, 128, 128, fp16)(Tensor) = slice_update(x=%read_state_1, update=%v_cast_fp16, begin=[0, 0, 0, 0], end=[0, 0, 48, 0], stride=[1, 1, 1, 1], begin_mask=[False, False, False, True], end_mask=[True, True, False, True], squeeze_mask=[False, False, False, False], name="v_cache_internal_tensor_assign_1_cast_fp16") %coreml_update_state_1: (1, 16, 128, 128, fp16)(Tensor) = coreml_update_state(state=%v_cache, value=%v_cache_internal_tensor_assign_1_cast_fp16, name="coreml_update_state_1") %op: (1, 16, 128, 128, fp16)(Tensor) = linear(x=%coreml_update_state_0, weight=%wo_weight_cast_fp16, bias=%linear_2_bias_0_to_fp16, name="linear_2_cast_fp16") } -> (%op) }

## System environment (please complete the following information):
 - coremltools version: 8.0b1
 - OS (e.g. MacOS version or Linux type): running on iphone 15 pro with ios 18
 - Any other relevant version information (e.g. PyTorch or TensorFlow version): 

## Additional context
- Add anything else about the problem here that you want to share.
TobyRoseman commented 1 month ago

Are you able to get predictions from your model in Python?

seayoung1112 commented 1 month ago

Are you able to get predictions from your model in Python?

Got some issues upgrading my macOS to 15 as the model inference in Python requires it, will add the results when I have it...