princeton-nlp / CoFiPruning

[ACL 2022] Structured Pruning Learns Compact and Accurate Models https://arxiv.org/abs/2204.00408
MIT License
192 stars 31 forks source link

layer-distillation: teacher layer sets selection? #34

Closed zhangzhenyu13 closed 2 years ago

zhangzhenyu13 commented 2 years ago

The original papers mentioned: Specifically, let T denote a set of teacher layers that we use to distill knowledge to the student model.'' And the code in trainer provides[2, 5, 8, 11]'' only, which is part of settings in Appendix. Any suggestions of selection of such teacher layer sets for distillation,? 4 layers at most? which 4 layers are proper? how do we specify task-aware settings? i.e., There are 12 layers for Students, why we only choose to select from given 4 layers? How about 5, 6, 12 layers for T,? I think it is critical for reproduce results, where I barely reproduce any results to match the reported scores now?

xiamengzhou commented 2 years ago

In our experiments, we always use [2, 5, 8, 11] layers for distillation, which follows TinyBERT's practice. This is a manual setup, and the selection of these layers could depend on the final sparsity and the expected number of layers remaining in the student model. We noticed in preliminary experiments that models of a 95% sparsity rarely lead to a final structure of fewer than four layers, suggesting that four layers might be optimal in a 95% sparsity regime. For lower sparsities, increasing the number of teacher layers to distill from more teacher layers, but we didn't run the experiments to verify.

Could you share more details on your experimental setup so I can help with result reproduction?

zhangzhenyu13 commented 2 years ago

Thanks. The fixed teacher layer sets might be suboptima for different datasets. I found that it can always lead a satisfying results by selecting the teacher layer sets via the random strategies every training step . The idea is inspired by the works: [2109.10164] RAIL-KD: RAndom Intermediate Layer Mapping for Knowledge Distillation (arxiv.org).

I have modified the trainer methods where I always use version 6. The following are the adapted codes in trainer.calculate_layer_distillation_loss method:

`def calculate_layer_distillation_loss(self, teacher_outputs, student_outputs, zs): mse_loss = torch.nn.MSELoss(reduction="mean") if self.additional_args.do_layer_distill: #! only do layer distill mlp_z = None head_layer_z = None

logger.info(f"zs={zs}")

        if "mlp_z" in zs:
            mlp_z = zs["mlp_z"].detach().cpu()
        if "head_layer_z" in zs:
            head_layer_z = zs["head_layer_z"].detach().cpu()

        teacher_layer_output = teacher_outputs[2][1:] #! hidden states, with a length of 12. Every has a shape of [32, 65, 768]
        student_layer_output = student_outputs[2][1:] 

        # distilliting existing layers
        if self.additional_args.layer_distill_version == 2:
            for layer_num, (t_layer_o, s_layer_o) in enumerate(zip(teacher_layer_output, student_layer_output)):
                s_layer_o = self.model.layer_transformation(s_layer_o)
                l = mse_loss(t_layer_o, s_layer_o)
                if mlp_z is None or mlp_z[layer_num] > 0:
                    layer_loss += l

        # distilling layers with a minimal distance
        elif self.additional_args.layer_distill_version > 2:
            l = []
            if self.additional_args.layer_distill_version > 4:
                specified_teacher_layers = [i for i in range(12)]
                if self.additional_args.layer_distill_version ==5:
                    specified_teacher_layers = sorted(random.sample(specified_teacher_layers, 4))
                elif self.additional_args.layer_distill_version ==6:
                    result_layers_T= []
                    skip_window = len(specified_teacher_layers)//4
                    for i in range(0, len(specified_teacher_layers), skip_window):
                        result_layers_T.append(random.sample(specified_teacher_layers[i:i+skip_window], 1)[0])
                    specified_teacher_layers = result_layers_T
                specified_teacher_layers[0] = max(2, specified_teacher_layers[0])
            else:
                specified_teacher_layers = [2, 5, 8, 11]
            # logger.info(f"sampled teacher layers: {specified_teacher_layers}")
            transformed_s_layer_o = [self.model.layer_transformation(
                s_layer_o) for s_layer_o in student_layer_output]
            specified_teacher_layer_reps = [
                teacher_layer_output[i] for i in specified_teacher_layers] #! teacher: 4x[32,113,768]

            device = transformed_s_layer_o[0].device
            for t_layer_o in specified_teacher_layer_reps:
                for i, s_layer_o in enumerate(transformed_s_layer_o): #! student: 12x[32,113,768]
                    l.append(mse_loss(t_layer_o, s_layer_o))
            layerwiseloss = torch.stack(l).reshape(
                len(specified_teacher_layer_reps), len(student_layer_output)) #! [4,12]

            existing_layers = None
            if head_layer_z is not None:
                existing_layers = head_layer_z != 0
                existing_layers = existing_layers.to(layerwiseloss.device)

            layer_loss = 0
            #! no ordering restriction specified
            if self.additional_args.layer_distill_version == 3:
                alignment = torch.argmin(layerwiseloss, dim=1)
            #! added the ordering restriction -> to choose the min loss in 4 student layers
            elif self.additional_args.layer_distill_version in (3, 4, 5, 6):
                last_aligned_layer = 12
                alignment = []
                for search_index in range(len(specified_teacher_layers)-1, -1, -1):
                    indexes = layerwiseloss[search_index].sort()[1]
                    if existing_layers is not None:
                        align = indexes[(
                            indexes < last_aligned_layer) & existing_layers]
                    else:
                        align = indexes[indexes < last_aligned_layer]
                    if len(align) > 0:
                        align = align[0]
                    else:
                        align = last_aligned_layer
                    alignment.append(align)
                    last_aligned_layer = align
                alignment.reverse()
                alignment = torch.tensor(alignment).to(device)
            else:
                logger.info(
                    f"{self.additional_args.layer_distill_version} version is not specified.")
                sys.exit()

            layerwise = torch.arange(len(specified_teacher_layers)).to(device)
            layer_loss += layerwiseloss[layerwise, alignment].sum() #! layerwise: teacher (specified layers) / alignment: student (min loss layers) / layerwiseloss: [4,12]
            if self.global_step % 100 == 0:
                logger.info(f"v{self.additional_args.layer_distill_version} Global step: {self.global_step}, Alignment: " + str(alignment))
        return layer_loss
    else:
        return None`
xiamengzhou commented 2 years ago

Hi,

Yes, I agree that the fixed set of teacher layers might be suboptimal, and it's super cool that the random layer selection version works well! I am very curious about how the final results differ between versions 3, 4, and version 5, 6. Also, feel free to pull a request to the repo, and I am happy to acknowledge the contribution on the main page.

xiamengzhou commented 2 years ago

Hi, I am closing this issue and feel free to reopen it if necessary. Thanks again for your contribution!