taichi-dev / taichi

Productive, portable, and performant GPU programming in Python.
https://taichi-lang.org
Apache License 2.0
25.35k stars 2.27k forks source link

Add support for conversion of torch scalar to taichi scalar #8373

Open obust opened 10 months ago

obust commented 10 months ago

Torch scalars are 0-dimensional tensors.

Currently taichi cannot convert a zero dimensional tensor (e.g. torch.tensor(2, dtype=torch.int32)) to taichi scalar (e.g. ti.int32).

Example

Suppose you have this fill kernel:

import torch
import taichi as ti

@ti.kernel
def fill(out: ti.types.ndarray(dtype=ti.int32), value: ti.int32):
    for I in ti.grouped(out):
        out[I] = value

We can pass a python scalar

out = torch.empty((10,), dtype=torch.int32)
value = 2  # python scalar
fill(out, value)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=torch.int32)

We can also pass a numpy scalar

out = np.empty((10,), dtype=np.int32)
value = np.int32(2)  # numpy scalar
fill(out, value)
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int32)

but we cannot pass a torch scalar

out = torch.empty((10,), dtype=torch.int32)
value = torch.tensor(2, dtype=torch.int32)  # torch scalar
fill(out, value)
TaichiRuntimeTypeError: 
Argument 1 (type=<class 'torch.Tensor'>) cannot be converted into required type i32
jim19930609 commented 10 months ago

As you just mentioned, "torch-scalar" is a "zero-dimension" tensor, so basically you still need to treat it as a Ndarray in Taichi:

import torch
import taichi as ti

ti.init(arch=ti.cpu)

@ti.kernel
def fill(out: ti.types.ndarray(), value: ti.types.ndarray()):
    for I in ti.grouped(out):
        out[I] = value[None]

out = torch.empty((10,), dtype=torch.int32)
value = torch.tensor(2, dtype=torch.int32)  # torch scalar
fill(out, value)

Try this out.