apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.4k stars 634 forks source link

Support for scatter_max #1766

Open zobertke opened 1 year ago

zobertke commented 1 year ago

❓Question

Hi! I am in the middle of converting a graph neural network which at a point uses a message passing layer and during converting the torch model I get the unsupported operation message for torch.scatter_max operation. I saw that you have a scatter_gather operation with mode max here https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.scatter_gather. Is this the same operation as https://pytorch-scatter.readthedocs.io/en/1.3.0/functions/max.html? If yes, how can I reroute it through a custom operation? Thanks a lot πŸ™

TobyRoseman commented 1 year ago

See our User Guide Section for Using Composite Ops with PyTorch Conversion.

I'm going to close this issue, but let us know if you have any more questions.

zobertke commented 1 year ago

Thanks for the link. I am aware of the composite ops, that is how I am trying to circumvent this issue. My question was more like: which mil operations supports the above mentioned scatter_max torch operation? You have some scatter operations like scatter_along_axis and scatter_nd but those work with a separate input tensor for update, while scatter_max has only input data and no update data.

zobertke commented 1 year ago

I tried the following but got an error that it can't add the shape as const (as it has a flexible shape I assume)

@register_torch_op
def scatter_max(context, node):
    inputs = _get_inputs(context, node, expected=5)
    src = inputs[0]
    indices = inputs[1]
    axis = inputs[2].val
    updates = mb.reshape(
        x=src,
        shape=[ src.shape[1], inputs[4].val]
    )
    result = mb.scatter_along_axis(data=src, updates=updates, indices=indices, axis=axis, mode='max', name=node.name)
    context.add(result, torch_name=node.name)

note: src has a dynamic (is66,8) shape -> and reshape can't work with it right now as it seems scatter_max is needed because the model is using a message passing layer from pytorch_geometric

TobyRoseman commented 1 year ago

Based on the error, it looks like the input to your layer has a flexible shape. Before trying to get it to work with flexible shapes, I'd recommend getting it to work with a simple isolated example (i.e. a unit test) that has a fixed shape.

Take a look at the torch op unit tests, there are many examples of unit test for a single layer.

zobertke commented 1 year ago

The issue is that the incoming data is already in different shape. src.shape is (is66,8,fp32) -> dynamic number of tensors with 8 elements indices.shape is (is65,int32) -> indexes And the scatter function needs rank(src)==rank(indices). So this immediately fails. Need a reshape I believe. But reshape doesn't like dynamic inputs.

TobyRoseman commented 1 year ago

Based on the documentation for the MIL reshape op, I think reshape supports one dynamic input:

Symbols: All but one symbol in shape must be present in x.shape. The new symbol that is not
present in x.shape represent a dimension such that the total size remains constant.
zobertke commented 1 year ago

I was able to finally get it working and pass through this step. The implementation is:

@register_torch_op
def scatter_max(context, node):
    inputs = _get_inputs(context, node, expected=5)
    data = inputs[0]
    indices = inputs[1]
    axis = inputs[2].val
    update_dim = inputs[4].val
    updates = mb.fill(shape=(update_dim, data.shape[1]))
    context.add(mb.scatter(data=data, updates=updates, indices=indices, axis=axis, mode='max', name=node.name))

But soon enough I have stumbled in another much bigger issue. To give you a broader context, what I am trying to convert is the softmax operation in pytorch geometric: https://pytorch-geometric.readthedocs.io/en/2.2.0/_modules/torch_geometric/utils/softmax.html The relevant lines are:

N = maybe_num_nodes(index, num_nodes)
src_max = scatter(src, index, dim, dim_size=N, reduce='max')
src_max = src_max.index_select(dim, index)
out = (src - src_max).exp()
out_sum = scatter(out, index, dim, dim_size=N, reduce='sum')
out_sum = out_sum.index_select(dim, index)

Which is roughly translated by the converter in my case to:

%src_max.1, %303 = scatter_max[](%src.1, %index.5, %100, %98, %90)
%src_max.3 = index_select[](%src_max.1, %100, %index.5)
%305 = sub[](%src.1, %src_max.3, %102)
%other.1 = exp[](%305)
%src.3 = unsqueeze[](%index.5, %94)
%308 = size[](%other.1, %100)
%309 = size[](%other.1, %102)
%310 = listconstruct[](%308, %309)
%index.7 = expand[](%src.3, %310, %97)
%312 = size[](%other.1, %102)
%313 = listconstruct[](%291, %312)
%out.75 = zeros[](%313, %99, %98, %105, %97)
%out_sum.1 = scatter_add_[](%out.75, %100, %index.7, %other.1)

I don't see why these come in picture:

%src.3 = unsqueeze[](%index.5, %94)
%308 = size[](%other.1, %100)
%309 = size[](%other.1, %102)
%310 = listconstruct[](%308, %309)
%index.7 = expand[](%src.3, %310, %97)
%312 = size[](%other.1, %102)
%313 = listconstruct[](%291, %312)
%out.75 = zeros[](%313, %99, %98, %105, %97)

but they definitely end up blowing up the input for scatter_add which fails. I have checked scatter_max has the correct output shape. I would have expected smthg like:

