whittle-org / whittle

Python library to compress LitGPT models for resource efficient inference.
https://whittle-org.github.io/whittle/latest/
Apache License 2.0
10 stars 4 forks source link

Support extract_sub_network for other MLP classes #139

Closed gabikadlecova closed 3 weeks ago

gabikadlecova commented 1 month ago

Is your feature request related to a problem? Please describe. Currently, extract_sub_network works for GPTNeoxMLP, but not for LLamaMLP or GemmaMLP.

https://github.com/whittle-org/whittle/blob/3b18ba58a60ed0266438b67b0cfc272a291c7cb9/whittle/models/gpt/extract.py

Describe the solution you'd like Fix how weights are extracted - e.g. this line should be in a function specific to GPTNeoxMLP https://github.com/whittle-org/whittle/blob/3b18ba58a60ed0266438b67b0cfc272a291c7cb9/whittle/models/gpt/extract.py#L29

If we plan to introduce other types of models/MLPs, maybe a block class should provide a dict of submodules that need calls to load_state_dict.

gabikadlecova commented 3 weeks ago

I implemented the fix, but don't have a test yet. Are there some small models with LLamaMLP that we could use? Just like the two pythia models: https://github.com/whittle-org/whittle/blob/main/test/test_extract.py

Or, I could take MicroLLama, extract some subnet, and test the extract_sub_network with this smaller model.

rheasukthanker commented 3 weeks ago

@aaronkl one thing I noticed it that we are not extracting weights of norm layers in extract_sub_network, should it be that way or am I missing something?

gabikadlecova commented 3 weeks ago

@rheasukthanker Yes, IMO we should do that.

Since the test uses an uninitialized model, it passed because the default init for norm is a unit vector for weight and zero for biases. How should I change the test? One quick ugly idea is to initialize the norm weights of the supernet to random values.

rheasukthanker commented 3 weeks ago

I agree, we should make the tests more robust and randomize the init for subnet extraction, along with the fix for norms. But lets do that in a separate PR (from the current one). I'll create a new issue.