apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.32k stars 626 forks source link

RuntimeError: PyTorch convert function for op 'pythonop' not implemented. #1135

Open lschaupp opened 3 years ago

lschaupp commented 3 years ago

Hey there,

I am getting the following error when converting a model using coreml. : RuntimeError: PyTorch convert function for op 'pythonop' not implemented. "PyTorch convert function for op '{}' not implemented.".format(node.kind)

I am using the strict flag before in the trace function. (since the input size can vary). traced_model_detector = torch.jit.trace(...., strict=False)

Did anybody encounter this issue before?

Thank you!

TobyRoseman commented 3 years ago

I have not seen this issue before. It sounds like you have a layer of type pythonop in your model and we haven't implemented a way to convert that type of layer.

Do you have any details about this pythonop layer? Is this a custom layer in your model?

henbucuoshanghai commented 3 years ago

help how to solve it tks
/torch/nn/modules/module.py", line 772, in getattr type(self).name, name)) torch.nn.modules.module.ModuleAttributeError: 'Crowd_locator' object has no attribute 'loss'

i-amgeek commented 3 years ago

Getting same issue.

RuntimeError: PyTorch convert function for op 'pythonop' not implemented.

tommy19970714 commented 3 years ago

+1

njb commented 2 years ago

+1

rotem154154 commented 2 years ago

+1

TobyRoseman commented 2 years ago

Can anyone share simple steps to reproduce this issue?

barca314 commented 2 years ago

+1

ahmedshoaib commented 2 years ago

Hey, got the same issue while trying to convert a simple UNet with efficientnet encoder.

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="efficientnet-b4",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=2,                      # model output channels (number of classes in your dataset)
)

import torch
model.eval()

example_input = torch.rand(1, 3, 160, 160)
traced_model = torch.jit.trace(model, example_input)
out = traced_model(example_input)

import coremltools as ct
model = ct.convert(
    traced_model,
    inputs=[ct.TensorType(shape=example_input.shape)]
 )
TobyRoseman commented 2 years ago

Thanks @ahmedshoaib. I can reproduce this problem using your code.

If anyone can construct a minimal example (i.e. a simple standalone neural network) for this problem that would be helpful.

Also can anyone share information about the pythonop op? I can't find any documentation for it.

iutlu commented 2 years ago

@TobyRoseman I've run into this with models that contain autograd.Functions. pythonop probably refers to operations that PyTorch tracer cannot recognize (operations that call into Python).

Here's an example:

import coremltools as ct
import torch
import torch.autograd as autograd
import torch.nn as nn

class ExampleFunction(autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x

    @staticmethod
    def backward(ctx, x):
        return x

class ExampleModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return ExampleFunction.apply(x)

def main():
    device = torch.device('cpu')
    model = ExampleModule().to(device=device).eval()
    model_input = torch.rand(1, 3, 256, 256).to(device=device)
    traced_model = torch.jit.trace(model, model_input)

    print()
    for node in traced_model.graph.nodes():
        print(f"{node=}")
        print(f"{node.kind()=}")
    print()

    # PyTorch convert function for op 'pythonop' not implemented.
    coreml_model = ct.convert(traced_model, inputs=[ct.TensorType(name='x', shape=model_input.shape)])

    return 0

if __name__ == '__main__':
    raise SystemExit(main())

In practice, one might run into this trying to convert, e.g. EfficientNet from https://github.com/lukemelas/EfficientNet-PyTorch. Here we run into the pythonop issue when trying to convert the Swish function implementation

# A memory-efficient implementation of Swish function
class SwishImplementation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.sigmoid(i)
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_tensors[0]
        sigmoid_i = torch.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

class MemoryEfficientSwish(nn.Module):
    def forward(self, x):
        return SwishImplementation.apply(x)

again due to the wrapped autograd.Function.

The obvious solution is to use (again from the same repo)

# For compatibility with old PyTorch versions
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

However, it might not always be possible to do a bypass like this. For example, an autograd.Function might be needed to wrap a custom CUDA implementation of forward & backward.

This special case might raise the question of why we would need to even care about this to begin with, since we wouldn't be able to trace & convert such a model anyways.

I think however that there's still value in handling the pythonop error in this case: One might want to provide a custom layer implementation in Swift / Metal for such a layer (register the torch op, bind to the Swift class, convert to neuralnetwork type).

In this case, the naive approach of trying to register @register_torch_op(torch_alias=['pythonop']) breaks down in two ways: First, there might be multiple such CUDA functions to provide a custom implementation for. Second, this approach fails to capture potential scalar arguments that might be present:

class CustomFunction(autograd.Function):
    @staticmethod
    def forward(ctx, input_1, input_2, param_1, param_2, param_3):
        ctx.save_for_backward(input_1, input_2)
        ctx.param_1 = param_1
        ctx.param_2 = param_2
        ctx.param_3 = param_3
        ...

    @staticmethod
    def backward(ctx, x):
        input_1, input_2 = ctx.saved_tensors
        param_1 = ctx.param_1
        param_2 = ctx.param_2
        param_3 = ctx.param_2
        ...

Normally, we'd expect to have 3 constant nodes (param_1, param_2, param_3) that are registered as inputs to our op, along with the nodes for the tensor inputs input_1 and input_2. But inspecting the generated op (of kind pythonop), we see that only inputs are the input_1 and input_2 nodes.

To work around this, I have monkey-patched InternalTorchIRNode from coremltools, sharing it in case anyone else might find it useful (no guarentees! :)):

