idiap / fast-transformers

Pytorch library for fast transformer implementations
1.65k stars 179 forks source link

forward() got multiple values for argument 'state' #39

Closed hadaev8 closed 4 years ago

hadaev8 commented 4 years ago

Sorry for disturbing, I can't understand is it me or error in lib. I'm doing sampling like this:

with torch.no_grad():
    trg_tensor = torch.LongTensor([p2idx['SOS'], ]).unsqueeze(0).to(device)
    state = None
    out_token = trg_tensor
    for i in range(max_len):
        # decoder_mask = TriangularCausalMask(trg_tensor.size(1), device=device)
        # decoder_len_mask = LengthMask(trg_tensor.new_full((trg_tensor.shape[0],), trg_tensor.shape[1], dtype=torch.int64))

        output = model.pos_decoder(model.decoder(out_token), i)
        output, state = model.fc_out(model.transformer_decoder_rnn(output.squeeze(1), memory, memory_length_mask=encoder_len_mask, state=state))
        out_token = output.argmax(-1)[:,-1].unsqueeze(0)
        trg_tensor = torch.cat([trg_tensor, out_token], axis=-1)
        if out_token == p2idx['EOS']:
            break

Whole code and trace https://colab.research.google.com/drive/1mYTh4MO_Tg6LBrhhVQUd81R92UNE56F7?usp=sharing

hadaev8 commented 4 years ago

Aha, found it RecurrentAttentionLayer have no memory_lengths argument https://github.com/idiap/fast-transformers/blob/ea9cb3b1751f0b2f1e661b087f3dd3ec8a413ab0/fast_transformers/recurrent/attention/self_attention/attention_layer.py#L54 But RecurrentTransformerDecoderLayer try to pass it https://github.com/idiap/fast-transformers/blob/8acb570071926605da1d7f22d4c1239be2d80b55/fast_transformers/recurrent/transformers.py#L215

hadaev8 commented 4 years ago

Nevermind, didnt realize where is RecurrentCrossAttentionLayer also. Kind of complicated honestly.