apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.44k stars 643 forks source link

Issue with `index_put` and/or `scatter_nd` ops #2040

Open Datamance opened 1 year ago

Datamance commented 1 year ago

Here is my script:

import torch
import torch
import transformers
import coremltools as ct
import numpy as np

print(
    "Versions:\n"
    f"torch: {torch.__version__}\n"
    f"transformers: {transformers.__version__}\n"
    f"coremltools: {ct.__version__}\n"
    f"numpy: {np.__version__}"
)

model = (
    transformers.AutoModelForSequenceClassification.from_pretrained(
        "yikuan8/Clinical-Longformer",
        num_labels=1,
        problem_type="regression",
        torch_dtype=torch.float32,
        torchscript=True,
        return_dict=False,  # Need this, or forward will bork
    )
    .to(device="mps", dtype=torch.float32)
    .eval()
)

# skip as_strided, unfold, etc. by pretending we're exporting to ONNX
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/longformer/modeling_longformer.py#L780
model.config.onnx_export = True

# mimic batched inputs
fake_inputs = torch.randint(50_000, (1, 4096)).to(device="mps")
fake_attn_mask = torch.ones_like(fake_inputs).to(dtype=int, device="mps")

fake_input_dict = {"input_ids": fake_inputs, "attention_mask": fake_attn_mask}

traced = torch.jit.trace(model, example_kwarg_inputs=fake_input_dict)

# Convenience aliases
register_torch_op = ct.converters.mil.frontend.torch.register_torch_op
_get_inputs = ct.converters.mil.frontend.torch.ops._get_inputs
_make_fill_op = ct.converters.mil.frontend.torch.ops._make_fill_op

@register_torch_op
def new_ones(context, node):
    # The difference between "new_full" and "full" is that the "new_full" is called from
    # an existing tensor: tensor.new_full(size, fill_value), while the "full" is called
    # from the torch API: torch.full(size, fill_value).
    # But they are basically doing the same thing.
    inputs = _get_inputs(context, node)
    size = inputs[1]
    result = _make_fill_op(size, 1, node.name)
    context.add(result)

# Can replace one of the shape elements with a ct.RangeDim(lower, upper) for variable size stuff.
mlmodel = ct.convert(
    traced,
    inputs=[
        ct.TensorType(name="input_ids", shape=(1, 4096), dtype=np.int32),
        ct.TensorType(name="attention_mask", shape=(1, 4096), dtype=np.int32),
    ],
    minimum_deployment_target=ct.target.macOS14,
)

I think tracing is the right call here because my input will always be collated and padded to the correct shape. So I'm not worried about any of the warnings (which mostly come from assert statements anyway, so they don't really affect control flow).

You can see above that I circumvented initial errors about the as_strided op not being implemented by lying to the config about exporting to ONNX. That got me a little farther, but then I hit another wall at new_ones. I tried to circumvent this yet again by registering another op like so:

@register_torch_op
def new_ones(context, node):
    # The difference between "new_full" and "full" is that the "new_full" is called from
    # an existing tensor: tensor.new_full(size, fill_value), while the "full" is called
    # from the torch API: torch.full(size, fill_value).
    # But they are basically doing the same thing.
    inputs = _get_inputs(context, node)
    size = inputs[1]
    result = _make_fill_op(size, 1, node.name)
    context.add(result)

