AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

Add dispatch and combine masks for dropping #841

Closed RissyRan closed 3 weeks ago

RissyRan commented 1 month ago

Description

Test

A single moe layer unit test to compare before & after reshape on different capacity factors - one example.

RissyRan commented 4 weeks ago

Looks great to me, well done! There is a lot of logic here that can be easily messed up - have you been able to test this in some way? Ideally we would set up a unit test but I don't see a candidate API (stable,exposed) to test except the MoE Block with a capacity_factor. Perhaps generate_masks, but that may be more of an internal implementation API that is subject to change

Yes, I tested it before/after the change in one layer, however this test will invalid once the change is merged. Yeah, let me probably add a unit test for generate masks function on a small example.

Update: add a test, and it passed.