apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.44k stars 641 forks source link

Flexible shapes not working with PyTorch unified converter #991

Closed 3DTOPO closed 2 years ago

3DTOPO commented 4 years ago

🐞Describe the bug

I made changes to my model so I could use the recommended unified convertor. Conversion is successful without issue and shows that flexible shapes are supported (in both Python and Xcode).

Running prediction with a shape in the supported ranges (and any shape other than the fixed shape) will fail with an error. The fixed shape input works as expected. I've tried both GPU and CPU only.

Trace

[espresso] [Espresso::handle_ex_plan] exception=Espresso exception: "Not implemented": axis -4 not implemented status=-9
[coreml] Failure dynamically resizing for sequence length.
[coreml] Failure in resetSizes.
prediction error: Error Domain=com.apple.CoreML Code=0 "Failure dynamically resizing for sequence length." UserInfo={NSLocalizedDescription=Failure dynamically resizing for sequence length.}

To Reproduce

The source code and model is in the attached archive.

import torch
import torch.nn as nn
import coremltools as ct
import coremltools.proto.FeatureTypes_pb2 as ft
from coremltools.models.neural_network import flexible_shape_utils
from model import TransformerNet

channels = 3
width = 1024
height = 1024

torch_model = TransformerNet()
#torch_model.load_state_dict(torch.load('TrainedModel.pth', map_location=torch.device('cpu')))
torch_model.eval()

example_input = torch.rand(1, channels, width, height)
traced_model = torch.jit.trace(torch_model, example_input)

mlmodel = ct.convert(
    traced_model,
    inputs=[ct.ImageType(name="input_1", shape=example_input.shape)], 
    minimum_ios_deployment_target='13'
)

#note if "input" is used for the name it creates a name collision

spec = mlmodel.get_spec()

# needed because documentation states:
# outputs must not be specified for PyTorch
output = spec.description.output[0]
output.type.imageType.colorSpace = ft.ImageFeatureType.RGB
output.type.imageType.height = height
output.type.imageType.width = width

ct.utils.rename_feature(spec, '782', 'output')

img_size_ranges = flexible_shape_utils.NeuralNetworkImageSizeRange(height_range=(256, 3072), width_range=(256, 3072))
flexible_shape_utils.update_image_size_range(spec, feature_name='input_1', size_range=img_size_ranges)
flexible_shape_utils.update_image_size_range(spec, feature_name='output', size_range=img_size_ranges)

ct.utils.save_spec(spec, "TransformerNet.mlmodel")

model.py:

import torch
import torch.nn as nn

class TransformerNet(torch.nn.Module):

    def __init__(self):
        super(TransformerNet, self).__init__()
        # Initial convolution layers
        self.conv1 = ConvLayer(3, 8, kernel_size=9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(8, affine=True)
        self.conv2 = ConvLayer(8, 16, kernel_size=3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(16, affine=True)
        self.conv3 = ConvLayer(16, 32, kernel_size=3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(32, affine=True)
        # Residual layers
        self.res1 = ResidualBlock(32)
        self.res2 = ResidualBlock(32)
        self.res3 = ResidualBlock(32)
        self.res4 = ResidualBlock(32)
        self.res5 = ResidualBlock(32)
        # Upsampling Layers
        self.deconv1 = UpsampleConvLayer(32, 16, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(16, affine=True)
        self.deconv2 = UpsampleConvLayer(16, 8, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(8, affine=True)
        self.deconv3 = ConvLayer(8, 3, kernel_size=9, stride=1)
        # Non-linearities
        self.relu = torch.nn.ReLU()

    def forward(self, X):
        y = self.relu(self.in1(self.conv1(X)))
        y = self.relu(self.in2(self.conv2(y)))
        y = self.relu(self.in3(self.conv3(y)))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.in4(self.deconv1(y)))
        y = self.relu(self.in5(self.deconv2(y)))
        y = self.deconv3(y)
        return y

class ConvLayer(torch.nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        #self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.reflection_pad = ReflectPad2d_rev(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

class ResidualBlock(torch.nn.Module):

    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out = out + residual
        return out

class UpsampleConvLayer(torch.nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        #self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.reflection_pad = ReflectPad2d_rev(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

class ReflectPad2d_rev(nn.Module):

    def __init__(self, size):
        super().__init__()
        self.size = size

    def forward(self, x):
        a = self.size
        L_list, R_list = [], []
        U_list, D_list = [], []
        for i in range(a):#i:0, 1
            l = x[:, :, :, (a-i):(a-i+1)]
            L_list.append(l)
            r = x[:, :, :, (i-a-1):(i-a)]
            R_list.append(r)
        L_list.append(x)
        x = torch.cat(L_list+R_list[::-1], dim=3)
        for i in range(a):
            u = x[:, :, (a-i):(a-i+1), :]
            U_list.append(u)
            d = x[:, :, (i-a-1):(i-a), :]
            D_list.append(d)
        U_list.append(x)
        x = torch.cat(U_list+D_list[::-1], dim=2)
        return x

System environment (please complete the following information):

Additional context

This issue severely restricts deploying MLModels across my workflow.

Mstronach commented 3 years ago

Hi Jeshua, thank you for submitting this issue. Is this different than issue #992 ? Thanks!

3DTOPO commented 3 years ago

Hi Mstronach, you're quite welcome. Yes, it appears to be a duplicate. I hadn't noticed the double-post before. It must have gotten posted twice by accident.

TobyRoseman commented 2 years ago

Yes, it appears to be a duplicate.

Ok, closing as a duplicate.