Closed ratsgo closed 2 years ago
참고 코드
class Tokenizer:
""" Tokenizer class"""
def __init__(self, vocab, split_fn, pad_fn, maxlen):
self._vocab = vocab
self._split = split_fn
self._pad = pad_fn
self._maxlen = maxlen
# def split(self, string: str) -> list[str]:
def split(self, string):
tokens = self._split(string)
return tokens
# def transform(self, list_of_tokens: list[str]) -> list[int]:
def transform(self, tokens):
indices = self._vocab.to_indices(tokens)
pad_indices = self._pad(indices, pad_id=0, maxlen=self._maxlen) if self._pad else indices
return pad_indices
# def split_and_transform(self, string: str) -> list[int]:
def split_and_transform(self, string):
return self.transform(self.split(string))
@property
def vocab(self):
return self._vocab
def list_of_tokens_to_list_of_token_ids(self, X_token_batch):
X_ids_batch = []
for X_tokens in X_token_batch:
X_ids_batch.append([self._vocab.transform_token2idx(X_token) for X_token in X_tokens])
return X_ids_batch
def list_of_string_to_list_of_tokens(self, X_str_batch):
X_token_batch = [self._split(X_str) for X_str in X_str_batch]
return X_token_batch
def list_of_tokens_to_list_of_token_ids(self, X_token_batch):
X_ids_batch = []
for X_tokens in X_token_batch:
X_ids_batch.append([self._vocab.transform_token2idx(X_token) for X_token in X_tokens])
return X_ids_batch
def list_of_string_to_list_token_ids(self, X_str_batch):
X_token_batch = self.list_of_string_to_list_of_tokens(X_str_batch)
X_ids_batch = self.list_of_tokens_to_list_of_token_ids(X_token_batch)
return X_ids_batch
def list_of_string_to_arr_of_pad_token_ids(self, X_str_batch, add_start_end_token=False):
X_token_batch = self.list_of_string_to_list_of_tokens(X_str_batch)
# print("X_token_batch: ", X_token_batch)
if add_start_end_token is True:
return self.add_start_end_token_with_pad(X_token_batch)
else:
X_ids_batch = self.list_of_tokens_to_list_of_token_ids(X_token_batch)
pad_X_ids_batch = self._pad(X_ids_batch, pad_id=self._vocab.PAD_ID, maxlen=self._maxlen)
return pad_X_ids_batch
def list_of_tokens_to_list_of_cls_sep_token_ids(self, X_token_batch):
X_ids_batch = []
for X_tokens in X_token_batch:
X_tokens = [self._vocab.cls_token] + X_tokens + [self._vocab.sep_token]
X_ids_batch.append([self._vocab.transform_token2idx(X_token) for X_token in X_tokens])
return X_ids_batch
def list_of_string_to_arr_of_cls_sep_pad_token_ids(self, X_str_batch):
X_token_batch = self.list_of_string_to_list_of_tokens(X_str_batch)
X_ids_batch = self.list_of_tokens_to_list_of_cls_sep_token_ids(X_token_batch)
pad_X_ids_batch = self._pad(X_ids_batch, pad_id=self._vocab.PAD_ID, maxlen=self._maxlen)
return pad_X_ids_batch
def list_of_string_to_list_of_cls_sep_token_ids(self, X_str_batch):
X_token_batch = self.list_of_string_to_list_of_tokens(X_str_batch)
X_ids_batch = self.list_of_tokens_to_list_of_cls_sep_token_ids(X_token_batch)
return X_ids_batch
def add_start_end_token_with_pad(self, X_token_batch):
dec_input_token_batch = [[self._vocab.START_TOKEN] + X_token for X_token in X_token_batch]
dec_output_token_batch = [X_token + [self._vocab.END_TOKEN] for X_token in X_token_batch]
dec_input_token_batch = self.list_of_tokens_to_list_of_token_ids(dec_input_token_batch)
pad_dec_input_ids_batch = self._pad(dec_input_token_batch, pad_id=self._vocab.PAD_ID, maxlen=self._maxlen)
dec_output_ids_batch = self.list_of_tokens_to_list_of_token_ids(dec_output_token_batch)
pad_dec_output_ids_batch = self._pad(dec_output_ids_batch, pad_id=self._vocab.PAD_ID, maxlen=self._maxlen)
return pad_dec_input_ids_batch, pad_dec_output_ids_batch
def decode_token_ids(self, token_ids_batch):
list_of_token_batch = []
for token_ids in token_ids_batch:
token_token = [self._vocab.transform_idx2token(token_id) for token_id in token_ids]
# token_token = [self._vocab[token_id] for token_id in token_ids]
list_of_token_batch.append(token_token)
return list_of_token_batch
class Vocabulary(object):
"""Vocab Class"""
def __init__(self, token_to_idx=None):
self.token_to_idx = {}
self.idx_to_token = {}
self.idx = 0
self.PAD = self.padding_token = "[PAD]"
self.START_TOKEN = "<S>"
self.END_TOKEN = "<T>"
self.UNK = "[UNK]"
self.CLS = "[CLS]"
self.MASK = "[MASK]"
self.SEP = "[SEP]"
self.SEG_A = "[SEG_A]"
self.SEG_B = "[SEG_B]"
self.NUM = "<num>"
self.cls_token = self.CLS
self.sep_token = self.SEP
self.special_tokens = [self.PAD,
self.START_TOKEN,
self.END_TOKEN,
self.UNK,
self.CLS,
self.MASK,
self.SEP,
self.SEG_A,
self.SEG_B,
self.NUM]
self.init_vocab()
if token_to_idx is not None:
self.token_to_idx = token_to_idx
self.idx_to_token = {v: k for k, v in token_to_idx.items()}
self.idx = len(token_to_idx) - 1
# if pad token in token_to_idx dict, get pad_id
if self.PAD in self.token_to_idx:
self.PAD_ID = self.transform_token2idx(self.PAD)
else:
self.PAD_ID = 0
def init_vocab(self):
for special_token in self.special_tokens:
self.add_token(special_token)
self.PAD_ID = self.transform_token2idx(self.PAD)
def __len__(self):
return len(self.token_to_idx)
def to_indices(self, tokens):
return [self.transform_token2idx(X_token) for X_token in tokens]
def add_token(self, token):
if not token in self.token_to_idx:
self.token_to_idx[token] = self.idx
self.idx_to_token[self.idx] = token
self.idx += 1
def transform_token2idx(self, token, show_oov=False):
try:
return self.token_to_idx[token]
except:
if show_oov is True:
print("key error: " + str(token))
token = self.UNK
return self.token_to_idx[token]
def transform_idx2token(self, idx):
try:
return self.idx_to_token[idx]
except:
print("key error: " + str(idx))
idx = self.token_to_idx[self.UNK]
return self.idx_to_token[idx]
def build_vocab(self, list_of_str, threshold=1, vocab_save_path="./data_in/token_vocab.json",
split_fn=None):
"""Build a token vocab"""
def do_concurrent_tagging(start, end, text_list, counter):
for i, text in enumerate(text_list[start:end]):
text = text.strip()
text = text.lower()
try:
tokens_ko = split_fn(text)
# tokens_ko = [str(pos[0]) + '/' + str(pos[1]) for pos in tokens_ko]
counter.update(tokens_ko)
if i % 1000 == 0:
print("[%d/%d (total: %d)] Tokenized input text." % (
start + i, start + len(text_list[start:end]), len(text_list)))
except Exception as e: # OOM, Parsing Error
print(e)
continue
counter = Counter()
num_thread = 4
thread_list = []
num_list_of_str = len(list_of_str)
for i in range(num_thread):
thread_list.append(Thread(target=do_concurrent_tagging, args=(
int(i * num_list_of_str / num_thread), int((i + 1) * num_list_of_str / num_thread), list_of_str,
counter)))
for thread in thread_list:
thread.start()
for thread in thread_list:
thread.join()
# vocab_report
print(counter.most_common(10)) # print most common tokens
tokens = [token for token, cnt in counter.items() if cnt >= threshold]
for i, token in enumerate(tokens):
self.add_token(str(token))
print("len(self.token_to_idx): ", len(self.token_to_idx))
import json
with open(vocab_save_path, 'w', encoding='utf-8') as f:
json.dump(self.token_to_idx, f, ensure_ascii=False, indent=4)
return self.token_to_idx
import keras
import numpy as np
def keras_pad_fn(token_ids_batch, maxlen, pad_id=0, padding='post', truncating='post'):
padded_token_ids_batch = keras.preprocessing.sequence.pad_sequences(token_ids_batch,
value=pad_id, # vocab.transform_token2idx(PAD),
padding=padding,
truncating=truncating,
maxlen=maxlen)
return np.array(padded_token_ids_batch)
import gluonnlp as nlp
from gluonnlp.data import SentencepieceTokenizer, SentencepieceDetokenizer
ptr_tokenizer = SentencepieceTokenizer("/Users/david/Downloads/kobert_news_wiki_ko_cased-ae5711deb3.spiece")
vocab_b_obj = nlp.vocab.BERTVocab.from_sentencepiece("/Users/david/Downloads/kobert_news_wiki_ko_cased-ae5711deb3.spiece", padding_token='[PAD]')
token2idx = vocab_b_obj.token_to_idx
vocab = Vocabulary(token2idx)
tokenizer = Tokenizer(vocab=vocab, split_fn=ptr_tokenizer, pad_fn=keras_pad_fn, maxlen=32)
text = "첫 회를 시작으로 13일까지 4일간 총 4회에 걸쳐 매 회 2편씩 총 8편이 공개될 예정이다."
label_text = "첫 회를 시작으로 <13일:DAT>까지 <4일간:DUR> 총 <4회:NOH>에 걸쳐 매 회 <2편:NOH>씩 총 <8편:NOH>이 공개될 예정이다."
# tokens = ['▁첫', '▁', '회를', '▁시작으로', '▁13', '일까지', '▁4', '일간', '▁총', '▁4', '회', '에', '▁걸쳐', '▁매', '▁회', '▁2', '편', '씩', '▁총', '▁8', '편', '이', '▁공개', '될', '▁예정이다', '.']
# list_of_ner_label = ['O', 'O', 'O', 'O', 'B-DAT', 'I-DAT', 'B-DUR', 'I-DUR', 'O', 'B-NOH', 'I-NOH', 'O', 'O', 'O', 'O', 'B-NOH', 'I-NOH', 'O', 'O', 'B-NOH', 'I-NOH', 'O', 'O', 'O', 'O', 'O']
# source fn
tokens = tokenizer.split(text) # wordpiece(BPE) tokenize와 동일 => tokenizer.tokenize(text)
token_ids_with_cls_sep = tokenizer.list_of_string_to_arr_of_cls_sep_pad_token_ids([text]) # => tokenizer(text)["input_ids"]
prefix_sum_of_token_start_index = []
sum = 0
for i, token in enumerate(tokens):
if i == 0:
prefix_sum_of_token_start_index.append(0)
sum += len(token) - 1
else:
prefix_sum_of_token_start_index.append(sum)
sum += len(token)
# target fn
import re
regex_ner = re.compile('<(.+?):[A-Z]{3}>') # NER Tag가 2자리 문자면 {3} -> {2}로 변경 (e.g. LOC -> LC) 인경우
regex_filter_res = regex_ner.finditer(label_text)
list_of_ner_tag = []
list_of_ner_text = []
list_of_tuple_ner_start_end = []
count_of_match = 0
for match_item in regex_filter_res:
ner_tag = match_item[0][-4:-1] # <4일간:DUR> -> DUR
ner_text = match_item[1] # <4일간:DUR> -> 4일간
start_index = match_item.start() - 6 * count_of_match # delete previous '<, :, 3 words tag name, >'
end_index = match_item.end() - 6 - 6 * count_of_match
list_of_ner_tag.append(ner_tag)
list_of_ner_text.append(ner_text)
list_of_tuple_ner_start_end.append((start_index, end_index))
count_of_match += 1
list_of_ner_label = []
entity_index = 0
is_entity_still_B = True
for tup in zip(tokens, prefix_sum_of_token_start_index):
token, index = tup
if '▁' in token: # 주의할 점!! '▁' 이것과 우리가 쓰는 underscore '_'는 서로 다른 토큰임
index += 1 # 토큰이 띄어쓰기를 앞단에 포함한 경우 index 한개 앞으로 당김 # ('▁13', 9) -> ('13', 10)
if entity_index < len(list_of_tuple_ner_start_end):
start, end = list_of_tuple_ner_start_end[entity_index]
if end < index: # 엔티티 범위보다 현재 seq pos가 더 크면 다음 엔티티를 꺼내서 체크
is_entity_still_B = True
entity_index = entity_index + 1 if entity_index + 1 < len(list_of_tuple_ner_start_end) else entity_index
start, end = list_of_tuple_ner_start_end[entity_index]
if start <= index and index < end: # <13일:DAT>까지 -> ('▁13', 10, 'B-DAT') ('일까지', 12, 'I-DAT') 이런 경우가 포함됨, 포함 안시키려면 토큰의 length도 계산해서 제어해야함
entity_tag = list_of_ner_tag[entity_index]
if is_entity_still_B is True:
entity_tag = 'B-' + entity_tag
list_of_ner_label.append(entity_tag)
is_entity_still_B = False
else:
entity_tag = 'I-' + entity_tag
list_of_ner_label.append(entity_tag)
else:
is_entity_still_B = True
entity_tag = 'O'
list_of_ner_label.append(entity_tag)
else:
entity_tag = 'O'
list_of_ner_label.append(entity_tag)
raw data
split code
import random
corpus = open("/Users/david/Downloads/original_data.txt").read()
data = corpus.split("\n\n")
random.seed(7)
num_total_samples = len(data)
num_valid_samples = int(num_total_samples * 0.1)
valid_idxes = random.sample(range(num_total_samples), num_valid_samples)
train_dataset = []
valid_dataset = []
for idx, el in enumerate(data):
if idx in valid_idxes:
valid_dataset.append(el.split("\n"))
else:
train_dataset.append(el.split("\n"))
with open("/Users/david/Downloads/train2.txt", "w", encoding="utf-8") as f:
for el in train_dataset:
line = el[1].replace("## ", "") + "\u241E" + el[2].replace("## ", "") + "\n"
f.writelines(line)
with open("/Users/david/Downloads/valid.txt", "w", encoding="utf-8") as f:
for el in valid_dataset:
line = el[1].replace("## ", "") + "\u241E" + el[2].replace("## ", "") + "\n"
f.writelines(line)
학습데이터 추가 확보
# git clone https://github.com/kmounlp/NER.git
# https://github.com/kmounlp/NER/tree/master/%EB%A7%90%EB%AD%89%EC%B9%98%20-%20%ED%98%95%ED%83%9C%EC%86%8C_%EA%B0%9C%EC%B2%B4%EB%AA%85
import glob
fpaths = glob.glob("/Users/david/works/NER/말뭉치 - 형태소_개체명/*.txt")
with open("/Users/david/Downloads/train1.txt", "w", encoding="utf-8") as f:
for fpath in fpaths:
raw_lines = open(fpath, "r", encoding="utf-8").readlines()
lines = [line.replace("\ufeff", "").replace("## ", "") for line in raw_lines if line.replace("\ufeff", "").startswith("## ")]
assert len(lines) % 3 == 0, f"{fpath} # of line error!"
for idx, line in enumerate(lines):
if idx > 0 and idx % 3 == 2:
processed_line = lines[idx - 1].strip() + "\u241E" + lines[idx].strip()
f.writelines(processed_line + "\n")
train.txt
성능 저하 문제 발생
개요
개체명 인식 튜토리얼을 작성한다