from collections import defaultdict
from inspect import signature
from itertools import starmap

import coremltools.converters.mil.frontend.torch.internal_graph

class InternalTorchIRNode(coremltools.converters.mil.frontend.torch.internal_graph.InternalTorchIRNode):
    def __init__(self, node=None, **kwargs):
        super().__init__(node=node, **kwargs)
        if node.kind() == 'prim::PythonOp':
            self.prepend_constant_nodes(node)
            const_names = [const_node.outputs[0] for const_node in self.const_nodes]
            self.inputs = list(map(next, starmap({"c": iter(const_names), "d": iter(self.inputs)}.get, node.cconv())))
            self.kind = f'autograd_{node.pyname().removesuffix("Function").lower()}'

    def prepend_constant_nodes(self, node):
        assert node.kind() == 'prim::PythonOp'
        parameters = list(signature(node.pyobj().__self__.forward).parameters)[1:]  # skip ctx
        constant_parameters = [p for p, arg_type in zip(parameters, node.cconv()) if arg_type == 'c']
        scalar_args = node.scalar_args()
        self.const_nodes = []
        for param, arg in zip(constant_parameters, scalar_args):
            name = f'{node.pyname()}.{param}'
            unique = self.get_unique_name_in_graph(name)
            name = f'{node.pyname()}.{unique}.{param}'
            const_node = self.get_constant_node(name=name, value=arg)
            self.const_nodes.append(const_node)
        for const_node in self.const_nodes:
            self.parent.nodes.append(const_node)

    def get_unique_name_in_graph(self, name):
        graph = self.parent
        counter = graph.const_node_counter = getattr(graph, 'const_node_counter', defaultdict(int))
        count = counter[name]
        counter[name] += 1
        return count

    def get_constant_node(self, *, name, value):
        const_node = type(self).__new__(type(self))
        const_node.parent = self.parent
        const_node.inputs = []
        const_node.outputs = [name]
        const_node.kind = 'constant'
        const_node.blocks = []
        const_node.attr = {'value': value}
        const_node.name = name
        return const_node

coremltools.converters.mil.frontend.torch.internal_graph.InternalTorchIRNode = InternalTorchIRNode

Here I'm using an arbitrary convention where invocations of CustomFunction.apply translate into ops of kind autograd_custom, which one can go ahead and register. The scalar parameters are also handled:

@register_torch_op
def autograd_custom(context, node):
    ...

Here the constant InternalTorchIRNodes for param_1, param_2 and param_3 are created and added to the graph. The naming convention for constant nodes here is {autograd_function_name}.{count}.{param_name} -- might not be consistent with how the other nodes are labeled, but it probably doesn't matter anyways (related: https://discuss.pytorch.org/t/accessing-parameter-names-of-jit-traced-autograd-functions/143750)

Note: You might need to replace the contents of staticmethod forward temporarily for the conversion to go through (e.g. with a torch.zeros of appropriate shape)

bigmindapp commented 2 years ago

I encounter this problem, when convert to coreml.

 File "/Users/mac/opt/anaconda3/envs/e2fgvi36/lib/python3.6/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 88, in convert_nodes
    "PyTorch convert function for op '{}' not implemented.".format(node.kind)
RuntimeError: PyTorch convert function for op 'pythonop' not implemented.

I print the node which may result to problem:

