Closed zhangzhenyu13 closed 2 years ago
In our experiments, we always use [2, 5, 8, 11] layers for distillation, which follows TinyBERT's practice. This is a manual setup, and the selection of these layers could depend on the final sparsity and the expected number of layers remaining in the student model. We noticed in preliminary experiments that models of a 95% sparsity rarely lead to a final structure of fewer than four layers, suggesting that four layers might be optimal in a 95% sparsity regime. For lower sparsities, increasing the number of teacher layers to distill from more teacher layers, but we didn't run the experiments to verify.
Could you share more details on your experimental setup so I can help with result reproduction?
Thanks. The fixed teacher layer sets might be suboptima for different datasets. I found that it can always lead a satisfying results by selecting the teacher layer sets via the random strategies every training step . The idea is inspired by the works: [2109.10164] RAIL-KD: RAndom Intermediate Layer Mapping for Knowledge Distillation (arxiv.org).
I have modified the trainer methods where I always use version 6. The following are the adapted codes in trainer.calculate_layer_distillation_loss method:
`def calculate_layer_distillation_loss(self, teacher_outputs, student_outputs, zs): mse_loss = torch.nn.MSELoss(reduction="mean") if self.additional_args.do_layer_distill: #! only do layer distill mlp_z = None head_layer_z = None
if "mlp_z" in zs:
mlp_z = zs["mlp_z"].detach().cpu()
if "head_layer_z" in zs:
head_layer_z = zs["head_layer_z"].detach().cpu()
teacher_layer_output = teacher_outputs[2][1:] #! hidden states, with a length of 12. Every has a shape of [32, 65, 768]
student_layer_output = student_outputs[2][1:]
# distilliting existing layers
if self.additional_args.layer_distill_version == 2:
for layer_num, (t_layer_o, s_layer_o) in enumerate(zip(teacher_layer_output, student_layer_output)):
s_layer_o = self.model.layer_transformation(s_layer_o)
l = mse_loss(t_layer_o, s_layer_o)
if mlp_z is None or mlp_z[layer_num] > 0:
layer_loss += l
# distilling layers with a minimal distance
elif self.additional_args.layer_distill_version > 2:
l = []
if self.additional_args.layer_distill_version > 4:
specified_teacher_layers = [i for i in range(12)]
if self.additional_args.layer_distill_version ==5:
specified_teacher_layers = sorted(random.sample(specified_teacher_layers, 4))
elif self.additional_args.layer_distill_version ==6:
result_layers_T= []
skip_window = len(specified_teacher_layers)//4
for i in range(0, len(specified_teacher_layers), skip_window):
result_layers_T.append(random.sample(specified_teacher_layers[i:i+skip_window], 1)[0])
specified_teacher_layers = result_layers_T
specified_teacher_layers[0] = max(2, specified_teacher_layers[0])
else:
specified_teacher_layers = [2, 5, 8, 11]
# logger.info(f"sampled teacher layers: {specified_teacher_layers}")
transformed_s_layer_o = [self.model.layer_transformation(
s_layer_o) for s_layer_o in student_layer_output]
specified_teacher_layer_reps = [
teacher_layer_output[i] for i in specified_teacher_layers] #! teacher: 4x[32,113,768]
device = transformed_s_layer_o[0].device
for t_layer_o in specified_teacher_layer_reps:
for i, s_layer_o in enumerate(transformed_s_layer_o): #! student: 12x[32,113,768]
l.append(mse_loss(t_layer_o, s_layer_o))
layerwiseloss = torch.stack(l).reshape(
len(specified_teacher_layer_reps), len(student_layer_output)) #! [4,12]
existing_layers = None
if head_layer_z is not None:
existing_layers = head_layer_z != 0
existing_layers = existing_layers.to(layerwiseloss.device)
layer_loss = 0
#! no ordering restriction specified
if self.additional_args.layer_distill_version == 3:
alignment = torch.argmin(layerwiseloss, dim=1)
#! added the ordering restriction -> to choose the min loss in 4 student layers
elif self.additional_args.layer_distill_version in (3, 4, 5, 6):
last_aligned_layer = 12
alignment = []
for search_index in range(len(specified_teacher_layers)-1, -1, -1):
indexes = layerwiseloss[search_index].sort()[1]
if existing_layers is not None:
align = indexes[(
indexes < last_aligned_layer) & existing_layers]
else:
align = indexes[indexes < last_aligned_layer]
if len(align) > 0:
align = align[0]
else:
align = last_aligned_layer
alignment.append(align)
last_aligned_layer = align
alignment.reverse()
alignment = torch.tensor(alignment).to(device)
else:
logger.info(
f"{self.additional_args.layer_distill_version} version is not specified.")
sys.exit()
layerwise = torch.arange(len(specified_teacher_layers)).to(device)
layer_loss += layerwiseloss[layerwise, alignment].sum() #! layerwise: teacher (specified layers) / alignment: student (min loss layers) / layerwiseloss: [4,12]
if self.global_step % 100 == 0:
logger.info(f"v{self.additional_args.layer_distill_version} Global step: {self.global_step}, Alignment: " + str(alignment))
return layer_loss
else:
return None`
Hi,
Yes, I agree that the fixed set of teacher layers might be suboptimal, and it's super cool that the random layer selection version works well! I am very curious about how the final results differ between versions 3, 4, and version 5, 6. Also, feel free to pull a request to the repo, and I am happy to acknowledge the contribution on the main page.
Hi, I am closing this issue and feel free to reopen it if necessary. Thanks again for your contribution!
The original papers mentioned:
Specifically, let T denote a set of teacher layers that we use to distill knowledge to the student model.'' And the code in trainer provides
[2, 5, 8, 11]'' only, which is part of settings in Appendix. Any suggestions of selection of such teacher layer sets for distillation,? 4 layers at most? which 4 layers are proper? how do we specify task-aware settings? i.e., There are 12 layers for Students, why we only choose to select from given 4 layers? How about 5, 6, 12 layers for T,? I think it is critical for reproduce results, where I barely reproduce any results to match the reported scores now?