syncdoth / RetNet

Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226 stars 24 forks source link

I don't know if I should input attention_mask in the SFT process #40

Open wac81 opened 7 months ago

wac81 commented 7 months ago

like your code:

input_ids = torch. LongTensor([[1,2,3,4,1,2,5,5], [5,5,1,2,3,4,1,2]]). to(device) retention_mask = torch. LongTensor([[1,1,1,1,1,1,0,0], [0,0,1,1,1,1,1,1]]). to(device)

parallel_outputs = model(input_ids, retention_mask=retention_mask, forward_impl='parallel', use_cache=True)

If I want to sft train, should I pass in retention_mask, or do I just need input_ids and labels