bcaitech1 / p4-fr-9-googoo

p4-fr-p4-fr-9-googoo created by GitHub Classroom
0 stars 0 forks source link

train.py의 id_to_string method 오류 #7

Open jhj9109 opened 3 years ago

jhj9109 commented 3 years ago

토론 글을 통해서 제시된 오류. Score를 계산하는 변환 시에 <EOS>를 만나면 종료되는 코드 추가하여 해결

원본 코드

def id_to_string(tokens, data_loader,do_eval=0):
    result = []
    if do_eval:
        special_ids = [data_loader.dataset.token_to_id["<PAD>"], data_loader.dataset.token_to_id["<SOS>"],
                       data_loader.dataset.token_to_id["<EOS>"]]

    for example in tokens:
        string = ""
        if do_eval:
            for token in example:
                token = token.item()
                if token not in special_ids:
                    if token != -1:
                        string += data_loader.dataset.id_to_token[token] + " "
        else:
            for token in example:
                token = token.item()
                if token != -1:
                    string += data_loader.dataset.id_to_token[token] + " "

문제점

<EOS> 이후에도 스페셜 토큰이 아닌 어떤 값(id)이 결과로 남아있을때, 불필요하게 이를 고려하여 score에 영향을 끼치게 된다.

e.g. teacher forcing이 적용될 때. <EOS>를 예측했지만, 올바른 출력이 아니라면 => 올바른 출력이 다시 입력으로 들어가 예측에 활용되어 <EOS> 이후에도 맞지 않은 토큰 값들이 출력된다. e.g. gt: <SOS> { 1 } + { 2 } = { 3 } <EOS> <PAD> <PAD> ... pred: <SOS> { 1 } <EOS> { 2 } = { 3 } <EOS> <PAD> ...

수정 코드

def id_to_string(tokens, data_loader,do_eval=0): # 0 Preds 1 -1 -1....
    result = []
    if do_eval:
        eos_id =  data_loader.dataset.token_to_id["<EOS>"]
        pad_id = data_loader.dataset.token_to_id["<PAD>"]
        sos_id = data_loader.dataset.token_to_id["<SOS>"]
        pad_id2 = -1
        ignore_ids = {
            pad_id : 1,
            sos_id : 1,
            pad_id2 : 1,
        }
    for example in tokens:
        string = ""
        if do_eval:  # 계산 용도 => score 와 관련이 있다.
            for token in example:
                token = token.item()
                if token == eos_id: # <EOS>만나면 종료한다.
                    break
                if token not in ignore_ids: # eos 외 무시할 id들을 체크한다.
                    string += data_loader.dataset.id_to_token[token] + " "
        else: # display 용도.
            for token in example:
                token = token.item()
                if token != -1: # 길이 채우기 위한 -1만 무시한다.
                    string += data_loader.dataset.id_to_token[token] + " "

        result.append(string)
    return result