princeton-nlp / CoFiPruning

[ACL 2022] Structured Pruning Learns Compact and Accurate Models https://arxiv.org/abs/2204.00408
MIT License
188 stars 32 forks source link

Fatal Logic Error found in trainer.py #10

Closed zhangzhenyu13 closed 2 years ago

zhangzhenyu13 commented 2 years ago

in the file: https://github.com/princeton-nlp/CoFiPruning/blob/main/trainer/trainer.py line 279 sepcifies following statement :

 if self.start_prune:
    zs = self.l0_module.forward(training=True)
    self.fill_inputs_with_zs(zs, inputs)

only when this runs, we can get the gradient for the params in self.l0_optimizer. Only when the condiction satisfied as below (line 268):

  if self.prepruning_finetune_steps > 0 and self.global_step == self.prepruning_finetune_steps:
      self.start_prune = True

However, line 301 just directly update the params without checking whether the grads are ready:

  if self.l0_module is not None and self.l0_optimizer is not None:
      self.l0_optimizer.step()
      self.lagrangian_optimizer.step()

therefore, the adamw yields bugs for beta1/beta2 referred before define in its step method. As the the grad of the params are all None, the adamw implementation will skip define the hyper-params via the self.group dict.

xiamengzhou commented 2 years ago

Hi,

If the gradient is None, self.l0_optimizer.step() has no effect on the set of parameters passed to the optimizer. And it will start taking effect (updating the parameters) when gradients are no longer None after starting to prune.

zhangzhenyu13 commented 2 years ago

Okay, the torch I used is 1.8.2 version, where the adamw used in trainer is defined in the following file (the step method is given in line 55): https://github.com/pytorch/pytorch/blob/v1.8.2/torch/optim/adamw.py where it continue the loop in line 77

            for p in group['params']:
                if p.grad is None:
                    continue

but the beta1, beta2 = group['betas'] is given in line 104 that are skipped by above. The code then refer to the betas in line 117&118:

F.adamw(params_with_grad,
                    grads,
                    exp_avgs,
                    exp_avg_sqs,
                    max_exp_avg_sqs,
                    state_steps,
                    amsgrad,
                    beta1,
                    beta2,
                    group['lr'],
                    group['weight_decay'],
                    group['eps'])

Thus the undefined error happens in above function.

For compatiability and robustness, I think it is necessary to check when applying the step method via the tag I proposed in pull request. Because the flexible pruning&distill method of your research is essential necessity, many torch users of all kinds of versions will refer the code. So we need to consider for such LTS-version (e.g. torch-1.8.x).

However, the latest torch version has different code-flow, the bug will not happen. The adamw used is defined in https://github.com/pytorch/pytorch/blob/master/torch/optim/adamw.py, the step method is defined in line 108, where the statements in line 130 beta1, beta2 = group['betas'] is before 'no-grad-skip' continue statement (line 132 ). Thanks very much for your replies.

xiamengzhou commented 2 years ago

Thanks for identifying this! I merged the fix into the main branch.