allenai / open-instruct

Apache License 2.0
1.1k stars 145 forks source link

Tulu v2 Sanky Diagram #91

Closed nuoma closed 7 months ago

nuoma commented 7 months ago

Hey, made a quick Sanky Diagram for Tulu v2 and thought it might be interesting to share it with you guys. The reason I'm doing this is because I noticed FLAN is repeatedly used by different datasets, just like what LLM guys are doing with GSM8k, potentially causing data contaminations. Unfortunately, I still cannot figure out some of the detailed relationships correctly.

image

Made using https://sankeymatic.com/build/ Script:

FLAN v2 [50000] Tulu v2
FLAN v2 CoT [50000] Tulu v2
oasst1 [7708] Tulu v2
ShareGPT [114046] Tulu v2
GPT4-Alpaca [20000] Tulu v2
Code-Alpaca [20022] Tulu v2
LIMA [1030] Tulu v2
Evol Instruct [30000] Tulu v2
Open-Orca [30000] Tulu v2
Hardcoded [140] Tulu v2
Science [7544] Tulu v2

FLAN v2 CoT [75000] Open-Orca
FLAN v2 niv [75000] Open-Orca
FLAN v2 t0 [75000] Open-Orca
FLAN v2 flan [75000] Open-Orca

FLAN v2 CoT [75000] FLAN v2
FLAN v2 niv [75000] FLAN v2
FLAN v2 t0 [75000] FLAN v2
FLAN v2 flan [75000] FLAN v2
hamishivi commented 7 months ago

Hi, thanks for this useful figure! This seems right: we explicitly use the CoT subset of FLAN in addition to the rest of FLAN (as we wanted to 'upweight' the CoT data), and Open-Orca is derived from FLAN, but with completions from GPT models instead of the original labels (see https://huggingface.co/datasets/Open-Orca/OpenOrca). Its worth noting we also only use a subset of FLAN data.

There is training data from GSM8k in the FLAN data, but afaik there should be no test data.

nuoma commented 7 months ago

I made a quick tSNE visualization, indicating tulu v2 semantically include shareGPT, and roughly cover the semantic space of slim-orca. Does this look make sense to you?

image

code:

import os
import json
import numpy as np
from transformers import BertTokenizer, AutoModel
import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from tqdm import tqdm

os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3'  # Set your CUDA device

def load_text_data(input_file):
    texts = []
    with open(input_file, 'r', encoding='utf-8') as file:
        for line in file:
            data = json.loads(line)
            first_human_text = next((turn['value'] for turn in data['conversations'] if turn['from'] == 'human'), None)
            if first_human_text:
                texts.append(first_human_text)
    return texts

@torch.no_grad()
def compute_embeddings(texts, model, tokenizer, batch_size=100, max_length=512):
    model.eval()
    embeddings = []

    for i in tqdm(range(0, len(texts), batch_size), total=len(texts)//batch_size, desc="Computing embeddings"):
        batch_texts = texts[i:i + batch_size]
        encoded_input = tokenizer(batch_texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt').to('cuda')
        output = model(**encoded_input)

        cls_embeddings = output.last_hidden_state[:, 0, :].cpu().numpy()
        embeddings.append(cls_embeddings)

        torch.cuda.empty_cache()

    return np.concatenate(embeddings, axis=0)

def tsne_plot(embeddings1, embeddings2, embeddings3):
    tsne = TSNE(n_components=2, random_state=42)
    combined_embeddings = np.vstack((embeddings1, embeddings2, embeddings3))
    reduced_embeddings = tsne.fit_transform(combined_embeddings)

    plt.figure(figsize=(12, 8))
    end1 = len(embeddings1)
    end2 = end1 + len(embeddings2)

    plt.scatter(reduced_embeddings[:end1, 0], reduced_embeddings[:end1, 1], color='blue', label='slim orca', alpha=0.5)
    plt.scatter(reduced_embeddings[end1:end2, 0], reduced_embeddings[end1:end2, 1], color='red', label='tulu v2', alpha=0.5)
    plt.scatter(reduced_embeddings[end2:, 0], reduced_embeddings[end2:, 1], color='green', label='sharegpt v3 53k', alpha=0.5)

    plt.title("t-SNE visualization of three text datasets")
    plt.xlabel("t-SNE dimension 1")
    plt.ylabel("t-SNE dimension 2")

    # Customize legend
    plt.legend(title="Datasets", loc="best", frameon=False)

    plt.savefig('tsne_three_datasets_comparison.png')

def main(dataset_path1, dataset_path2, dataset_path3):
    tokenizer = BertTokenizer.from_pretrained('../models/bert-base-uncased')
    model = AutoModel.from_pretrained('../models/bert-base-uncased').to('cuda')

    texts1 = load_text_data(dataset_path1)
    texts2 = load_text_data(dataset_path2)
    texts3 = load_text_data(dataset_path3)

    embeddings1 = compute_embeddings(texts1, model, tokenizer)
    embeddings2 = compute_embeddings(texts2, model, tokenizer)
    embeddings3 = compute_embeddings(texts3, model, tokenizer)

    tsne_plot(embeddings1, embeddings2, embeddings3)

if __name__ == "__main__":
    dataset_path1 = '../1220_slimorca_Sharegpt.jsonl'
    dataset_path2 = '../231201_tuluv2_sharegpt.jsonl'
    dataset_path3 = '../sharegpt_V3_format.jsonl'
    main(dataset_path1, dataset_path2, dataset_path3)
nuoma commented 7 months ago

Also, a follow-up. I understand the scope of this study is the dataset itself. But I'm still very curious about one topic. During SFT training, you guys basically use the entire tulu for traininig. Have you guys thought of using a subsample collection of simple(low ppl) single turn dialogues for one epoch, then move on to larger and more complex dataset for like two more epochs? Basically the same idea of curriculum learning. Do you think this strategy will result in a better chat model? I do see several guys mentioned this, but as their little secret personal recipe without any publication to support this idea.

hamishivi commented 7 months ago

Hi, sorry for the delayed response - was travelling!

The t-SNE looks reasonable, I think its kind of hard to reason about the many dimensions data can have with these visualisations, but showing that Tulu captures the diversity of sharegpt and orca (which are parts of the tulu mix) makes sense.

With curriculum learning, we haven't tried anything there, but I think it's an interesting avenue for future work. I think papers like skill-it show promise, but AFAIK no-one has shown this to work for this generalized chat setting in a paper with modern performant models and data.