apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.33k stars 626 forks source link

Torch var not found in context - although variable is class attribute #1355

Open JRGit4UE opened 2 years ago

JRGit4UE commented 2 years ago

❓Converting simple PyTorch model via TorchScript to CoreML fails, as a member variable cannot be found in forward() call

System Information

I am a newbie in converting models and maybe someone please can tell me, why the self.attribute1 value can not be found in the forward() call? How can I work around that error?

import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional, Dict, List
import coremltools as ct

print(torch.__version__) # 1.9.1
print(ct.__version__) # 5.1.0

class SimpleTest(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.attribute1: float = 42.

    def forward(self, x: Tensor) -> Tensor:
        fillval: float = 42.              # this alone works
        fillval: float = self.attribute1  # ValueError: Torch var fillval.3 not found in context
        some_result = torch.full((3, 550, 550), fillval)
        return some_result

test_batch = torch.rand(1,3,550,550)
m = SimpleTest()
m.eval()
result = m(test_batch)
print('1 Before torch-script -------------------')
s = torch.jit.script(m)
print('2 Before coreml-convert -------------------')
c = ct.convert(s,
    inputs=[ct.TensorType(name="test", shape=test_batch.shape)],
    source='auto',
    minimum_deployment_target=ct.target.iOS15,
    compute_units=ct.ComputeUnit.CPU_ONLY,
    compute_precision=ct.precision.FLOAT32,
    convert_to='mlprogram',
    debug=True
)
JRGit4UE commented 2 years ago

Any comments? Have I got something completely wrong? Is it a possible bug? Why does the context of the forward(x) node not contain the self.attribute1 value?? Is there a way to squeeze the missing var into the context?

TobyRoseman commented 2 years ago

Is there are reason you need to use a PyTorch script rather than a PyTorch traced model?

Replacing this line: s = torch.jit.script(m) with the following line: s = torch.jit.trace(m, test_batch) will allow you to convert the model.

JRGit4UE commented 2 years ago

Thank you for your hint, I will check how far I can get with it.

Is there are reason you need to use a PyTorch script rather than a PyTorch traced model?

Definitely, the pytorch model I try to convert is Yolact++ and full of untraceable code sections. So, every convenience that coremltools can offer would help to save a lot of dev time..

glenntu commented 2 years ago

Hi @JRGit4UE, below is a working example of your code. The only change is self.attribute1 is now a Parameter.

import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional, Dict, List
import coremltools as ct

print(torch.__version__) # 1.9.1
print(ct.__version__) # 5.1.0

class SimpleTest(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.attribute1 = nn.parameter.Parameter(torch.tensor(42.), requires_grad=False) # 👈

    def forward(self, x: Tensor) -> Tensor:
        fillval = self.attribute1  # ValueError: Torch var fillval.3 not found in context
        some_result = torch.full((3, 550, 550), fillval)
        return some_result

test_batch = torch.rand(1,3,550,550)
m = SimpleTest()
m.eval()
result = m(test_batch)
print('1 Before torch-script -------------------')
s = torch.jit.script(m)
print('2 Before coreml-convert -------------------')
c = ct.convert(s,
    inputs=[ct.TensorType(name="test", shape=test_batch.shape)],
    source='auto',
    minimum_deployment_target=ct.target.iOS15,
    compute_units=ct.ComputeUnit.CPU_ONLY,
    compute_precision=ct.precision.FLOAT32,
    convert_to='mlprogram',
    debug=True
)

According to the PyTorch docs:

Parameters are Tensor subclasses, that have a very special property when used with Modules - when they’re assigned as Module attributes they are automatically added to the list of its parameters, and will appear e.g. in parameters() iterator.

JRGit4UE commented 2 years ago

@glenntu 👍 simple change, great result - thanksalot