hidet-org / hidet

An open-source efficient deep learning framework/compiler, written in python.
https://hidet.org
Apache License 2.0
634 stars 50 forks source link

[Bug] ops.concat does not work the same as torch.cat #440

Open CBalaa opened 4 months ago

CBalaa commented 4 months ago

Describe the bug ops.concat does not work the same as torch.cat on an empty tensor(shape=[0]), and this cause an Runtime error when interpreting torch.cat.

To Reproduce The following is my code for test:

import torch
import hidet

from hidet.graph import ops
x = torch.rand([3, 4, 5])
y = torch.rand([0])
print(torch.cat([x, y]))

# x = hidet.randn([3, 4, 5])
# y = hidet.randn([0])
# print(ops.concat([x, y], axis = 0))

The hidet code report an error but torch work normally.

Expected behavior I hope that ops.cat can do as tensor.cat does. Or how can I custom what interpreter dose, for example, interpret torch.cat as my custome operator, but not ops.concat defaultly.

Enviroment