apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.45k stars 645 forks source link

Add torch cumsum dtype support #2373

Closed M-Quadra closed 4 weeks ago

M-Quadra commented 1 month ago
import torch
from torch import nn
import coremltools as ct
from coremltools.converters.mil.mil import types

class Model(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.cumsum(dim=-1, dtype=torch.int32)

model = Model().eval()
x = torch.tensor([0.5, 1.5, 2.5])
y = model(x)
print(y)

traced_model = torch.jit.trace(model, (x))
mlmodel = ct.convert(
    traced_model,
    inputs=[
        ct.TensorType(name="x", shape=x.shape, dtype=types.fp32),
    ],
    outputs=[
        ct.TensorType(name="y"),
    ],
)
mlmodel.save("tmp.mlpackage")

It will get the right result.

TobyRoseman commented 4 weeks ago

Change looks good.

CI: https://gitlab.com/coremltools1/coremltools/-/pipelines/1509752385

TobyRoseman commented 4 weeks ago

Thanks for the fix @M-Quadra