AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

Add dropping strategy #833

Closed RissyRan closed 1 month ago

RissyRan commented 1 month ago

Description

Enable token dropping for matmul implementation

Next steps:

Test

Have a unite test to check a single layer output (become dropless if capacity factor is large enough):

ZhiyuLi-goog commented 1 month ago

Thank you @RissyRan for adding dropping strategy!! I just added some nit.

RissyRan commented 1 month ago

Thank you @RissyRan for adding dropping strategy!! I just added some nit.

Thanks Zhiyu! Have you published comments?

ZhiyuLi-goog commented 1 month ago

Thank you @RissyRan for adding dropping strategy!! I just added some nit.

Thanks Zhiyu! Have you published comments?

Just published.