Dear author, I have some doubts about the code here:
####inter Attention######
x = x.unfold(dimension=1, size=self.patch_size, step=self.stride) # [b x patch_num x nvar x dim x patch_len]
x = x.permute(0, 2, 1, 3, 4) # [b x nvar x patch_num x dim x patch_len ]
b, nvar, patch_num, dim, patch_len = x.shape
x = torch.reshape(x, (
x.shape[0] * x.shape[1], x.shape[2], x.shape[3] * x.shape[-1])) # [b*nvar, patch_num, dim*patch_len]
Could you please explain what nvar means, and why can you use (b x nvar)? Does (b x nvar) represent a new batch?
Dear author, I have some doubts about the code here:
Could you please explain what nvar means, and why can you use (b x nvar)? Does (b x nvar) represent a new batch?