%src_max.1, %303 = scatter_max[](%src.1, %index.5, %100, %98, %90)
%src_max.3 = index_select[](%src_max.1, %100, %index.5)
%305 = sub[](%src.1, %src_max.3, %102)
%other.1 = exp[](%305)
%out_sum.1 = scatter_add_[](%other.1, %100, %index.5, %out.75)

Should I override scatter_add too with a proprietary implementation? Maybe you can shed some light on what is happening there with all the extra operations there. Thank you

zobertke commented 1 year ago

Figured it out. There is a broadcast op before actually calling the scatter_add operation. Now my issue is distilled down to the fact that scatter_add -> scatter_along_axis is called with: data [12,8], indices(is68,is69) and there is a check here https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py#L430 where a dynamic shape is compared to a constant one. So somewhere along the shape is destroyed. Might be in scatter_max.

TobyRoseman commented 1 year ago

I'm a bit confused. What is the current state/problem here?

Can you share a minimal example to reproduce your problem? Are you calling torch.jit.trace on your model prior to conversion?

zobertke commented 1 year ago

Hi! I am still stucked. Yes I am calling jit.trace on my model. The model is a bigger one but at a point it uses a graph attention layer from pytorch geometric https://arxiv.org/abs/2105.14491

from torch_geometric.nn import GATv2Conv, GATConv
class GAT(torch.nn.Module):
    def __init__(self, hidden_dim, dropout, model_type, batch_norm=False):
        super(GAT, self).__init__()
        if model_type == "gat":
            self.conv1 = GATConv(hidden_dim, hidden_dim // 8, heads=8, dropout=dropout)
            self.conv2 = GATConv(hidden_dim, hidden_dim // 8, heads=8, dropout=dropout)
        elif model_type == "gatv2":
            self.conv1 = GATv2Conv(hidden_dim, hidden_dim // 8, heads=8, dropout=dropout)
            self.conv2 = GATv2Conv(hidden_dim, hidden_dim // 8, heads=8, dropout=dropout)
        else:
            raise Exception("Unknown model_type argument")
        self.batch_norm = batch_norm
        if batch_norm:
            self.bn = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edges_idx):
        x = self.dropout(x)
        x = self.conv1(x, edges_idx)
        if self.batch_norm:
            x = self.bn(x)
        x = F.elu(x)
        x = self.dropout(x)
        x = self.conv2(x, edges_idx)
        return x

In the implementation of this layer there is a softmax( https://pytorch-geometric.readthedocs.io/en/2.2.0/_modules/torch_geometric/utils/softmax.html), which uses a scatter_max and then a scatter_add operation. There is no support currently in coreml pytorch ops for scatter_max, hence the above custom operation. There I am basically calling mb.scatter (not mb.scatter_along_axis as I am not able to have the indices and data to have the same tensor shape). This seems to work but then scatter_add from https://github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py#L8 is traced down to a broadcast and only then is scatter_add called which in turn fails because the input tensors have incompatible shapes. Because this is a graph neural network I need to work with flexible shapes.

zobertke commented 1 year ago

Finally got it working. I am able to convert the model. At inference though I get a general error like this:

Tuple detected at graph output. This will be flattened in the converted model.
Converting PyTorch Frontend ==> MIL Ops:  42%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž        | 213/506 [00:00<00:00, 788.91 ops/s]Saving value type of int64 into a builtin type of int32, might lose precision!
Converting PyTorch Frontend ==> MIL Ops:  68%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 346/506 [00:00<00:00, 458.44 ops/s]Saving value type of int64 into a builtin type of int32, might lose precision!
Converting PyTorch Frontend ==> MIL Ops: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰| 504/506 [00:00<00:00, 517.36 ops/s]
Running MIL Common passes:   0%|                                      | 0/39 [00:00<?, ? passes/s]/Users/csongorszabo/opt/miniconda3/envs/graph_env/lib/python3.9/site-packages/coremltools/converters/mil/mil/passes/name_sanitization_utils.py:135: UserWarning: Output, '608', of the source model, has been renamed to 'var_608' in the Core ML model.
  warnings.warn(msg.format(var.name, new_name))
Running MIL Common passes: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 39/39 [00:00<00:00, 52.10 passes/s]
Running MIL FP16ComputePrecision pass: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00,  2.58 passes/s]
Running MIL Clean up passes: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 11/11 [00:01<00:00, 10.25 passes/s]
Backend MacOSX is interactive backend. Turning interactive mode on.
    return self.__proxy__.predict(data)
RuntimeError: {
    NSLocalizedDescription = "Error computing NN outputs.";
}

Any idea, how to debug this further?Do I need to bisect the model somehow or I can get somehow more meaningful logs somewhere?

TobyRoseman commented 1 year ago

Bisecting the model is a good idea. However there are a couple of easier things I recommend you try first.

1 - When calling ct.convert, try using convert_to="mlprogram".

2 - Also when calling ct.convert, try specifying the output shape using the output parameter. For example, if you have a single output with shape (1, 45), do:

outputs=[ct.TensorType(shape=(1, 45)],