Closed ympaik87 closed 2 years ago
전체 코드는 kaggle/inference를 참조하시면 됩니다!
코드 사용법을 간단히 알려드리겠습니다.
import sys
sys.path.insert(0, '../input/jigsaw-repo/jigsaw-rate-severity-of-toxic-comments-feature-wo_ln/jigsaw_toxic_severity_rating/')
위 코드는 import 섹션에 있는 코드입니다. 해당 코드가 있어야 모델을 문제없이 불러올 수 있습니다.
먼저 저희 깃 레포지토리를 캐글 데이터셋으로 올립니다. 가져온 데이터 경로대로 위의 경로를 변경해주시면 됩니다. 코드 자체에 대한 설명은 노션에 적어두겠습니다.
# pt파일 경로
MODEL_WEIGHTS = glob('../input/deberta/deberta/*.pt')
MODEL_DIR = '../input/tokenizer/deberta-v3-base'
위 코드는 Tokenizer와 학습된 모델 파일을 불러오기 위해 사용하는 코드입니다. 해당 방식으로 관련 파일들을 불러옵니다.
def inference(model_paths, textloader, device):
final_preds = []
for i, path in enumerate(model_paths):
model = torch.load(path)
model.to(CONFIG['device'])
print(f"Getting predictions for model {i+1}")
preds = valid_fn(model, textloader, device)
final_preds.append(preds)
del model
_ = gc.collect()
final_preds = np.array(final_preds)
final_preds = np.mean(final_preds, axis=0)
return final_preds
모델을 직접 불러오는 코드입니다. 이전 코드와 크게 바뀐 부분은 없으며
모델을 불러올 때 다음과 같이 모델을 불러오시면 됩니다.
model = torch.load(path)
저희 깃 레포지토리 model.py 파일안에 학습시킨 모델 클래스(e.g., class JigsawModel())가 적혀있는 경우에만 위와 같이 코드를 사용할 수 있습니다.
만약 class JigsawModel을 사용하지 않고 다른 클래스(e.g., class ElectraModel())로 모델을 구성하였다면 해당 클래스가 model.py 안에 작성되어 있어야 합니다!
궁금하신 점이 있으시다면 이슈에 코멘트 남겨주세요!
실험이 더욱 편리해지겠네요! 감사합니다 ㅎㅎ
상세한 설명 감사합니다!
모델 저장 방식이 #29 로 업댓된 관계로, inference에서 로드하는 방식에 대한 guideline이 필요합니다.