jzhang38 / EasyContext

Memory optimization and training recipes to extrapolate language models' context length to 1 million tokens, with minimal hardware.
Apache License 2.0
529 stars 33 forks source link

Does the input sharding match exact optimization of long sequence? #3

Closed guanzhchen closed 3 months ago

guanzhchen commented 3 months ago

Thanks for your exciting work!

I found the extract_local function seems to split the input sequence length L into L/world_size. Are parameters optimized (backward) for each chunk rather than the whole long sequence? So have you tried if there are any approximation errors or the optimization is length-agnostic?

jzhang38 commented 3 months ago

Are parameters optimized (backward) for each chunk rather than the whole long sequence?

The whole sequence.

have you tried if there are any approximation errors or the optimization is length-agnostic?

https://github.com/zhuzilin/ring-flash-attention/blob/55ff66fd35f329dfcc24ce7a448bfdd532865966/test/test_zigzag_ring_flash_attn_func.py#L121

guanzhchen commented 3 months ago

That makes sense! Thank you!