기존과 매우 유사하지만 다른 점이 있다면 json 폴더에서 기존 preprocessing pattern으로 이미 작업한 내용이 있다면 이를 불러와서 pandas dataframe으로 변경하는 과정이 추가되어있습니다. 매번 작업을 돌리기에는 대략 50분~1시간이나 걸리기에 이러한 과정을 추가하였습니다.
def retrieve_train_BM25(
self, dataset: Union[str, Dataset], topk: Optional[int] = 1, rtt_name : Optional[str] = None
) -> Union[Tuple[List, List], pd.DataFrame]:
assert self.BM25 is not None and isinstance(dataset, Dataset)
sep_flag = 1 if self.add_special_tokens_flag == True else 0
rtt_flag = 1 if rtt_name != None else 0
json_name = f"train_retrieval_{self.pt_num}_{sep_flag}_{rtt_flag}.json"
json_path = os.path.join('./json', json_name)
if os.path.isfile(json_path):
print("Load Saved Retrieval Json Data.")
with open(json_path , "r", encoding="utf-8") as f:
json_data = json.load(f)
cqas = pd.DataFrame(json_data)
else :
total = []
print('Make Retrieval Pandas Data')
with timer("query exhaustive search"):
doc_scores, doc_indices = self.get_relevant_train_bulk_BM25(dataset, k=topk, )
for idx, example in enumerate(
tqdm(dataset, desc="BM25 retrieval: ")
):
context = " [SPLIT] ".join([self.contexts[pid] for pid in doc_indices[idx]]) if self.add_special_tokens_flag \
else " ".join([self.contexts[pid] for pid in doc_indices[idx]])
tmp = {
# Query와 해당 id를 반환합니다.
"question": example["question"],
"id": example["id"],
# Retrieve한 Passage의 id, context를 반환합니다.
"context_id": doc_indices[idx],
"context": context,
}
if "context" in example.keys() and "answers" in example.keys():
# validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
tmp["original_context"] = example["context"]
tmp["answers"] = example["answers"]
total.append(tmp)
cqas = pd.DataFrame(total)
cqas.to_json(json_path)
f = Features(
{
"answers": Sequence(
feature={
"text": Value(dtype="string", id=None),
"answer_start": Value(dtype="int32", id=None),
},
length=-1,
id=None,
),
"context": Value(dtype="string", id=None),
"id": Value(dtype="string", id=None),
"question": Value(dtype="string", id=None),
}
)
print('Make Retrieved Train Dataset')
datasets = Dataset.from_pandas(cqas, features=f)
return datasets
def get_relevant_train_bulk_BM25(
self, datasets: Dataset, k: Optional[int] = 1,
) -> Tuple[List, List]:
print("Build BM25 score, indices")
data_size = len(datasets)
queries = datasets['question']
contexts = datasets['context']
tokenized_queries= [self.tokenizer(i) for i in queries]
doc_scores = []
doc_indices = []
for i in tqdm(range(data_size)):
scores = self.BM25.get_scores(tokenized_queries[i])
context_txt = contexts[i]
sorted_score = np.sort(scores)[::-1]
sorted_id = np.argsort(scores)[::-1]
org_rank = self.contexts.index(context_txt)
selected_scores = [0]
selected_indices = [org_rank]
j = 1
size = 1
while(size < k) :
doc_id = sorted_id[j]
doc_score = sorted_score[j]
if doc_id != org_rank :
selected_scores.append(doc_score)
selected_indices.append(doc_id)
size += 1
j += 1
doc_scores.append(selected_scores)
doc_indices.append(selected_indices)
return doc_scores, doc_indices
Context를 뒤로 붙인 결과
각 문단을 [SPLIT] 라는 Special Token을 넣어주면서 구분을 하도록 하였습니다. 이를 위해서는 아래와 같은 arguments를 넣어주어야 합니다.
add_special_tokens_flag:bool = field(
default=False,
metadata={"help": "add special tokens"},
)
주 구현내용 : Train Retrieval 이라는 Argument를 넣어주면서 BM25로 가져온 Context를 뒤로 붙이는 과정입니다.
수정 부분
1. augments.py
train_retrieval이 기존을 False로 하였고 --train_retrieval이라는 인자를 학습할 때 넣어주면 동작하도록 하였습니다.
2. train.py
train_retrieval이 True면 train.py에서 아래와 같은 과정을 진행하게 됩니다.
3. retrieval_sparse_BM25.py
기존과 매우 유사하지만 다른 점이 있다면 json 폴더에서 기존 preprocessing pattern으로 이미 작업한 내용이 있다면 이를 불러와서 pandas dataframe으로 변경하는 과정이 추가되어있습니다. 매번 작업을 돌리기에는 대략 50분~1시간이나 걸리기에 이러한 과정을 추가하였습니다.
Context 사진