boostcampaitech2 / mrc-level2-nlp-04

mrc-level2-nlp-04 created by GitHub Classroom
4 stars 5 forks source link

Dense retrieval 구현 동참 누가 빨리하나 시합 #22

Closed raki-1203 closed 2 years ago

raki-1203 commented 2 years ago

arguments.py

retrieval_encoder.py

custom encoder 클래스 만들어진 파일 저장하는 방법과 불러오는 방법이 헷갈려서 시간을 많이 뺏겼지만 이제 좀 알것같음 AutoModel.from_pretrained() 해서 불러오는 모델을 train 해주고 save_pretrained() 해주는 데 RetrievalEncoder.encoder.save_pretrained() 해주면 우리가 불러오는 부분만 저장되는 것 같음

여기서 갑자기 의문이 드는데 저기에 nn.linear() 이런 선형변환을 하나 추가해주면 쟤만 부르면 또 안되겠다는 생각이들어서.... 어찌해야 갑자기 생각이 드네요....

dense_retriaval.py

DenseRetrieval 이라는 클래스 선언된 파일

retrieval_train.py

scrach 로 retrieval train 하는 함수 구현해놓음 epoch 마다 validation score 계산하게 되어있음

python retrival_train.py
--project_name dense_retrieval_implement
--retrieval_run_name roberta-small
--use_trained_model False
--retrieval_model_name_or_path klue/roberta-small

이런식으로 인자 주어지면 모델이 저장되는 위치는 "./retrival_output/roberta-small" 여기에 p_encoder, q_encoder 나눠서 저장됨

retrieval_test.py

train_dataset 에 들어있는 train, validation 데이터를 모두 사용해서 retrieval 모델들이 얼마나 wiki 에서 정확한 문장을 잘 뽑아오는지 테스트하는 파일

k = 1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] 에 대해서 test 해줌

python retrieval_test.py
--project_name retrieval_accuracy_comparison
--retrieval_run_name roberta-small
--use_trained_model True
--retrieval_model_name_or_path klue/roberta-small

train 에서의 retrieval_run_name 과 retrieval_model_name_or_path 를 맞춰주면 학습했던 모델을 사용할 수 있음

utils_retrieval.py