%feat_prop.2 = pythonop[inplace=0](%x.20, %offset.2, %mask.2, %feat_prop_module.deform_align.backward_.weight, %feat_prop_module.deform_align.backward_.bias)

And then i find the function in feat_prop:

from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d from mmcv.cnn import constant_init

from model.modules.flow_comp import flow_warp


class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
    """Second-order deformable alignment module."""
    def __init__(self, *args, **kwargs):
        self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)

        super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)

        self.conv_offset = nn.Sequential(
            nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
        )

        self.init_offset()

    def init_offset(self):
        constant_init(self.conv_offset[-1], val=0, bias=0)

    def forward(self, x, extra_feat, flow_1, flow_2):
        extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
        out = self.conv_offset(extra_feat)
        o1, o2, mask = torch.chunk(out, 3, dim=1)

        # offset
        offset = self.max_residue_magnitude * torch.tanh(
            torch.cat((o1, o2), dim=1))
        offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
        offset_1 = offset_1 + flow_1.flip(1).repeat(1,
                                                    offset_1.size(1) // 2, 1,
                                                    1)
        offset_2 = offset_2 + flow_2.flip(1).repeat(1,
                                                    offset_2.size(1) // 2, 1,
                                                    1)
        offset = torch.cat([offset_1, offset_2], dim=1)

        # mask
        mask = torch.sigmoid(mask)

        return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
                                       self.stride, self.padding,
                                       self.dilation, self.groups,
                                       self.deform_groups)

I searched the modulated_deform_conv2d, it says 'modulated_deform_conv2d only supports in GPU' I maped it in CPU device

Do it mean, if use function like modulated_deform_conv2d, not in pytorch, the coremltools can't convert it?

bigmindapp commented 2 years ago

I run the converter in Google Colab, with pytorch 1.11.0 + coremltools6.0b1, the error is the same,

%feat_prop.2 = pythonop[inplace=0](%x.20, %offset.2, %mask.2, %feat_prop_module.deform_align.backward_.weight, %feat_prop_module.deform_align.backward_.bias)

 File "/Users/mac/opt/anaconda3/envs/e2fgvi36/lib/python3.6/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 88, in convert_nodes
    "PyTorch convert function for op '{}' not implemented.".format(node.kind)
RuntimeError: PyTorch convert function for op 'pythonop' not implemented.
srelbo commented 2 years ago

@TobyRoseman Is there a solution to this problem? or a workaround we could use?

QuchenFu commented 1 year ago

same issue

junpeiz commented 1 year ago

This issue also occurs during converting HuggingFace AutoModelForSequenceClassification.

fdchiu commented 1 year ago

Any progress on this? I am having the same issue with model conversion

nitishsaDire commented 1 year ago

Has anyone faced "RuntimeError: PyTorch convert function for op 'clip' not implemented.". Please help. Thanks.

TerryyyZhang commented 11 months ago

I saw an article proposed that the model may be the source of the problem.

https://rockyshikoku.medium.com/solution-for-pytorch-convert-function-for-op-record-function-enter-not-implemented-37389304c75b

zihaog0724 commented 10 months ago

Same here. Any updates?

YifanShenSZ commented 10 months ago

Hey folks, a potential source of pythonop is PyTorch autograd. Concretely, if you customize something with torch.autograd.Function, then PyTorch autograd may give you some pythonop that cannot be converted.

If that is the case, you may try replacing that torch.autograd.Function with something more conventional, e.g. torch.nn.Module

xyu2 commented 8 months ago

Hey folks, a potential source of pythonop is PyTorch autograd. Concretely, if you customize something with torch.autograd.Function, then PyTorch autograd may give you some pythonop that cannot be converted.

If that is the case, you may try replacing that torch.autograd.Function with something more conventional, e.g. torch.nn.Module

Any update for the issue with torch.autograd.Function? Our model has the same conversion issue due to torch.autograd.Function, but our model couldn't be replaced with torch.nn.Module. Thanks!

SmallChungus1 commented 1 month ago

Hello, I encountered the same issue recently when trying to convert a trained SMP model to Coreml. Was wondering if anyone has found a solution for this? Thanks!

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="efficientnet-b4",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=2,                      # model output channels (number of classes in your dataset)
)

import torch
model.eval()

example_input = torch.rand(1, 3, 160, 160)
traced_model = torch.jit.trace(model, example_input)
out = traced_model(example_input)

import coremltools as ct
model = ct.convert(
    traced_model,
    inputs=[ct.TensorType(shape=example_input.shape)]
 )