taishan1994 / pytorch_bert_multi_classification

基于pytorch_bert的中文多标签分类
82 stars 14 forks source link

你读数据的时候没有补零,DataLoader读取数据的时候会报错 #1

Closed MgArcher closed 2 years ago

taishan1994 commented 2 years ago

No description provided.

具体哪里有问题呢,应该没有问题,tokenizer.encode_plus()里面有个padding参数,设置后会自动填充0到最大长度

MgArcher commented 2 years ago

preprocess.py文件74行处:

encode_dict = tokenizer.encode_plus(text=raw_text, add_special_tokens=True, max_length=max_seq_len, truncation_strategy='longest_first', padding="max_length", return_token_type_ids=True, return_attention_mask=True) token_ids = encode_dict['input_ids']

padding参数似乎没有生效,token_ids并没有补0 解决方法: 在preprocess.py文件85行处添加代码进行手动补0: while len(token_ids) < max_seq_len: token_ids.append(0) attention_masks.append(0) token_type_ids.append(0)

assert len(token_ids) == max_seq_len assert len(attention_masks) == max_seq_len assert len(token_type_ids) == max_seq_len

MgArcher commented 2 years ago

preprocess.py文件74行处:

encode_dict = tokenizer.encode_plus(text=raw_text, add_special_tokens=True, max_length=max_seq_len, truncation_strategy='longest_first', padding="max_length", return_token_type_ids=True, return_attention_mask=True) token_ids = encode_dict['input_ids']

padding参数似乎没有生效,token_ids并没有补0 解决方法: 在preprocess.py文件85行处添加代码进行手动补0: while len(token_ids) < max_seq_len: token_ids.append(0) attention_masks.append(0) token_type_ids.append(0)

assert len(token_ids) == max_seq_len assert len(attention_masks) == max_seq_len assert len(token_type_ids) == max_seq_len

taishan1994 commented 2 years ago

preprocess.py文件74行处:

encode_dict = tokenizer.encode_plus(text=raw_text,

add_special_tokens=True, max_length=max_seq_len, truncation_strategy='longest_first', padding="max_length", return_token_type_ids=True, return_attention_mask=True) token_ids = encode_dict['input_ids'] padding参数似乎没有生效,token_ids并没有补0 解决方法: 在preprocess.py文件85行处添加代码进行手动补0: while len(token_ids) < max_seq_len: token_ids.append(0) attention_masks.append(0) token_type_ids.append(0)

assert len(token_ids) == max_seq_len assert len(attention_masks) == max_seq_len assert len(token_type_ids) == max_seq_len

可能是transformers版本的问题吧,transformers==4.7.0是可行的:

encode_dict = tokenizer.encode_plus(text=raw_text,
                                        add_special_tokens=True,
                                        max_length=max_seq_len,
                                        truncation='only_first',
                                        padding="max_length",
                                        return_token_type_ids=True,
                                        return_attention_mask=True)

改了下truncation='only_first'。不过加上你那个更稳妥,将你的代码已更新到里面了。