open-mmlab / mmaction2

OpenMMLab's Next Generation Video Understanding Toolbox and Benchmark
https://mmaction2.readthedocs.io
Apache License 2.0
4.06k stars 1.2k forks source link

[Bug] can uniformerv2 do batch inference? #2748

Open aixiaodewugege opened 8 months ago

aixiaodewugege commented 8 months ago

Branch

main branch (1.x version, such as v1.0.0, or dev-1.x branch)

Prerequisite

Environment

The forward function of uniformerv2 is : def forward(self, x: torch.Tensor) -> torch.Tensor: print(x.shape) x = self.conv1(x) # shape = [, width, grid, grid] N, C, T, H, W = x.shape x = x.permute(0, 2, 3, 4, 1).reshape(N T, H * W, C)

    x = torch.cat([
        self.class_embedding.to(x.dtype) + torch.zeros(
            x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
    ],
                  dim=1)  # shape = [*, grid ** 2 + 1, width]
    x = x + self.positional_embedding.to(x.dtype)
    x = self.ln_pre(x)

    x = x.permute(1, 0, 2)  # NLD -> LND
    out = self.transformer(x)
    return out

I don't know how it can handle batch inference. Could you please help me?

Describe the bug

I don't know how it can handle batch inference. Could you please help me?

Reproduces the problem - code sample

No response

Reproduces the problem - command or script

No response

Reproduces the problem - error message

No response

Additional information

No response