When I needed the original three loss scores, I return them and printed these loss values, and I found that the values were all very large, which was not normal.So I checked the calculation method of loss and found that there seemed to be some errors when calculating weights.
In your code toolformer_pytorch.py line 427, you use 0 replace the elements in weights when weights == pad_id(-1), but when the elements equal to -1 in t (get from get_arange_start_at_token_id()) pass through the weighting_fn(), they all become 1.2. There is no more element equal to -1 in weights.
So when we calculate the weighted loss value, the position (ex, before ) we don't want to calculate is 1.2, which makes the final loss value very large.
When I needed the original three loss scores, I return them and printed these loss values, and I found that the values were all very large, which was not normal.So I checked the calculation method of loss and found that there seemed to be some errors when calculating weights.
In your code toolformer_pytorch.py line 427, you use 0 replace the elements in weights when weights == pad_id(-1), but when the elements equal to -1 in t (get from get_arange_start_at_token_id()) pass through the weighting_fn(), they all become 1.2. There is no more element equal to -1 in weights. So when we calculate the weighted loss value, the position (ex, before ) we don't want to calculate is 1.2, which makes the final loss value very large.
Just change the condition toolformer_pytorch.py line 427 to (t == pad_id) or eights == 1.2 can fix this.