The forward function of uniformerv2 is :
def forward(self, x: torch.Tensor) -> torch.Tensor:
print(x.shape)
x = self.conv1(x) # shape = [, width, grid, grid]
N, C, T, H, W = x.shape
x = x.permute(0, 2, 3, 4, 1).reshape(N T, H * W, C)
x = torch.cat([
self.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
],
dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
out = self.transformer(x)
return out
I don't know how it can handle batch inference. Could you please help me?
Describe the bug
I don't know how it can handle batch inference. Could you please help me?
Branch
main branch (1.x version, such as
v1.0.0
, ordev-1.x
branch)Prerequisite
Environment
The forward function of uniformerv2 is : def forward(self, x: torch.Tensor) -> torch.Tensor: print(x.shape) x = self.conv1(x) # shape = [, width, grid, grid] N, C, T, H, W = x.shape x = x.permute(0, 2, 3, 4, 1).reshape(N T, H * W, C)
I don't know how it can handle batch inference. Could you please help me?
Describe the bug
I don't know how it can handle batch inference. Could you please help me?
Reproduces the problem - code sample
No response
Reproduces the problem - command or script
No response
Reproduces the problem - error message
No response
Additional information
No response