airaria / TextBrewer

A PyTorch-based knowledge distillation toolkit for natural language processing
http://textbrewer.hfl-rc.com
Apache License 2.0
1.6k stars 239 forks source link

对roberta-wwm-ext蒸馏遇到维度不匹配 #34

Closed junrong1 closed 3 years ago

junrong1 commented 3 years ago

` with open('config.json') as f: config = json.load(fp=f)

num_epochs=2

bert_l12 = BertConfig.from_json_file('./bert_config/bert_l12.json')
bert_l4 = BertConfig.from_json_file('./bert_config/bert_l4.json')
teacher_model = BertModel.from_pretrained('hfl/chinese-roberta-wwm-ext').cpu()
student_model = BertModel(bert_l4).cpu()
print("\nteacher_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(teacher_model, max_level=3)
print(result)

print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model, max_level=3)
print(result)

train_config = TrainingConfig(device=torch.device('cpu'))

distill_config = DistillationConfig(
    intermediate_matches=
        [{"layer_T": 0, "layer_S": 0, "feature": "hidden", "loss": "hidden_mse", "weight": 1,
          "proj": ["linear", 312, 768]},
         {"layer_T": 3, "layer_S": 1, "feature": "hidden", "loss": "hidden_mse", "weight": 1,
          "proj": ["linear", 312, 768]},
         {"layer_T": 6, "layer_S": 2, "feature": "hidden", "loss": "hidden_mse", "weight": 1,
          "proj": ["linear", 312, 768]},
         {"layer_T": 9, "layer_S": 3, "feature": "hidden", "loss": "hidden_mse", "weight": 1,
          "proj": ["linear", 312, 768]},
         {"layer_T": 12, "layer_S": 4, "feature": "hidden", "loss": "hidden_mse", "weight": 1,
          "proj": ["linear", 312, 768]}])

tokenizer = BertTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext')
data_dir = json_2_csv(config['json_file_path'], config['csv_file_output_dir']) + '.csv'
train_dataset = load_examples(tokenizer,
                              data_dir=data_dir,
                              max_seq_len=config['max_seq_len'],
                              set_type='train',
                              config=config)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset,
                              sampler=train_sampler,
                              batch_size=config['hyper_parameters']['batch_size'])

optimizer = AdamW(student_model.parameters(), lr=1e-4)
scheduler = None
scheduler_class = get_linear_schedule_with_warmup
num_training_steps = len(train_dataloader) * num_epochs
scheduler_args = {'num_warmup_steps':int(0.1*num_training_steps), 'num_training_steps':num_training_steps}

distiller = GeneralDistiller(
    train_config=train_config, distill_config=distill_config,
    model_T=teacher_model, model_S=student_model,
    adaptor_T=simple_adaptor, adaptor_S=simple_adaptor)

distiller.train(optimizer, train_dataloader, num_epochs=num_epochs, scheduler_class=scheduler_class, scheduler_args=scheduler_args, callback=None)`

Traceback (most recent call last): File "/Users/ray.yao/Desktop/daas-text-align/model/know_dis.py", line 96, in distiller.train(optimizer, train_dataloader, num_epochs=num_epochs, scheduler_class=scheduler_class, scheduler_args=scheduler_args, callback=None) File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/textbrewer/distiller_basic.py", line 283, in train self.train_with_num_epochs(optimizer, scheduler, tqdm_disable, dataloader, max_grad_norm, num_epochs, callback, batch_postprocessor, args) File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/textbrewer/distiller_basic.py", line 212, in train_with_num_epochs total_loss, losses_dict = self.train_on_batch(batch,args) File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/textbrewer/distiller_general.py", line 74, in train_on_batch (teacher_batch, results_T), (student_batch, results_S) = get_outputs_from_batch(batch, self.t_config.device, self.model_T, self.model_S, args) File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/textbrewer/distiller_utils.py", line 274, in get_outputs_from_batch results_T = auto_forward(model_T,batch,args) File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/textbrewer/distiller_utils.py", line 294, in auto_forward results = model(*batch, *args) File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, kwargs) File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/transformers/modeling_bert.py", line 728, in forward embedding_output = self.embeddings( File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/transformers/modeling_bert.py", line 177, in forward embeddings = inputs_embeds + position_embeddings + token_type_embeddings RuntimeError: The size of tensor a (30) must match the size of tensor b (512) at non-singleton dimension 1

Process finished with exit code 1

inputs_embeds.size() = [512, 30, 768] # batch, seq_len, hid_dim position_embeddings.size() = [512, 768] token_type_embeddings = [512, 30, 768]

airaria commented 3 years ago

可能是因为dataloader返回的batch中的元素的顺序和BertModel接收的参数的顺序不一致。 Transformer 2.9中BertModel的forward参数顺序是

 def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    )

你的dataloader返回的batch中元素的顺序是否与之一致?


建议自定义返回dict而不是tuple的dataset,使每个key和forward中参数名匹配,不容易发生此类对齐问题,例如:


