boostcampaitech2 / klue-level2-nlp-02

klue-level2-nlp-02 created by GitHub Classroom
0 stars 6 forks source link

Round-trip Translation으로 데이터를 늘려봅시다! #26

Open Yebin46 opened 3 years ago

Yebin46 commented 3 years ago

다른 분들께서 제가 상상도 못한 전처리를 정말 많이 시도해주셔서 (감사합니다ㅠㅠ) 저에게는 어려울 것 같아 미뤄뒀던 RTT를 작업해봤습니다..

Pororo 번역 모델 불러오기

from pororo import Pororo
model = Pororo(task="mt", lang="multi", model="transformer.large.multi.mtpg")

RTT 함수

def cycling_translation_en(sentence):
    english = model(sentence, 'ko', 'en') # 한->영
    korean = model(english, 'en', 'ko') # 영->한
    return korean

본격적인 데이터셋 만들기 (make_rtt_csv 함수)

아래는 저희의 raw data인 train.csv를 dataset으로 받아

데이터들을 raw data와 같은 형태(start_ind, type 등이 있는)로 csv 파일을 만들어주는 함수 make_rtt_csv입니다.

def make_rtt_csv(dataset):
    data_num = len(dataset)
    count = 0 # 총 몇 개의 데이터가 추가되는지 확인용
    lbl_candidate_list = ['per:place_of_residence', 'per:other_family', 'per:place_of_birth',
                     'org:founded_by', 'per:product', 'per:siblings', 'org:political/religious_affiliation',
                     'per:religion', 'per:schools_attended', 'org:dissolved', 'org:number_of_employees/members',
                     'per:place_of_death', ] # 개수가 193개 이하인 라벨들
    lbl_dict = dict.fromkeys(lbl_candidate_list, 0) # lbl_candidate_list를 키로 갖는 dictionary를 0으로 초기화

    ind_list = []
    new_sen_list = []
    new_subj_list = []
    new_obj_list = []
    label_list = []
    source_list = []
    for ind in range(data_num):
        _, sen, subj_dict, obj_dict, label, source = dataset.loc[ind]

        if label not in lbl_candidate_list: # 개수가 193개 초과인 라벨은 추가하지 않음
            continue
        rtt_sen = cycling_translation_en(sen)
        rtt_subj = cycling_translation_en(eval(subj_dict)['word'])
        rtt_obj = cycling_translation_en(eval(obj_dict)['word'])

        if rtt_subj not in rtt_sen or rtt_obj not in rtt_sen: # 번역된 subj, obj가 번역된 문장에 없으면 추가하지 않음
            continue
        if len(sen.split('.')) < len(rtt_sen.split('.')): # 문장이 여러 개로 만들어졌으면 추가하지 않는다 (잘못 번역되었을 확률 높음)
            continue

        subj_start = rtt_sen.find(rtt_subj)
        subj_end = subj_start + len(rtt_subj) - 1
        rtt_subj = str(dict({'word':rtt_subj,
                            'start_idx':subj_start,
                            'end_idx':subj_end,
                            'type':eval(subj_dict)['type'],
                            }))
        obj_start = rtt_sen.find(rtt_obj)
        obj_end = obj_start + len(rtt_obj) - 1
        rtt_obj = str(dict({'word':rtt_obj,
                            'start_idx':obj_start,
                            'end_idx':obj_end,
                            'type':eval(subj_dict)['type'],
                           }))

        ind_list.append(count)
        new_sen_list.append(rtt_sen)
        new_subj_list.append(rtt_subj)
        new_obj_list.append(rtt_obj)
        label_list.append(label)
        source_list.append(source)
        lbl_dict[label] += 1
        count += 1

        # 어떤 label을 갖는 데이터가 몇 개 추가 되었는지 출력하며 확인할 수 있습니다.
        print(f'[{ind} / {count} 개]\nORIGINAL: {sen}\nRTT: {rtt_sen}\nsubject: {rtt_subj}, object: {rtt_obj}, label: {label}')
        print(lbl_dict, '\n')
    new_data = {'sentence': new_sen_list,
               'subject_entity': new_subj_list,
               'object_entity': new_obj_list,
               'label': label_list,
               'source': source_list}
    new_data_df = pd.DataFrame(new_data)
    new_data_df.to_csv('/opt/ml/dataset/train/rtt_data.csv')

id column 이름 바꿔주기

id를 dictionary에 넣어서 csv 파일로 만들면 열이 하나 더 추가 되어서 그냥 안넣어주고 열 이름만 바꿨습니다.

# 근데 Unnamed 열은 왜 rename 함수로도 안바뀔까요ㅠ
rtt_data = pd.read_csv('/opt/ml/dataset/train/rtt_data.csv')
rtt_data.columns = ['id', 'sentence', 'subject_entity', 'object_entity', 'label', 'source']

결과

아래는 결과물 예시입니다! (이미지를 누르시면 크게 보실 수 있습니다.) image

위에서 설명드렸던 세 가지 조건으로 rtt를 수행했을 때, 총 306개의 데이터가 생성됩니다. 원래 데이터를 늘리려고 했던 클래스가 각각 40~193 개의 데이터를 갖고 있었어서 생성된 데이터 개수도 소량이네요...ㅠ 클래스별 추가된 데이터 개수는 아래와 같습니다.

{'per:place_of_residence': 58, 'per:other_family': 29, 'per:place_of_birth': 38, 'org:founded_by': 30, 'per:product': 14, 'per:siblings': 30, 'org:political/religious_affiliation': 20, 'per:religion': 16, 'per:schools_attended': 28, 'org:dissolved': 26, 'org:number_of_employees/members': 8, 'per:place_of_death': 9}

presto105 commented 3 years ago

rtt를 해주셨군요!! 목적에 따른 흐름이 인상깊습니다!!

j961224 commented 3 years ago

RTT 해주시다니..! 감사합니다~!!! 깔끔한 정리 감사합니다!