microsoft / nni

An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.
https://nni.readthedocs.io
MIT License
14.07k stars 1.82k forks source link

How to specify the target of distilling when there are multiple outputs? #5600

Closed hobbitlzy closed 1 year ago

hobbitlzy commented 1 year ago

Describe the issue: I am using the transcript of pruning BERT on MNLI. I find a problem of distillation when I change the output format of the encoder, which is as follows.

layer_outputs = layer_module(hidden_states, ...) # original BERT encoder
layer_outputs, second_output = layer_module(hidden_states, ...) # modified BERT encoder

I changed the configuration of the distiller

def dynamic_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
                      student_trainer: Trainer):
    layer_num = len(student_model.bert.encoder.layer)
    config_list = [{
        'op_names': [f'bert.encoder.layer.{i}'],
        'link': [f'bert.encoder.layer.{j}' for j in range(i, layer_num)],
        'lambda': 0.9,
        'apply_method': 'mse',
        'target_names': ['_output_0'] # this line is new added to specify which output is used to distil.
    } for i in range(layer_num)]
    config_list.append({
        'op_names': ['classifier'],
        'link': ['classifier'],
        'lambda': 0.9,
        'apply_method': 'kl',
    })

    evaluator = TransformersEvaluator(student_trainer)

    def teacher_predict(batch, teacher_model):
        return teacher_model(**batch)

    return DynamicLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)

But the errors I put in the Log message occurs. Could you help me with this problem?

Environment:

Log message:

Traceback (most recent call last):
  File "/home/work/project/prune_bert_glue/pruning_bert_glue.py", line 178, in <module>
    finetuned_model=post_distillation(task_name, finetuned_model, teacher_model, output_dir=stage_softmax_dir)
  File "/home/work/project/prune_bert_glue/prune_modules.py", line 107, in post_distillation
    dynamic_distillation(task_name, model, copy.deepcopy(teacher_model), output_dir, None, 3)
  File "/home/work/project/prune_bert_glue/distiller.py", line 52, in dynamic_distillation
    distiller.compress(max_steps, max_epochs)
  File "/home/work/project/nni/nni/contrib/compression/base/compressor.py", line 190, in compress
    self._single_compress(max_steps, max_epochs)
  File "/home/work/project/nni/nni/contrib/compression/distillation/basic_distiller.py", line 144, in _single_compress
    self._fusion_compress(max_steps, max_epochs)
  File "/home/work/project/nni/nni/contrib/compression/base/compressor.py", line 183, in _fusion_compress
    self.evaluator.train(max_steps, max_epochs)
  File "/home/work/project/nni/nni/contrib/compression/utils/evaluator.py", line 1084, in train
    self.trainer.train()
  File "/home/work/anaconda3/envs/nni/lib/python3.10/site-packages/transformers/trainer.py", line 1664, in train
    return inner_training_loop(
  File "/home/work/anaconda3/envs/nni/lib/python3.10/site-packages/transformers/trainer.py", line 1940, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/work/anaconda3/envs/nni/lib/python3.10/site-packages/transformers/trainer.py", line 2735, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/work/project/nni/nni/contrib/compression/utils/evaluator.py", line 1044, in patched_compute_loss
    result = old_compute_loss(model, inputs, return_outputs)
  File "/home/work/anaconda3/envs/nni/lib/python3.10/site-packages/transformers/trainer.py", line 2767, in compute_loss
    outputs = model(**inputs)
  File "/home/work/anaconda3/envs/nni/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/work/project/prune_bert_glue/models/modeling_bert.py", line 1592, in forward
    outputs = self.bert(
  File "/home/work/anaconda3/envs/nni/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/work/project/prune_bert_glue/models/modeling_bert.py", line 1050, in forward
    encoder_outputs = self.encoder(
  File "/home/work/anaconda3/envs/nni/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/work/project/prune_bert_glue/models/modeling_bert.py", line 627, in forward
    layer_outputs, sparse_mask = layer_module(
  File "/home/work/anaconda3/envs/nni/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/work/project/nni/nni/contrib/compression/base/wrapper.py", line 407, in forward
    outputs = self.patch_outputs(outputs)
  File "/home/work/project/nni/nni/contrib/compression/base/wrapper.py", line 369, in patch_outputs
    new_outputs.append(self.patch_helper(target_name, target))
  File "/home/work/project/nni/nni/contrib/compression/base/wrapper.py", line 331, in patch_helper
    target = self._distil_observe_helper(target, self.distillation_target_spaces[target_name])
  File "/home/work/project/nni/nni/contrib/compression/base/wrapper.py", line 303, in _distil_observe_helper
    target_space.hidden_state = target
  File "/home/work/project/nni/nni/contrib/compression/base/target_space.py", line 387, in hidden_state
    raise TypeError('Only support saving tensor as distillation hidden_state.')
TypeError: Only support saving tensor as distillation hidden_state.
J-shang commented 1 year ago

Hello @hobbitlzy , sorry for that distiller only support link with single output layers, we will add support for multi-output in the future, for a workaround, you could add an torch.nn.Identity at the end of the encoder before return the multi-output,

class Encoder(...):
    def __init__(...):
        ...
        self.idt = torch.nn.Identity()

    def forward(...):
        ...
        layer_outputs = self.idt(layer_outputs)
        return layer_outputs, second_output

Then modify your config_list:

    config_list = [{
        'op_names': [f'bert.encoder.layer.{i}.idt'],  # add idt
        'link': [f'bert.encoder.layer.{j}.idt' for j in range(i, layer_num,  # add idt
        'lambda': 0.9,
        'apply_method': 'mse',
    } for i in range(layer_num)]
hobbitlzy commented 1 year ago

Thanks, that solves the problem. :)

hobbitlzy commented 1 year ago

Hi @J-shang. One more comment, I find if I pack the output as a tuple output=(layer_outputs, second_output), the distiller seems to pick the output[0] as the distillation target, which I checked in this method. This happens to work for me since I only need the layer_outputs for distillation. But I do not dive to understand why distiller does this and dangers may grow somewhere.