Closed renxida closed 1 month ago
@_onnx_symbolic("aten::conv_tbc")
@symbolic_helper.parse_args("v", "v", "v", "i")
@_beartype.beartype
def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("conv_tbc", input, weight, bias, pad_i=pad)
else:
# input must have 3 dimensions, see:
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10
# input = (time, batch, in_channels)
# weight = (kernel_width, in_channels, out_channels)
# bias = (out_channels,)
input = g.op("Transpose", input, perm_i=[1, 2, 0])
weight = g.op("Transpose", weight, perm_i=[2, 1, 0])
conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1)
return g.op("Transpose", conv, perm_i=[2, 0, 1])
It looks like to implement convtbc i need to implement conv1d, and conv3d seems pretty trivial to tack on.
Support added for these ops.
Seems pretty straightforward. Just need to transpose the inputs.
According to this issue,
So this would be as simple as TBC (transpose 0 1) BTC (transpose 1 2) BCT