facebookincubator / AITemplate

AITemplate is a Python framework which renders neural network into high performance CUDA/HIP C++ code. Specialized for FP16 TensorCore (NVIDIA GPU) and MatrixCore (AMD GPU) inference.
Apache License 2.0
4.54k stars 363 forks source link

`Unsupported workload for this conv2d specialization` when using dynamic shape together with permute #981

Closed jiangwei221 closed 9 months ago

jiangwei221 commented 9 months ago

Hi AIT team:

I'm working on a video model based on stable diffusion, and my input tensor is in shape [batch, frame, channel, height, width]. In order to pass this tensor to the conv2d layer, I have to first merge batch and frame dimensions, and permute the channel to the 4th position. Following is the minimum code to reproduce the Unsupported workload for this conv2d specialization error.

import torch
from aitemplate.frontend import nn, Tensor
from aitemplate.compiler.ops import reshape, permute
from aitemplate.frontend import Tensor, IntVar
from aitemplate.compiler import compile_model
from aitemplate.testing import detect_target

torch.set_grad_enabled(False)

class MyConvNetAit(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_in = nn.Conv2dBias(
            4, 320, kernel_size=3, padding=1, stride=1
        )

    def forward(self, sample):
        batch_size, num_frames, channel, height, width = sample.shape()
        sample = reshape()(sample, [batch_size * num_frames, channel, height, width])
        sample = permute()(sample, [0, 2, 3, 1])
        sample = self.conv_in(sample)
        sample = permute()(sample, [0, 3, 1, 2])
        sample = reshape()(sample, [batch_size, num_frames, 320, height, width])
        return sample

class MyConvNetPt(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_in = torch.nn.Conv2d(
            4, 320, kernel_size=3, padding=1, stride=1
        )

    def forward(self, sample):
        batch_size, num_frames, channel, height, width = sample.shape
        sample = sample.view(batch_size * num_frames, channel, height, width)
        sample = self.conv_in(sample)
        sample = sample.view(batch_size, num_frames, 320, height, width)
        return sample

def map_pt_params(ait_model, pt_model):
    ait_model.name_parameter_tensor()
    pt_params = dict(pt_model.named_parameters())
    mapped_pt_params = {}
    for name, _ in ait_model.named_parameters():
        ait_name = name.replace(".", "_")
        assert name in pt_params
        mapped_pt_params[ait_name] = pt_params[name]
    return mapped_pt_params

def mark_output(y):
    if type(y) is not tuple:
        y = (y,)
    for i in range(len(y)):
        y[i]._attrs["is_output"] = True
        y[i]._attrs["name"] = "output_%d" % (i)
        y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]]
        print("output_{} shape: {}".format(i, y_shape))

def my_compile():
    ait_model = MyConvNetAit()
    ait_model.name_parameter_tensor()

    pt_model = MyConvNetPt().cuda().half()
    pt_model.eval()

    weights = map_pt_params(ait_model, pt_model)

    b, f, c, h, w = 2, 32, 4, 40, 72
    x_pt = torch.randn(b, f, c, h, w).cuda().half()
    y_pt = pt_model(x_pt.clone())
    print(y_pt.shape)
    ait_h = IntVar(values=[40, 72], name="height")
    ait_w = IntVar(values=[40, 72], name="width")
    ait_f = IntVar(values=[16, 64], name="num_frames")
    x_ait = Tensor(
        shape=[b, ait_f, c, ait_h, ait_w],
        dtype="float16",
        name="hidden_states",
        is_input=True,
    )

    y_ait = ait_model(sample=x_ait)
    mark_output(y_ait)
    target = detect_target(use_fp16_acc=True)

    with compile_model(y_ait, target, "./tmp", "my_conv", constants=weights) as module:
        y_ait_infered = torch.zeros([b, f, 320, h, w]).cuda().half()
        inputs = {
            "hidden_states": x_pt
        }
        outputs = {"output_0": y_ait_infered}
        module.run_with_tensors(inputs, outputs)

if __name__ == "__main__":
    my_compile()

Do you have any ideas why I can't execute this model? Could it be a hardware/environment issue or a limitation of current version of AIT? Thanks a lot!

aakhundov commented 9 months ago

@jiangwei221 Thanks for reporting! I can repro the error. Will look into it and circle back.

aakhundov commented 9 months ago

@jiangwei221 So the error you're getting is due to the way reshape is implemented in AIT.

To resolve the error, replace this line

sample = reshape()(sample, [batch_size * num_frames, channel, height, width])

with this line

sample = reshape()(sample, [-1, channel, height, width])

This makes compilation run successfully. However, the outputs of your PT and AIT models don't match. So, I assume, something is not equivalent in the implementation. Let me know if you need help figuring that out.

jiangwei221 commented 9 months ago

I see, thanks a lot! A follow-up question is that if I want to do something like:

sample = reshape()(sample, [batch_size * num_frames, height * width, channel])

Where 1st & 2nd dims are merged, and 3rd&4th dims are also merged. In this case, what should I do?

aakhundov commented 9 months ago

Good question. I believe, it should work in general. The issue in your case may be related to the batch_size being an IntImm (fixed / static dim) and num_frames being an IntVar (dynamic dim). Although it should work. I'll look into it and let you know what's happening there.

aakhundov commented 9 months ago

Ok, so here's what is happening. aitemplate.Tensor.shape() method is actually meant to be used in the IR, and not in the AIT model code hand-written by the user. What I'd suggest using instead is the size operator. So if you from aitemplate.compiler.ops import reshape, permute, size and do this in the MyConvNetAit.forward:

batch_size, num_frames, channel, height, width = size()(sample)

your original code will work just fine. And then you can do any (basic: +, -, *, //) arithmetics on the dims, static or dynamic. Hope this helps.

jiangwei221 commented 9 months ago

Hey @aakhundov , after I changed all x.shape() to size()(x), everything works now! Except for certain model it will return a gcc internal error as mentioned in #980 , but it can be easily solved by upgrading gcc9 to gcc10. Thanks a lot for your help!