apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.46k stars 647 forks source link

Add torch one_hot op #2358

Closed M-Quadra closed 1 month ago

M-Quadra commented 1 month ago

Example:

import torch
from torch import nn
from torch.nn import functional as F
from typing import Final
from cp import coremltools as ct
from cp.coremltools.converters.mil.mil import types

class Model(nn.Module):
  def forward(self, x: torch.LongTensor) -> torch.Tensor:
      return F.one_hot(x, num_classes=10)

model = Model().eval()
x = torch.arange(10)
traced_model = torch.jit.trace(model, (x))

var_dim: Final[ct.RangeDim] = ct.RangeDim(1, 1_000)
mlmodel = ct.convert(
    traced_model,
    inputs=[
        ct.TensorType(name="x", shape=ct.Shape([var_dim]), dtype=types.int32),
    ],
    outputs=[
        ct.TensorType(name="y")
    ],
)
mlmodel.save("tmp.mlpackage")
TobyRoseman commented 1 month ago

Change looks good to me.

CI run: https://gitlab.com/coremltools1/coremltools/-/pipelines/1485462215