medal-challenger / jigsaw-rate-severity-of-toxic-comments

0 stars 0 forks source link

inference 에서 모델 load 방식 변경 #33

Closed ympaik87 closed 2 years ago

ympaik87 commented 2 years ago

모델 저장 방식이 #29 로 업댓된 관계로, inference에서 로드하는 방식에 대한 guideline이 필요합니다.

kkbwilldo commented 2 years ago

전체 코드는 kaggle/inference를 참조하시면 됩니다!

코드 사용법을 간단히 알려드리겠습니다.


import sys
sys.path.insert(0, '../input/jigsaw-repo/jigsaw-rate-severity-of-toxic-comments-feature-wo_ln/jigsaw_toxic_severity_rating/')

위 코드는 import 섹션에 있는 코드입니다. 해당 코드가 있어야 모델을 문제없이 불러올 수 있습니다.

먼저 저희 깃 레포지토리를 캐글 데이터셋으로 올립니다. 가져온 데이터 경로대로 위의 경로를 변경해주시면 됩니다. 코드 자체에 대한 설명은 노션에 적어두겠습니다.


# pt파일 경로
MODEL_WEIGHTS = glob('../input/deberta/deberta/*.pt')
MODEL_DIR = '../input/tokenizer/deberta-v3-base'

위 코드는 Tokenizer와 학습된 모델 파일을 불러오기 위해 사용하는 코드입니다. 해당 방식으로 관련 파일들을 불러옵니다.

def inference(model_paths, textloader, device):
    final_preds = []
    for i, path in enumerate(model_paths):

        model = torch.load(path)
        model.to(CONFIG['device'])

        print(f"Getting predictions for model {i+1}")
        preds = valid_fn(model, textloader, device)

        final_preds.append(preds)

        del model
        _ = gc.collect()
    final_preds = np.array(final_preds)
    final_preds = np.mean(final_preds, axis=0)
    return final_preds

모델을 직접 불러오는 코드입니다. 이전 코드와 크게 바뀐 부분은 없으며

모델을 불러올 때 다음과 같이 모델을 불러오시면 됩니다.

model = torch.load(path)

저희 깃 레포지토리 model.py 파일안에 학습시킨 모델 클래스(e.g., class JigsawModel())가 적혀있는 경우에만 위와 같이 코드를 사용할 수 있습니다.

만약 class JigsawModel을 사용하지 않고 다른 클래스(e.g., class ElectraModel())로 모델을 구성하였다면 해당 클래스가 model.py 안에 작성되어 있어야 합니다!


궁금하신 점이 있으시다면 이슈에 코멘트 남겨주세요!

Kingthegarden commented 2 years ago

실험이 더욱 편리해지겠네요! 감사합니다 ㅎㅎ

ympaik87 commented 2 years ago

상세한 설명 감사합니다!