nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
95 stars 48 forks source link

torch.aten.conv_tbc conv1d and conv3d #345

Closed renxida closed 1 month ago

renxida commented 10 months ago

Seems pretty straightforward. Just need to transpose the inputs.

According to this issue,

The input shape for nn.Conv1d is batch x channels x time (BCT), which would require a transpose since the rest of the network operates with time x batch x channel (TBC) tensors. conv_tbc takes time x batch x channel (TBC) input directly.

So this would be as simple as TBC (transpose 0 1) BTC (transpose 1 2) BCT

renxida commented 10 months ago

From https://github.com/pytorch/pytorch/blob/5bc896e5dc856ca831a7f78fdda3b95b1cb8c631/torch/onnx/symbolic_opset9.py#L3605


@_onnx_symbolic("aten::conv_tbc")
@symbolic_helper.parse_args("v", "v", "v", "i")
@_beartype.beartype
def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad):
    if symbolic_helper.is_caffe2_aten_fallback():
        return g.at("conv_tbc", input, weight, bias, pad_i=pad)
    else:
        # input must have 3 dimensions, see:
        # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10
        # input = (time, batch, in_channels)
        # weight = (kernel_width, in_channels, out_channels)
        # bias = (out_channels,)
        input = g.op("Transpose", input, perm_i=[1, 2, 0])
        weight = g.op("Transpose", weight, perm_i=[2, 1, 0])
        conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1)
        return g.op("Transpose", conv, perm_i=[2, 0, 1])     
renxida commented 10 months ago

It looks like to implement convtbc i need to implement conv1d, and conv3d seems pretty trivial to tack on.

vivekkhandelwal1 commented 1 month ago

Support added for these ops.