Closed gabikadlecova closed 3 weeks ago
I implemented the fix, but don't have a test yet. Are there some small models with LLamaMLP that we could use? Just like the two pythia
models: https://github.com/whittle-org/whittle/blob/main/test/test_extract.py
Or, I could take MicroLLama, extract some subnet, and test the extract_sub_network
with this smaller model.
@aaronkl one thing I noticed it that we are not extracting weights of norm layers in extract_sub_network
, should it be that way or am I missing something?
@rheasukthanker Yes, IMO we should do that.
Since the test uses an uninitialized model, it passed because the default init for norm is a unit vector for weight and zero for biases. How should I change the test? One quick ugly idea is to initialize the norm weights of the supernet to random values.
I agree, we should make the tests more robust and randomize the init for subnet extraction, along with the fix for norms. But lets do that in a separate PR (from the current one). I'll create a new issue.
Is your feature request related to a problem? Please describe. Currently,
extract_sub_network
works for GPTNeoxMLP, but not for LLamaMLP or GemmaMLP.https://github.com/whittle-org/whittle/blob/3b18ba58a60ed0266438b67b0cfc272a291c7cb9/whittle/models/gpt/extract.py
Describe the solution you'd like Fix how weights are extracted - e.g. this line should be in a function specific to GPTNeoxMLP https://github.com/whittle-org/whittle/blob/3b18ba58a60ed0266438b67b0cfc272a291c7cb9/whittle/models/gpt/extract.py#L29
If we plan to introduce other types of models/MLPs, maybe a block class should provide a dict of submodules that need calls to
load_state_dict
.