This allows for masking variable-lengthed sequences with the correct distribution (e.g. p = 0.6) with respect to the sequence mask.
I had to branch the masking behavior when the mask id is provided, because the existing batch -> mask -> rearrange back to batch doesn't work for ragged sequences, since each sequence in the batch can have a variable number of tokens. @lucidrains let me know if you are ok with this approach or if you have a better idea 🙏
I also added an option to return logits along with the loss to allow computing downstream metrics (the metrics are not included in this PR), and put in some minor fixes for the decoder depth & cleaned up some issues with the logging from the original trainer work.
This allows for masking variable-lengthed sequences with the correct distribution (e.g. p = 0.6) with respect to the sequence mask.
I had to branch the masking behavior when the mask id is provided, because the existing batch -> mask -> rearrange back to batch doesn't work for ragged sequences, since each sequence in the batch can have a variable number of tokens. @lucidrains let me know if you are ok with this approach or if you have a better idea 🙏
I also added an option to return logits along with the loss to allow computing downstream metrics (the metrics are not included in this PR), and put in some minor fixes for the decoder depth & cleaned up some issues with the logging from the original trainer work.