lix19937 / tensorrt-insight

deep insight tensorrt
1 stars 0 forks source link

g.op define layer with fixed shape of output #7

Open lix19937 opened 1 month ago

lix19937 commented 1 month ago

When use g.op define layer, in torch.onnx.export, we can use None para as input.

import torch

class TRT_SCA(torch.autograd.Function):  
    @staticmethod
    def forward(ctx,
                query,
                key,
                value,
                reference_points,
                spatial_shapes,
                reference_points_cam,
                bev_mask,
                level_start_index
                ):
        out = torch.randn(1, 1600, 256, dtype=torch.float32) 
        return out  # I just want to assign the out shape is [1, 1600, 256]   

    @staticmethod
    def symbolic(g, 
                query,
                key,
                value,
                reference_points,
                spatial_shapes,
                reference_points_cam,
                bev_mask,
                level_start_index
                ):
        return g.op("TRT::SCATT",
                query,
                key,
                value,
                reference_points,
                spatial_shapes,
                reference_points_cam,
                bev_mask,
                level_start_index
                )

trt_sca = TRT_SCA.apply 

class SpatialCrossAttention(torch.nn.Module):
    def __init__(self):
        super(SpatialCrossAttention, self).__init__()

    def forward(self,
                query,
                key,
                value,
                reference_points=None,
                spatial_shapes=None,
                reference_points_cam=None,
                bev_mask=None,
                level_start_index=None
                ):       
        return trt_sca(
            query,
            key,
            value,
            reference_points,
            spatial_shapes,
            reference_points_cam,
            bev_mask,
            level_start_index
            ) 
lix19937 commented 1 month ago

warning

WARNING: The shape inference of xxx::Clip type is missing, so it may result in wrong shape inference
for the exported graph. Please consider adding it in symbolic function.