Open zobertke opened 1 year ago
See our User Guide Section for Using Composite Ops with PyTorch Conversion.
I'm going to close this issue, but let us know if you have any more questions.
Thanks for the link. I am aware of the composite ops, that is how I am trying to circumvent this issue. My question was more like: which mil operations supports the above mentioned scatter_max torch operation? You have some scatter operations like scatter_along_axis and scatter_nd but those work with a separate input tensor for update, while scatter_max has only input data and no update data.
I tried the following but got an error that it can't add the shape as const (as it has a flexible shape I assume)
@register_torch_op
def scatter_max(context, node):
inputs = _get_inputs(context, node, expected=5)
src = inputs[0]
indices = inputs[1]
axis = inputs[2].val
updates = mb.reshape(
x=src,
shape=[ src.shape[1], inputs[4].val]
)
result = mb.scatter_along_axis(data=src, updates=updates, indices=indices, axis=axis, mode='max', name=node.name)
context.add(result, torch_name=node.name)
note: src has a dynamic (is66,8) shape -> and reshape can't work with it right now as it seems scatter_max is needed because the model is using a message passing layer from pytorch_geometric
Based on the error, it looks like the input to your layer has a flexible shape. Before trying to get it to work with flexible shapes, I'd recommend getting it to work with a simple isolated example (i.e. a unit test) that has a fixed shape.
Take a look at the torch op unit tests, there are many examples of unit test for a single layer.
The issue is that the incoming data is already in different shape. src.shape is (is66,8,fp32) -> dynamic number of tensors with 8 elements indices.shape is (is65,int32) -> indexes And the scatter function needs rank(src)==rank(indices). So this immediately fails. Need a reshape I believe. But reshape doesn't like dynamic inputs.
Based on the documentation for the MIL reshape
op, I think reshape supports one dynamic input:
Symbols: All but one symbol in shape must be present in x.shape. The new symbol that is not
present in x.shape represent a dimension such that the total size remains constant.
I was able to finally get it working and pass through this step. The implementation is:
@register_torch_op
def scatter_max(context, node):
inputs = _get_inputs(context, node, expected=5)
data = inputs[0]
indices = inputs[1]
axis = inputs[2].val
update_dim = inputs[4].val
updates = mb.fill(shape=(update_dim, data.shape[1]))
context.add(mb.scatter(data=data, updates=updates, indices=indices, axis=axis, mode='max', name=node.name))
But soon enough I have stumbled in another much bigger issue. To give you a broader context, what I am trying to convert is the softmax operation in pytorch geometric: https://pytorch-geometric.readthedocs.io/en/2.2.0/_modules/torch_geometric/utils/softmax.html The relevant lines are:
N = maybe_num_nodes(index, num_nodes)
src_max = scatter(src, index, dim, dim_size=N, reduce='max')
src_max = src_max.index_select(dim, index)
out = (src - src_max).exp()
out_sum = scatter(out, index, dim, dim_size=N, reduce='sum')
out_sum = out_sum.index_select(dim, index)
Which is roughly translated by the converter in my case to:
%src_max.1, %303 = scatter_max[](%src.1, %index.5, %100, %98, %90)
%src_max.3 = index_select[](%src_max.1, %100, %index.5)
%305 = sub[](%src.1, %src_max.3, %102)
%other.1 = exp[](%305)
%src.3 = unsqueeze[](%index.5, %94)
%308 = size[](%other.1, %100)
%309 = size[](%other.1, %102)
%310 = listconstruct[](%308, %309)
%index.7 = expand[](%src.3, %310, %97)
%312 = size[](%other.1, %102)
%313 = listconstruct[](%291, %312)
%out.75 = zeros[](%313, %99, %98, %105, %97)
%out_sum.1 = scatter_add_[](%out.75, %100, %index.7, %other.1)
I don't see why these come in picture:
%src.3 = unsqueeze[](%index.5, %94)
%308 = size[](%other.1, %100)
%309 = size[](%other.1, %102)
%310 = listconstruct[](%308, %309)
%index.7 = expand[](%src.3, %310, %97)
%312 = size[](%other.1, %102)
%313 = listconstruct[](%291, %312)
%out.75 = zeros[](%313, %99, %98, %105, %97)
but they definitely end up blowing up the input for scatter_add which fails. I have checked scatter_max has the correct output shape. I would have expected smthg like:
%src_max.1, %303 = scatter_max[](%src.1, %index.5, %100, %98, %90)
%src_max.3 = index_select[](%src_max.1, %100, %index.5)
%305 = sub[](%src.1, %src_max.3, %102)
%other.1 = exp[](%305)
%out_sum.1 = scatter_add_[](%other.1, %100, %index.5, %out.75)
Should I override scatter_add too with a proprietary implementation? Maybe you can shed some light on what is happening there with all the extra operations there. Thank you
Figured it out. There is a broadcast op before actually calling the scatter_add operation. Now my issue is distilled down to the fact that scatter_add -> scatter_along_axis is called with: data [12,8], indices(is68,is69) and there is a check here https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py#L430 where a dynamic shape is compared to a constant one. So somewhere along the shape is destroyed. Might be in scatter_max.
I'm a bit confused. What is the current state/problem here?
Can you share a minimal example to reproduce your problem? Are you calling torch.jit.trace
on your model prior to conversion?
Hi! I am still stucked. Yes I am calling jit.trace on my model. The model is a bigger one but at a point it uses a graph attention layer from pytorch geometric https://arxiv.org/abs/2105.14491
from torch_geometric.nn import GATv2Conv, GATConv
class GAT(torch.nn.Module):
def __init__(self, hidden_dim, dropout, model_type, batch_norm=False):
super(GAT, self).__init__()
if model_type == "gat":
self.conv1 = GATConv(hidden_dim, hidden_dim // 8, heads=8, dropout=dropout)
self.conv2 = GATConv(hidden_dim, hidden_dim // 8, heads=8, dropout=dropout)
elif model_type == "gatv2":
self.conv1 = GATv2Conv(hidden_dim, hidden_dim // 8, heads=8, dropout=dropout)
self.conv2 = GATv2Conv(hidden_dim, hidden_dim // 8, heads=8, dropout=dropout)
else:
raise Exception("Unknown model_type argument")
self.batch_norm = batch_norm
if batch_norm:
self.bn = nn.BatchNorm1d(hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, edges_idx):
x = self.dropout(x)
x = self.conv1(x, edges_idx)
if self.batch_norm:
x = self.bn(x)
x = F.elu(x)
x = self.dropout(x)
x = self.conv2(x, edges_idx)
return x
In the implementation of this layer there is a softmax( https://pytorch-geometric.readthedocs.io/en/2.2.0/_modules/torch_geometric/utils/softmax.html), which uses a scatter_max and then a scatter_add operation. There is no support currently in coreml pytorch ops for scatter_max, hence the above custom operation. There I am basically calling mb.scatter (not mb.scatter_along_axis as I am not able to have the indices and data to have the same tensor shape). This seems to work but then scatter_add from https://github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py#L8 is traced down to a broadcast and only then is scatter_add called which in turn fails because the input tensors have incompatible shapes. Because this is a graph neural network I need to work with flexible shapes.
Finally got it working. I am able to convert the model. At inference though I get a general error like this:
Tuple detected at graph output. This will be flattened in the converted model.
Converting PyTorch Frontend ==> MIL Ops: 42%|βββββββ | 213/506 [00:00<00:00, 788.91 ops/s]Saving value type of int64 into a builtin type of int32, might lose precision!
Converting PyTorch Frontend ==> MIL Ops: 68%|βββββββββββ | 346/506 [00:00<00:00, 458.44 ops/s]Saving value type of int64 into a builtin type of int32, might lose precision!
Converting PyTorch Frontend ==> MIL Ops: 100%|βββββββββββββββ| 504/506 [00:00<00:00, 517.36 ops/s]
Running MIL Common passes: 0%| | 0/39 [00:00<?, ? passes/s]/Users/csongorszabo/opt/miniconda3/envs/graph_env/lib/python3.9/site-packages/coremltools/converters/mil/mil/passes/name_sanitization_utils.py:135: UserWarning: Output, '608', of the source model, has been renamed to 'var_608' in the Core ML model.
warnings.warn(msg.format(var.name, new_name))
Running MIL Common passes: 100%|βββββββββββββββββββββββββββββ| 39/39 [00:00<00:00, 52.10 passes/s]
Running MIL FP16ComputePrecision pass: 100%|βββββββββββββββββββ| 1/1 [00:00<00:00, 2.58 passes/s]
Running MIL Clean up passes: 100%|βββββββββββββββββββββββββββ| 11/11 [00:01<00:00, 10.25 passes/s]
Backend MacOSX is interactive backend. Turning interactive mode on.
return self.__proxy__.predict(data)
RuntimeError: {
NSLocalizedDescription = "Error computing NN outputs.";
}
Any idea, how to debug this further?Do I need to bisect the model somehow or I can get somehow more meaningful logs somewhere?
Bisecting the model is a good idea. However there are a couple of easier things I recommend you try first.
1 - When calling ct.convert
, try using convert_to="mlprogram"
.
2 - Also when calling ct.convert
, try specifying the output shape using the output
parameter. For example, if you have a single output with shape (1, 45), do:
outputs=[ct.TensorType(shape=(1, 45)],
βQuestion
Hi! I am in the middle of converting a graph neural network which at a point uses a message passing layer and during converting the torch model I get the unsupported operation message for torch.scatter_max operation. I saw that you have a scatter_gather operation with mode max here https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.scatter_gather. Is this the same operation as https://pytorch-scatter.readthedocs.io/en/1.3.0/functions/max.html? If yes, how can I reroute it through a custom operation? Thanks a lot π