class DictDataset(Dataset):
    def __init__(self,all_input_ids, all_input_mask, all_token_type_ids, all_labels):
        super(DictDataset, self).__init__()
        self.all_input_ids = all_input_ids
        self.all_input_mask = all_input_mask
        self.all_token_type_ids = all_token_type_ids
        self.all_labels = all_labels
    def __getitem__(self, index):
        input_ids = self.all_input_ids[index]
        input_mask = self.all_input_mask[index]
        token_type_ids = self.all_token_type_ids[index]
        labels = self.all_labels[index]
        return {'input_ids':input_ids,
                'attention_mask':input_mask,
                'token_type_ids': token_type_ids,
                'labels':labels}
    def __len__(self):
        return len(self.all_labels)
junrong1 commented 3 years ago

可能是因为dataloader返回的batch中的元素的顺序和BertModel接收的参数的顺序不一致。 Transformer 2.9中BertModel的forward参数顺序是

 def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    )

你的dataloader返回的batch中元素的顺序是否与之一致?

建议自定义返回dict而不是tuple的dataset,使每个key和forward中参数名匹配,不容易发生此类对齐问题,例如:

class DictDataset(Dataset):
    def __init__(self,all_input_ids, all_input_mask, all_token_type_ids, all_labels):
        super(DictDataset, self).__init__()
        self.all_input_ids = all_input_ids
        self.all_input_mask = all_input_mask
        self.all_token_type_ids = all_token_type_ids
        self.all_labels = all_labels
    def __getitem__(self, index):
        input_ids = self.all_input_ids[index]
        input_mask = self.all_input_mask[index]
        token_type_ids = self.all_token_type_ids[index]
        labels = self.all_labels[index]
        return {'input_ids':input_ids,
                'attention_mask':input_mask,
                'token_type_ids': token_type_ids,
                'labels':labels}
    def __len__(self):
        return len(self.all_labels)

通过改变dataset解决了这个维度不匹配的问题,然后又出现了一个新的问题 TypeError: forward() got an unexpected keyword argument 'labels'

junrong1 commented 3 years ago

可能是因为dataloader返回的batch中的元素的顺序和BertModel接收的参数的顺序不一致。 Transformer 2.9中BertModel的forward参数顺序是

 def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    )

你的dataloader返回的batch中元素的顺序是否与之一致? 建议自定义返回dict而不是tuple的dataset,使每个key和forward中参数名匹配,不容易发生此类对齐问题,例如:

class DictDataset(Dataset):
    def __init__(self,all_input_ids, all_input_mask, all_token_type_ids, all_labels):
        super(DictDataset, self).__init__()
        self.all_input_ids = all_input_ids
        self.all_input_mask = all_input_mask
        self.all_token_type_ids = all_token_type_ids
        self.all_labels = all_labels
    def __getitem__(self, index):
        input_ids = self.all_input_ids[index]
        input_mask = self.all_input_mask[index]
        token_type_ids = self.all_token_type_ids[index]
        labels = self.all_labels[index]
        return {'input_ids':input_ids,
                'attention_mask':input_mask,
                'token_type_ids': token_type_ids,
                'labels':labels}
    def __len__(self):
        return len(self.all_labels)

通过改变dataset解决了这个维度不匹配的问题,然后又出现了一个新的问题 TypeError: forward() got an unexpected keyword argument 'labels'

Traceback (most recent call last): File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/textbrewer/distiller_utils.py", line 265, in get_outputs_from_batch results_T = auto_forward(model_T,batch,args) File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/textbrewer/distiller_utils.py", line 287, in auto_forward results = model(batch, args) File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) TypeError: forward() got an unexpected keyword argument 'labels' python-BaseException

完整的报错信息

junrong1 commented 3 years ago

可能是因为dataloader返回的batch中的元素的顺序和BertModel接收的参数的顺序不一致。 Transformer 2.9中BertModel的forward参数顺序是

 def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    )

你的dataloader返回的batch中元素的顺序是否与之一致? 建议自定义返回dict而不是tuple的dataset,使每个key和forward中参数名匹配,不容易发生此类对齐问题,例如:

class DictDataset(Dataset):
    def __init__(self,all_input_ids, all_input_mask, all_token_type_ids, all_labels):
        super(DictDataset, self).__init__()
        self.all_input_ids = all_input_ids
        self.all_input_mask = all_input_mask
        self.all_token_type_ids = all_token_type_ids
        self.all_labels = all_labels
    def __getitem__(self, index):
        input_ids = self.all_input_ids[index]
        input_mask = self.all_input_mask[index]
        token_type_ids = self.all_token_type_ids[index]
        labels = self.all_labels[index]
        return {'input_ids':input_ids,
                'attention_mask':input_mask,
                'token_type_ids': token_type_ids,
                'labels':labels}
    def __len__(self):
        return len(self.all_labels)

通过改变dataset解决了这个维度不匹配的问题,然后又出现了一个新的问题 TypeError: forward() got an unexpected keyword argument 'labels'

Traceback (most recent call last): File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/textbrewer/distiller_utils.py", line 265, in get_outputs_from_batch results_T = auto_forward(model_T,batch,args) File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/textbrewer/distiller_utils.py", line 287, in auto_forward results = model(batch, args) File "/Users/ray.yao/opt/anaconda3/envs/text-align/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) TypeError: forward() got an unexpected keyword argument 'labels' python-BaseException

完整的报错信息

我目前用的transformer的版本是2.9.0

airaria commented 3 years ago

DictDataset仅做示例,其返回的字典的键值根据实际使用的模型自行修改

DeqianBai commented 2 years ago

这个改变dataset是怎么改变的,我也遇到这个问题了,谢谢!