Closed kuailehaha closed 2 months ago
Hi, I used the following script to preprocess each subject based on the MedMCQA training split.
And I have just uploaded the pre-processed evaluation splits on HuggingFace, so you can directly download them now: med_knowledge_prob and law_knowledge_prob.
For the evaluation score, I counted the final score by 'average accuracy over 21 subjects' as you did. Regarding the final score, I think it's acceptable to have some absolute score differences as long as the relative trend between the General LLM and the LLM after DAPT remains consistent. I hope this helps!
"""
usage:
python preprocess_knowledge_probing.py \
--train_set './train.json' \
--output_dir './MedMCQA_Know_Prob' \
--random_seed 0
"""
import argparse
import json
import json_lines
import jsonlines
import os
import random
parser = argparse.ArgumentParser()
parser.add_argument('--train_set', type=str, help='path to the `train.json` of medmacqa', default='./train.json')
parser.add_argument('--output_dir', type=str, help='directory for saving the splits.', default='./MedMCQA_Know_Prob')
parser.add_argument('--random_seed', type=int, default=0, help='random seed for shuffling the processed data')
args = parser.parse_args()
def read_jsonl(path):
data = []
with open(path, 'rb') as f:
for entry in json_lines.reader(f):
data.append(entry)
return data
data = read_jsonl(args.train_set)
def filter(entry):
# 1. to fix an issue with the dataset, see https://github.com/medmcqa/medmcqa/issues/4
if entry["choice_type"] == "multi":
return False
# 2. remove those in the instruction-following format, such as question-answer pairs starting with `which`
question_marks=['which', 'what', 'when', 'where', 'how', 'who', 'why', '__', ':', '?', '-', '–']
for mark in question_marks:
if mark in entry['question'].lower():
return False
return True
filtered_data = [x for x in data if filter(x)==True]
print(f'filtered_data length={len(filtered_data)}')
SUBJECTS= ['Anaesthesia', 'Anatomy', 'Biochemistry', 'Dental', 'ENT',
'Forensic Medicine', 'Gynaecology & Obstetrics', 'Medicine', 'Microbiology', 'Ophthalmology',
'Orthopedics', 'Pathology', 'Pediatrics', 'Pharmacology', 'Physiology',
'Psychiatry', 'Radiology', 'Skin', 'Social & Preventive Medicine', 'Surgery',
'Unknown']
def save_jsonl(docs, out_path):
if os.path.isfile(out_path):
os.remove(out_path)
with jsonlines.open(out_path,mode='a') as writer:
for doc in docs:
writer.write(doc)
print('saved jsonl to: ', out_path)
all_counts=0
for sub in SUBJECTS:
sub_data = []
for entry in filtered_data:
if entry["subject_name"] == sub or entry['topic_name'] == sub:
sub_data.append(entry)
print(f'{sub} data length = {len(sub_data)}')
all_counts += len(sub_data)
random.Random(args.random_seed).shuffle(sub_data) # shuffle the data order, this is optional
save_jsonl(sub_data, os.path.join(args.output_dir, f'{sub}.jsonl'))
print(f"all_counts: {all_counts}")
Thank you for your response. Best wishes!
Hi. I'm a student following your work "Adapting LLM via Reading Comprehension". I try to replicate the knowledge probing test. I process the dataset from medmcqa with regex
pattern = r"^(What|When|Who)[^\n]*|[:?-]$|__"
according to the paper. Thetrain.json
got 59348 data entries while thedev.json
got 1103. But I can‘t replicate the score 36.5 over General LLM Llama-7B and 36.9 over my DAPT LLM. They scored approximately 32 points based on accuracy and 35 based on average accuracy over 21 subjects. Could you please provide more details for Knowledge Probing test? Thanks!