huggingface / nn_pruning

Prune a model while finetuning or training.
Apache License 2.0
394 stars 58 forks source link

No weights removed during fine-pruning? #5

Closed lewtun closed 3 years ago

lewtun commented 3 years ago

Hello François,

I've put together a simple text classification example (link) using the SparseTrainer and it seems that no weights are being removed during fine-pruning.

From the nn_pruning docs my understanding is that I need to take the following steps:

Create a mixin with SparseTrainer and Trainer

Since I'm not doing anything fancy like question-answering, I created the following class:

class PruningTrainer(SparseTrainer, Trainer):
    def __init__(self, sparse_args, *args, **kwargs):
        Trainer.__init__(self, *args, **kwargs)
        SparseTrainer.__init__(self, sparse_args)

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        We override the default loss in SparseTrainer because it throws an 
        error when run without distillation
        """
        outputs = model(**inputs)

        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        # We don't use .loss here since the model may return tuples instead of ModelOutput.
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        self.metrics["ce_loss"] += float(loss)
        self.loss_counter += 1
        return (loss, outputs) if return_outputs else loss

where I override the default compute_loss function because it throws an TypeError: iteration over a 0-d tensor error when a teacher model is not provided (I want to try first without distillation). I think the error is produced by distil_loss_combine which returns a single value here but compute_loss expects two values here.

Instantiate trainer with sparse training arguments

With the above mixin, my idea was to use the default values in SparseTrainingArguments along with the usual things we need in a HF Trainer:

sparse_args = SparseTrainingArguments()

trainer = PruningTrainer(
    sparse_args=sparse_args,
    args=args,
    model=bert_model,
    train_dataset=boolq_enc['train'],
    eval_dataset=boolq_enc['validation'],
    tokenizer=bert_tokenizer,
    compute_metrics=compute_metrics
)

By default SparseTrainingArguments has initial_threshold=1 and final_threshold=0.5 so my understanding is that by the end of fine-pruning we expect around 50% of the encoder weights to be removed.

Set the trainer's patch coordinator

Here I took a guess based on your instructions for fine-pruning without a trainer and set the patch coordinator as follows:

mpc = ModelPatchingCoordinator(
    sparse_args=sparse_args, 
    device=device, 
    cache_dir="checkpoints", 
    logit_names="logits", 
    teacher_constructor=AutoModelForSequenceClassification)

trainer.set_patch_coordinator(mpc)

Fine-tune

Running

trainer.train()

seems to show the model is learning, although curiously the mask threshold in the logs is already 0.5 after the first epoch.

Optimize the model for inference

Following your example I ran

prunebert_model = optimize_model(trainer.model.to("cpu"), "dense")

but find that no parameters are removed 😢. So it seems that although the model is learning, I have missed something to enable pruning during fine-tuning.

Any ideas on what step(s) I'm missing?

P.S. What is the meaning of "XP" in classes like SparseXP?

lewtun commented 3 years ago

Update: I realised that I also need to patch the model, so took the following approach

bert_model = ...
mpc = ModelPatchingCoordinator(...)
mpc.patch_model(bert_model)  # returns LAYER NORM PATCH {'patched': 72}
trainer = PruningTrainer(...)
trainer.set_patch_coordinator(mpc)

Now here's the weird thing: if I compile and optimize the model before fine-pruning, then I see 50% of the weights are removed (the default in SparseTrainingArguments):

mpc.compile_model(trainer.model)
# prints removed heads 72, total_heads=144, percentage removed=0.5 ...
prunebert_model = optimize_model(trainer.model, "dense") 

But if I fine-prune the model first, I find none of the weights are removed

trainer.train()
mpc.compile_model(trainer.model)
# prints removed heads 0, total_heads=144, percentage removed=0.0 ...
prunebert_model = optimize_model(trainer.model, "dense") 

Evidently I am missing some crucial step in the configuration of the fine-pruning trainer and perhaps I should be wrapping this in a SparseXP mixin as you do in the SQuAD / GLUE examples - is this necessary?

madlag commented 3 years ago

Your code was actually 100% correct, but here is why you did not get a correct network.

Here is an updated notebook . With the default parameter values, you cannot get some visible pruning from the method, as the pruning will be completely unstructured (the historical behaviour of movement pruning): dense_block_rows=dense_block_cols=attention_block_rows=attention_block_cols=1

I have to keep it this way to make sure old runs are still compatible with the code, for reproducibility reasons.

That said, I will provide ways to create correct sparse_args, with the various methods that can be used.

So for the moment, here is a code snippet I used in the updated notebook, to create correct parameters. I added some comments to explain what are their roles.

sparse_args = SparseTrainingArguments()

d = {
  "initial_warmup": 1,
  "final_warmup": 3,
  "initial_threshold": 1.0, # When using topK set to 1 (initial density). With sigmoied_threshold, use 0.0 (cutoff)
  "final_threshold": 0.5, # When using topK, this is the final density. With sigmoied_threshold, use 0.1 (final cutoff, which is a bit arbitrary of course, set regularization_final_lambda to adjust final sparsity)
  "dense_pruning_method": "topK:1d_alt", #"sigmoied_threshold:1d_alt",
  "dense_block_rows":1,
  "dense_block_cols":1,
  "dense_lambda":0.25,
  "attention_pruning_method": "topK", #"sigmoied_threshold",
  "attention_block_rows":32,
  "attention_block_cols":32,
  "attention_lambda":1.0,
  "ampere_pruning_method": "disabled",
  "mask_init": "constant",
  "mask_scale": 0.0,
  "regularization": None, # "l1" when pruning_methods are sigmoied_threshold
  "regularization_final_lambda": 20, # To be tweaked to adjust sparsity : the higher, the more sparse. Try different values by multiplying by 2x several times
  "distil_teacher_name_or_path":None,
  "distil_alpha_ce": 0.1,
  "distil_alpha_teacher": 0.9,
  "attention_output_with_dense": 0,
  "layer_norm_patch" : 0,
  "gelu_patch":0
}

for k,v in d.items():
  if hasattr(sparse_args, k):
    setattr(sparse_args, k, v)
  else:
    print(f"sparse_args does not have an argument {k}")

Another aspect is that head pruning is a kind of emerging property when using the "sigmoied_treshold" method, so it's not 100% guaranteed that you get some full head pruning using it on just a test run, regularization has to do its job for some time to make heads disappear.

With the topK method you are guaranteed to remove X % of each layer, but it's not the best method in terms of accuracy. That said, topK is a good start to test that everything works OK.

I added some code to visualize two typical layers in attention and FFn, so you can have an idea of what's happening.

I hope this will help you !

xihajun commented 1 year ago

@lewtun May I ask did you manage to find a good value for dense_block_rows for pruning?