microsoft / unilm

Large-scale Self-supervised Pre-training Across Tasks, Languages, and Modalities
https://aka.ms/GeneralAI
MIT License
19.54k stars 2.49k forks source link

Maybe some bugs of YOCO #1612

Closed x54-729 closed 1 week ago

x54-729 commented 2 weeks ago
  1. In SlidingWindowAttention, self.window_size is sliding_size - 1 when init: https://github.com/microsoft/unilm/blob/378d4280ebf68a2e10d74c9e8081823934b65249/YOCO/yoco/models/decoder/sliding_window_attention.py#L21 But in forward it is still minus 1 when calling flash_attn_func https://github.com/microsoft/unilm/blob/378d4280ebf68a2e10d74c9e8081823934b65249/YOCO/yoco/models/decoder/sliding_window_attention.py#L64
  2. Also in SlidingWindowAttention, the using of key value seems not right https://github.com/microsoft/unilm/blob/378d4280ebf68a2e10d74c9e8081823934b65249/YOCO/yoco/models/decoder/sliding_window_attention.py#L50-L64 When calling flash_attn_func, the k v is not concat
donglixp commented 2 weeks ago

https://github.com/microsoft/unilm/commit/53ed1159c596f33af5b228f6041f6d7ffee963c0

donglixp commented 2 weeks ago

fixed

x54-729 commented 2 weeks ago

https://github.com/microsoft/unilm/blob/53ed1159c596f33af5b228f6041f6d7ffee963c0/YOCO/yoco/models/decoder/sliding_window_attention.py#L52-L53 https://github.com/microsoft/unilm/blob/53ed1159c596f33af5b228f6041f6d7ffee963c0/YOCO/yoco/models/decoder/sliding_window_attention.py#L58-L59

After fix, which one is correct here as parameter, self.window_size or self.window_size - 1?

Edit:: also for this if statement: https://github.com/microsoft/unilm/blob/53ed1159c596f33af5b228f6041f6d7ffee963c0/YOCO/yoco/models/decoder/sliding_window_attention.py#L57

sunyt32 commented 1 week ago

The exact window size follows FlashAttention interface. The result is always correct when provided KV cache is not less than real window size. The code is just for convenience.

x54-729 commented 1 week ago

Thanks!