Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226
stars
24
forks
source link
I don't know if I should input attention_mask in the SFT process #40
like your code:
input_ids = torch. LongTensor([[1,2,3,4,1,2,5,5], [5,5,1,2,3,4,1,2]]). to(device) retention_mask = torch. LongTensor([[1,1,1,1,1,1,0,0], [0,0,1,1,1,1,1,1]]). to(device)
parallel_outputs = model(input_ids, retention_mask=retention_mask, forward_impl='parallel', use_cache=True)
If I want to sft train, should I pass in retention_mask, or do I just need input_ids and labels