syncdoth / RetNet

Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226 stars 24 forks source link

passing attention_mask doesn't work for recurrent #15

Closed infiniteperplexity closed 1 year ago

infiniteperplexity commented 1 year ago

For example:

prompts = ["My dog is cute.", "My cat is very cute.", "Both my cat and dog are very cute."] tokenizer.pad_token_id = tokenizer.eos_token_id inputs = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True) outputs = model(**inputs, forward_impl='recurrent')

--> 114 current_kv = decay past_key_value + retention_mask (k.transpose(-1, -2) @ v) RuntimeError: The size of tensor a (27) must match the size of tensor b (3) at non-singleton dimension 0

It looks like the retention mask is being improperly reshaped so that all dimensions except the batch dimension are 1. However, it's not totally clear to me how the retention mask is supposed to work in a RetNet, so I can't say for sure what's going wrong.

syncdoth commented 1 year ago

The main idea of including attention_mask/retention_mask was to allow left padding of the sequences.

When using recurrent forward, the assumption is that the input_ids/retention_mask are of the shape [bs, 1], i.e. sequence length of 1. A full example would be:


prompts = ["My dog is cute.", "My cat is very cute.", "Both my cat and dog are very cute."]
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token_id = tokenizer.eos_token_id

inputs = tokenizer(prompts, return_tensors='pt', padding=True, truncation=True).to(device)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']

past_kv = None
logits = []
for i in range(input_ids.shape[1]):
    rnn_out = model(input_ids[:, i:i+1], 
                    attention_mask=attention_mask[:, i:i+1],  # attention_mask == retention_mask
                    forward_impl='recurrent',
                    past_key_values=past_kv,
                    use_cache=True,
                    sequence_offset=i)
    logits.append(rnn_out.logits)
    past_kv = rnn_out.past_key_values

I will make this fact more evident in the code.

infiniteperplexity commented 1 year ago

Thanks, makes sense now.