Open xujunrt opened 2 years ago
word_id_2_token_id_mapping = {}
token_len = 0
if isinstance(textlist, list):
for word_idx, word in enumerate(textlist):
word_id_2_token_id_mapping[word_idx] = token_len
token_len += len(tokenizer.tokenize(word))
textlist = " ".join(textlist)
else:
wordlist = textlist.split(" ")
for word_idx, word in enumerate(wordlist):
word_id_2_token_id_mapping[word_idx] = token_len
token_len += len(tokenizer.tokenize(word))
tokens = tokenizer.tokenize(textlist)
start_ids = [0] * len(tokens)
end_ids = [0] * len(tokens)
subjects_id = []
for subject in subjects:
label = subject[0]
start = word_id_2_token_id_mapping[subject[1]]
end = word_id_2_token_id_mapping[subject[2]]
start_ids[start] = label2id[label]
end_ids[end] = label2id[label]
subjects_id.append((label2id[label], start, end))
做一个简单处理就可以了
例如“7.7” 会被切分为“7”,“.”,“7”,token会增加3个,但是start还是在原来的位置上