microsoft / CodeBERT

CodeBERT
MIT License
2.17k stars 446 forks source link

Tensor size mismatch when fine-tune clonedetection task for GraphCodeBERT model #73

Closed MichaelFu1998-create closed 3 years ago

MichaelFu1998-create commented 3 years ago

Hi there, I followed the code provided here to fine-tune CodeGraphBERT on clonedetection task, but I got an error about Tensor mismatch.

I am using transformers == 4.10.2 Here is the complete error Traceback

Traceback (most recent call last): File "/HDD18TB/data/michael/vul_project/CodeBERT-master/GraphCodeBERT/clonedetection/run.py", line 624, in main() File "/HDD18TB/data/michael/vul_project/CodeBERT-master/GraphCodeBERT/clonedetection/run.py", line 602, in main train(args, train_dataset, model, tokenizer) File "/HDD18TB/data/michael/vul_project/CodeBERT-master/GraphCodeBERT/clonedetection/run.py", line 372, in train loss,logits = model(inputs_ids_1,position_idx_1,attn_mask_1,inputs_ids_2,position_idx_2,attn_mask_2,labels) File "/home/michael/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, kwargs) File "/home/michael/.local/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/home/michael/.local/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/home/michael/.local/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply output.reraise() File "/home/michael/.local/lib/python3.9/site-packages/torch/_utils.py", line 425, in reraise raise self.exc_type(msg) RuntimeError: Caught RuntimeError in replica 0 on device 0. Original Traceback (most recent call last): File "/home/michael/.local/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker output = module(*input, *kwargs) File "/home/michael/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "/HDD18TB/data/michael/vul_project/CodeBERT-master/GraphCodeBERT/clonedetection/model.py", line 53, in forward outputs = self.encoder.roberta(inputs_embeds=inputs_embeddings,attention_mask=attn_mask,position_ids=position_idx)[0] File "/home/michael/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/home/michael/.local/lib/python3.9/site-packages/transformers/models/roberta/modeling_roberta.py", line 809, in forward buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) RuntimeError: The expanded size of the tensor (640) must match the existing size (514) at non-singleton dimension 1. Target sizes: [16, 640]. Tensor sizes: [1, 514]

The error was triggered by this code block under transformers lib (transformers/models/roberta/modeling_roberta.py) when expanding the Tensor (buffered_token_type_ids.expand(batch_size, seq_length)), the error came out

if token_type_ids is None:
    if hasattr(self.embeddings, "token_type_ids"):
        buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
        buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
        token_type_ids = buffered_token_type_ids_expanded
    else:
        token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

If we skip the expand like this

if token_type_ids is None:
    token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

the fine-tune would run, but I am not sure if this would affect the result.

Thank you very much for the help Appreciate it :)

guoday commented 3 years ago

Your change is right. The error comes from the update of transformers.

There are two other ways to solve the error:

  1. using older version of transformers

  2. change to outputs = self.encoder.roberta(inputs_embeds=inputs_embeddings,attention_mask=attn_mask,position_ids=position_idx, token_type_ids = position_idx.eq(-1))[0] in https://github.com/microsoft/CodeBERT/blob/c03cd8aa458c6c1935226db4d37fe7cd2c5b9ee1/GraphCodeBERT/clonedetection/model.py#L53

MichaelFu1998-create commented 3 years ago

Thanks for the swift and precise response

guoday commented 2 years ago

Hi, i have the same problem. my transformers version is 4.15.0.

I'm not familiar with transformers, how can I fix this.

Replace https://github.com/microsoft/CodeBERT/blob/c03cd8aa458c6c1935226db4d37fe7cd2c5b9ee1/GraphCodeBERT/clonedetection/model.py#L53 by outputs = self.encoder.roberta(inputs_embeds=inputs_embeddings, attention_mask=attn_mask,position_ids=position_idx, token_type_ids = position_idx.eq(-1))[0]

1021149914 commented 2 years ago

Thanks for your help.

However there is another problem. RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got CUDABoolType instead (while checking arguments for embedding)

need to change the type of 'token_type_ids' to long

outputs = self.encoder.roberta(inputs_embeds=inputs_embeddings,attention_mask=attn_mask,position_ids=position_idx, token_type_ids = position_idx.eq(-1).long())[0]

guoday commented 2 years ago

Yeah. you are right.