which is basically a slightly modified copy of new_full. That seems to work (although I'm curious to get feedback on whether this is the right approach?), but then I get this error:

  File "/Users/ricorodriguez/GradSchool/Classes/Fall2023/MachineLearning/SecondProject/env/lib/python3.11/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 519, in convert
    convert_nodes(self.context, self.graph)
  File "/Users/ricorodriguez/GradSchool/Classes/Fall2023/MachineLearning/SecondProject/env/lib/python3.11/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 88, in convert_nodes
    add_op(context, node)
  File "/Users/ricorodriguez/GradSchool/Classes/Fall2023/MachineLearning/SecondProject/env/lib/python3.11/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 3647, in index_put
    result = mb.scatter_nd(data=x, indices=indices, updates=values, mode=mode, name=node.name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ricorodriguez/GradSchool/Classes/Fall2023/MachineLearning/SecondProject/env/lib/python3.11/site-packages/coremltools/converters/mil/mil/ops/registry.py", line 182, in add_op
    return cls._add_op(op_cls_to_add, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ricorodriguez/GradSchool/Classes/Fall2023/MachineLearning/SecondProject/env/lib/python3.11/site-packages/coremltools/converters/mil/mil/builder.py", line 184, in _add_op
    new_op.type_value_inference()
  File "/Users/ricorodriguez/GradSchool/Classes/Fall2023/MachineLearning/SecondProject/env/lib/python3.11/site-packages/coremltools/converters/mil/mil/operation.py", line 257, in type_value_inference
    output_types = self.type_inference()
                   ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ricorodriguez/GradSchool/Classes/Fall2023/MachineLearning/SecondProject/env/lib/python3.11/site-packages/coremltools/converters/mil/mil/ops/defs/iOS17/scatter_gather.py", line 228, in type_inference
    result = super().type_inference()
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ricorodriguez/GradSchool/Classes/Fall2023/MachineLearning/SecondProject/env/lib/python3.11/site-packages/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py", line 546, in type_inference
    assert is_compatible_symbolic_vector(
AssertionError

I stepped through with a debugger and it seems like the index_put and scatter_nd ops are the culprit here. When the builder tries to add scatter_nd, it does some kind of type inference check and fails an assertion here:

    def type_inference(self):
        assert self.indices.shape[-1] <= self.data.rank
        expected_updates_shape = (
            self.indices.shape[:-1] + self.data.shape[self.indices.shape[-1] :]
        )
        assert is_compatible_symbolic_vector(
            self.updates.shape, tuple(expected_updates_shape)
        )
        return self.data.sym_type

Checking the shapes in the debugger shows

self.updates.shape
Out[4]: (12, 64)
tuple(expected_updates_shape)
Out[5]: (is5, 12, 64)

Immediate question: did I do something wrong, or is this a bug?

Follow-up: what is the recommendation to proceed here? For context, I would ideally like to do training-time quantization (probably Linear 8-bit) on this Longformer model. Will I keep running into unimplemented ops even if I get past this issue? Should I try something else entirely?

Thanks!

Datamance commented 1 year ago

OK, looking at the graph output, it looks like index_put is first called here in this method called by forward inLongformerSelfAttention:

    def _concat_with_global_key_attn_probs(
        self,
        key_vectors,
        query_vectors,
        max_num_global_attn_indices,
        is_index_global_attn_nonzero,
        is_local_index_global_attn_nonzero,
        is_local_index_no_global_attn_nonzero,
    ):
        batch_size = key_vectors.shape[0]

        # create only global key vectors
        key_vectors_only_global = key_vectors.new_zeros(
            batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
        )

        key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero]

in a conditional block that checks if the pass is for global attention:

        # compute local attention probs from global attention keys and contact over window dim
        if is_global_attn:
            # compute global attn indices required through out forward fn
            (
                max_num_global_attn_indices,
                is_index_global_attn_nonzero,
                is_local_index_global_attn_nonzero,
                is_local_index_no_global_attn_nonzero,
            ) = self._get_global_attn_indices(is_index_global_attn)
            # calculate global attn probs from global key

            global_key_attn_scores = self._concat_with_global_key_attn_probs(
                query_vectors=query_vectors,
                key_vectors=key_vectors,
                max_num_global_attn_indices=max_num_global_attn_indices,
                is_index_global_attn_nonzero=is_index_global_attn_nonzero,
                is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
                is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
            )

So somehow, it looks like key_vectors is the wrong shape. Or, more precisely - key_vectors[is_index_global_attn_nonzero] is short one dimension. I also ruled out batch size being the complicating factor by changing all the input shapes to (4, 4096) which is closer to what I'd do in real training.

I will have to dig in more to see why the mismatch happens, but again I am all ears for different approaches/solutions to this issue.

YifanShenSZ commented 1 year ago

Hey @Datamance, the error you got

  File "/Users/ricorodriguez/GradSchool/Classes/Fall2023/MachineLearning/SecondProject/env/lib/python3.11/site-packages/coremltools/converters/mil/mil/ops/defs/iOS17/scatter_gather.py", line 228, in type_inference
    result = super().type_inference()
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ricorodriguez/GradSchool/Classes/Fall2023/MachineLearning/SecondProject/env/lib/python3.11/site-packages/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py", line 546, in type_inference
    assert is_compatible_symbolic_vector(
AssertionError

Is probably due to some dynamic shape involved in your model. 3 observations supporting my guess:

  1. new_ones is creating a new tensor full of 1, given the shape of another tensor. It is possible that this "another tensor" has a dynamic shape.
  2. You do have Out[5]: (is5, 12, 64), where is5 is "5th integer symbol"
  3. The batch sizing is a potential source of dynamic shape.

I would suggest to try a fixed shape model first, to rule out symbol noise and make sure things convert.