mistralai / mistral-inference

Official inference library for Mistral models
https://mistral.ai/
Apache License 2.0
9.58k stars 846 forks source link

Fix sliding window mask size in one_file_ref.py #137

Open MrYxJ opened 6 months ago

MrYxJ commented 6 months ago

Fixed a minor encoding error in the mask matrix size in one_file_ref.py where the original mask would display one more column than expected size.

MrYxJ commented 6 months ago
tensor = torch.ones((5, 5))
sliding_window = 3

mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
# mask = torch.triu(mask, diagonal=-sliding_window)  # actually produces a large column of slide window size
mask = torch.triu(mask, diagonal=-sliding_window+1)  
mask = torch.log(mask)
mask
tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [-inf, 0., 0., 0., -inf],
        [-inf, -inf, 0., 0., 0.]])