wxie9 / CARD

64 stars 8 forks source link

Forward pass doesn't work due to einsum error #3

Open eiriksteen opened 4 days ago

eiriksteen commented 4 days ago

I am trying to reproduce the results from your paper, but the code doesn't run. Specifically the last line here throws an error which isn't very easy to debug for someone not familiar with the code:

    def forward(self, src, *args, **kwargs):

        B, nvars, H, C, = src.shape

        qkv = self.qkv(src).reshape(B, nvars, H, 3, self.n_heads,
                                    C // self.n_heads).permute(3, 0, 1, 4, 2, 5)

        q, k, v = qkv[0], qkv[1], qkv[2]

        if not self.over_hidden:

            attn_score_along_token = torch.einsum(
                'bnhed,bnhfd->bnhef', self.ema(q), self.ema(k)) / self.head_dim ** -0.5

            attn_along_token = self.attn_dropout(
                F.softmax(attn_score_along_token, dim=-1))

            output_along_token = torch.einsum(
                'bnhef,bnhfd->bnhed', attn_along_token, v)

        else:

            v_dp, k_dp = self.dynamic_projection(
                v, self.dp_v), self.dynamic_projection(k, self.dp_k)
            attn_score_along_token = torch.einsum(
                'bnhed,bnhfd->bnhef', self.ema(q), self.ema(k_dp)) / self.head_dim ** -0.5

This is the error:

  File "/Users/eiriksteen/Personal/work/relu/code/vqformer/src/vqformer/models/baselines/card.py", line 131, in forward
    output_1 = a_1(inputs.permute(0, 2, 1, 3)).permute(0, 2, 1, 3)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/eiriksteen/miniconda3/envs/ofa/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/eiriksteen/miniconda3/envs/ofa/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/eiriksteen/Personal/work/relu/code/vqformer/src/vqformer/models/baselines/card.py", line 251, in forward
    'bnhed,bnhfd->bnhef', self.ema(q), self.ema(k_dp)) / self.head_dim ** -0.5
                          ^^^^^^^^^^^
  File "/Users/eiriksteen/Personal/work/relu/code/vqformer/src/vqformer/models/baselines/card.py", line 205, in ema
    return torch.einsum('bnhad,ga ->bnhgd', src, self.ema_matrix[:src.shape[-2], :src.shape[-2]])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/eiriksteen/miniconda3/envs/ofa/lib/python3.12/site-packages/torch/functional.py", line 386, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: einsum(): subscript a has size 8 for operand 1 which does not broadcast with previously seen size 21

Any idea what the problem is?

wxie9 commented 3 days ago

It seems the error is related to the self.ema() function and the dimensions in q/k and self.ema_matrix doesn't match. May I have the detailed dimensions of q/k and self.ema_matrix in your experiment?

eiriksteen commented 10 hours ago

Thank you for the quick response! These are the shapes:

q: torch.Size([32, 7, 8, 21, 32]), ema: torch.Size([8, 8])