sail-sg / regmix

🧬 RegMix: Data Mixture as Regression for Language Model Pre-training
MIT License
79 stars 3 forks source link

Can you share the script for generating the prior distribution of the token? #8

Closed a154377713 closed 1 day ago

SivilTaram commented 1 day ago

Hello @a154377713 , you can use the following demo code to print the rough prior distribution of different domains (warning: you have to preprocess the dataset using a relatively small chunk size at here so that the number of chunks can represent the tokens, and this code assumes that your domain prefix is separated using -):

import os
import random

def get_prefix_tokens(folder_path):
    # read all files in the folder
    all_files = os.listdir(folder_path)
    # shuffle
    random.shuffle(all_files)
    # filter prefix with doremi
    all_files = [file_path for file_path in all_files if "train_" in file_path]
    prefix_dict = {}
    for file_path in all_files:
        # get the prefix
        prefix = file_path.split("-")[0].strip()
        if prefix in prefix_dict:
            prefix_dict[prefix] += 1
        else:
            prefix_dict[prefix] = 1
    # print the normalized weight
    total = sum(prefix_dict.values())
    for prefix in prefix_dict:
        print(prefix + ": ", prefix_dict[prefix]/total)

Assume that your preprocessed dataset is lit_dataset_regmix, you can use the code as:

get_prefix_tokens("lit_dataset_regmix")
a154377713 commented 1 day ago

Thanks for the reply!