boostcamp-5th-NLP05 / level1_semantictextsimilarity-nlp-05

0 stars 0 forks source link

영어 토큰 전처리 코드 모은 파일 (2023/04/18) #8

Open yunjinchoidev opened 1 year ago

yunjinchoidev commented 1 year ago

(04/18) 현재까지 돌린 결과로는 영어토큰을 전처리 하지 않는게 결과가 더 잘 나옵니다.

from selenium import webdriver
from selenium.webdriver.common.by import By
import time
from selenium.webdriver.common.keys import Keys
import re
import os
import argparse
import json
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from konlpy.tag import Mecab
from googletrans import Translator

mecab = Mecab()

def has_english(self, s):
    mecab.morphs(s)
    for idx, val in enumerate(m):
        if val.encode().isalpha():
            if val == "PERSON": continue
            return True
    return False

def preprocess_eng_tok(s, to_korean_func):
    m = mecab.morphs(s)
    start = 0
    for idx, val in enumerate(m):
        if val.encode().isalpha():
            if val == "PERSON": continue
            res = to_korean_func(val)
            idx = s.find(val, start)
            s = s[:idx] + res + s[idx+len(val):]
            start = idx+len(val)

    return s

class PapagoCrawler:
    def __init__(self):
        chrome_options = webdriver.ChromeOptions()
        chrome_options.add_argument('--headless')
        chrome_options.add_argument('--no-sandbox')
        chrome_options.add_argument('--disable-dev-shm-usage')
        self.driver = webdriver.Chrome('chromedriver', chrome_options=chrome_options)

    def to_korean(self, input_sentence, max_count=5):
        self.driver.get("https://papago.naver.com/?sk=en&tk=ko")
        cnt = 0
        while cnt < max_count:
            try:
                input_box = self.driver.find_element(By.XPATH, '//*[@id="txtSource"]')
            except: cnt += 1
            else: break
        input_box.send_keys(input_sentence)
        time.sleep(1)
        get = False
        translated_text = ''
        while not get and translated_text != '...':
            try:
                translated_text = self.driver.find_element(By.XPATH,'//*[@id="txtTarget"]/span').text
                get = True
            except:
                continue
        return translated_text

class GoogletransTranslator:
    def __init__(self):
        self.trans = Translator()

    def to_korean(self, s, max_try=5):
        cnt = 0
        while cnt < max_try:
            try: # googletrans is unstable
                res = self.trans.translate(s, src="en", dest="ko").text
            except: cnt += 1
            else: return res

    def preprocess_eng_tok(self, s):
        m = mecab.morphs(s)
        start = 0
        for idx, val in enumerate(m):
            if val.encode().isalpha():
                if val == "PERSON": continue
                res = self.to_korean(val)
                idx = s.find(val, start)
                s = s[:idx] + res + s[idx+len(val):]
                start = idx+len(val)

        return s  

class MbartTranslator:
    def __init__(self):
        self.translator_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX",  max_length=120, additional_special_tokens=["<PERSON>"])
        self.translator = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
        self.translator.resize_token_embeddings(len(self.translator_tokenizer))

    def has_english(self, s):
        m = mecab.morphs(s)
        for idx, val in enumerate(m):
            if val.encode().isalpha():
                if val == "PERSON": continue
                return True
        return False

    def to_korean(self, sentence):
        model_inputs = self.translator_tokenizer(sentence, return_tensors="pt")
        generated_tokens = self.translator.generate(
            **model_inputs,
            forced_bos_token_id=self.translator_tokenizer.lang_code_to_id["ko_KR"]
        )
        res = self.translator_tokenizer.batch_decode(generated_tokens)[0][10:-4]
        return res

    def preprocess_eng_tok(self, s):
        m = mecab.morphs(s)
        start = 0
        for idx, val in enumerate(m):
            if val.encode().isalpha():
                if val == "PERSON": continue
                res = self.to_korean(val)
                idx = s.find(val, start)
                s = s[:idx] + res + s[idx+len(val):]
                start = idx+len(val)

        return s

if __name__ == "__main__":
    pp = PapagoCrawler()
    print("PapagoCrawler: It's started --> ", pp.to_korean("It's started!"))