rwth-i6 / returnn

The RWTH extensible training framework for universal recurrent neural networks
http://returnn.readthedocs.io/
Other
348 stars 130 forks source link

`rf.RelPosCausalSelfAttention` fails with `single_step_dim` #1585

Open LucaG1 opened 1 month ago

LucaG1 commented 1 month ago

Hi, I'm having a problem with rf.RelPosCausalSelfAttention when using it in a transformer decoder. It fails because it wants to remove single_step_dim from a tensor that does not have it in the function _rel_pos_enc_shift here: https://github.com/rwth-i6/returnn/blob/23d666ccf3ac9e748fce4e0d27afe353133eca48/returnn/frontend/attention.py#L412

https://github.com/rwth-i6/returnn/blob/23d666ccf3ac9e748fce4e0d27afe353133eca48/returnn/frontend/attention.py#L533

The input: matrix_bd looks like this: Tensor{'dot', ['initial-beam'(1),B?,'num_heads'(8),'self_att_expand_dim_init+1'(1)]}

The error i get looks like this.

    line: matrix_bd = _rel_pos_enc_shift(matrix_bd, axis, pos_emb_spatial_dim, hist_dim)
    locals:
      matrix_bd = <local> Tensor{'dot', ['initial-beam'(1),B?,'num_heads'(8),'self_att_expand_dim_init+1'(1)]}
      _rel_pos_enc_shift = <global> <function _rel_pos_enc_shift at 0x7f78c8937ac0>
      axis = <local> Dim{'single-step'!}
      pos_emb_spatial_dim = <local> Dim{'self_att_expand_dim_init+1'(1)}
      hist_dim = <local> Dim{'self_att_expand_dim_init+1'(1)}
  File "returnn/returnn/frontend/attention.py", line 412, in _rel_pos_enc_shift
    line: batch_dims = x.remaining_dims((axis, pos_emb_spatial_dim))
    locals:
      batch_dims = <not found>
      x = <local> Tensor{'dot', ['initial-beam'(1),B?,'num_heads'(8),'self_att_expand_dim_init+1'(1)]}
      x.remaining_dims = <local> <bound method _TensorMixin.remaining_dims of Tensor{'dot', ['initial-beam'(1),B?,'num_heads'(8),'self_att_expand_dim_init+1'(1)]}>
      axis = <local> Dim{'single-step'!}
      pos_emb_spatial_dim = <local> Dim{'self_att_expand_dim_init+1'(1)}
  File "returnn/returnn/tensor/_tensor_extra.py", line 1849, in _TensorMixin.remaining_dims
    line: batch_dims.remove(remove_)
    locals:
      batch_dims = <local> [Dim{'initial-beam'(1)}, Dim{B}, Dim{'num_heads'(8)}, Dim{'self_att_expand_dim_init+1'(1)}]
      batch_dims.remove = <local> <built-in method remove of list object at 0x7f7811ab6900>
      remove_ = <local> Dim{'single-step'!}
ValueError: list.remove(x): x not in list

I don't have an easy setup yet for you to reproduce this. However I think it should be easily reproducible when using rf.RelPosCausalSelfAttention with single_step_dim.

I also need to look deeper into the functionality behind this in order to understand what the correct behaviour would be.

If I have any new information on this I will post it here.

albertz commented 1 month ago

I think it's just not implemented yet.

albertz commented 1 month ago

What's the state here? @LucaG1 do you have a fix for this? I thought you are already using this?

LucaG1 commented 1 month ago

Right sorry. I forgot to post it here. For me currently this fix is working. But I am still not sure if this is the correct way to do this.

        if axis == single_step_dim:
            matrix_bd = rf.expand_dim(matrix_bd, axis)

        matrix_bd = _rel_pos_enc_shift(matrix_bd, axis, pos_emb_spatial_dim, hist_dim)

        if axis == single_step_dim:
            matrix_bd = rf.squeeze(matrix_bd, axis)

So just adding and removing single_step_dim for the call of _rel_pos_enc_shift and hoping that it does the right thing for that case as well.

albertz commented 1 month ago

relative_positional_encoding needs a proper query_offset in case of single step, or not?

albertz commented 1 month ago

Also, rf.expand_dim(matrix_bd, single_step_dim) does not make sense. I wonder that even works? That should throw an exception. single_step_dim is not allowed to be part of the shape of an actual tensor.

LucaG1 commented 1 month ago

I checked and I think it does not need any query offset. In my case the _rel_pos_enc_shift function does not affect matrix_bd anymore. I guess the correct thing to do would then be

        if axis != single_step_dim:
            matrix_bd = _rel_pos_enc_shift(matrix_bd, axis, pos_emb_spatial_dim, hist_dim)

If you want I can push this fix.

albertz commented 1 month ago

I checked and I think it does not need any query offset.

Why? That sounds incorrect. Surely a (rel or abs) positional encoding must somehow depend on the position?

albertz commented 1 month ago

It would be good if we also have a test case where we operate on the whole seq in one case, and then operate step-by-step, and then check that we get exactly the same output.

LucaG1 commented 1 month ago

Why? That sounds incorrect. Surely a (rel or abs) positional encoding must somehow depend on the position?

Ah sorry, my bad. I was thinking of something else. Seems to me query_offset is computed automatically here: https://github.com/rwth-i6/returnn/blob/61ad52a72916d5834a211ea11a8536388a0d7864/returnn/frontend/attention.py#L762 for the case of single_step_dim