Closed sonakshiigarg closed 3 weeks ago
Hello! I am having the same error. Any update on this?
What is the definition of class Trainer? I need to take a look at how self.model has been initialized.
Furthermore, could you provide a co-lab link? The current code is not quite readable and missing a lot of environment information, and thus very hard for me to replicate the error.
I am using Privacy Engine for Knowledge distillation in transformers. But I encountered an issue: Trying to add hooks twice to the same model. I am trying to attach privacy engine only once to the student_model, but it doesn't seem so. Need help to fix the code.
The code is:
class DistillationTrainingArguments(TrainingArguments): def init(self, *args, alpha=0.5, temperature=2.0, epsilon=1.0, delta=1e-5,*kwargs): super().init(args, **kwargs) self.epsilon = epsilon self.delta = delta self.alpha = alpha self.temperature = temperature
class DistillationTrainer(Trainer): def init(self, *args, teacher_model=None, *kwargs): super().init(args, **kwargs) self.teacher = teacher_model self.teacher.to(self.model.device) self.teacher.eval() self.model.train() optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.learning_rate,momentum=0.5) self.model.optimizer = optimizer if not hasattr(optimizer, "privacy_engine"): privacy_engine = PrivacyEngine() privacy_engine.make_private( module=self.model, data_loader=tokenized_datasets["train"], optimizer=optimizer, noise_multiplier=1.0, max_grad_norm=1.0,) privacy_engine.attach(self.optimizer) optimizer.privacy_engine = privacy_engine privacy_engine.detach()
teacher_model = AutoModelForSequenceClassification.from_pretrained( teacher_id, num_labels=num_labels, id2label=id2label, label2id=label2id, ) student_model = AutoModelForSequenceClassification.from_pretrained( student_id, num_labels=num_labels, id2label=id2label, label2id=label2id, ) trainer = DistillationTrainer( student_model, training_args, teacher_model=teacher_model, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["validation"], data_collator=data_collator, tokenizer=tokenizer, compute_metrics=compute_metrics, )