alycialee / beyond-scale-language-data-diversity

Apache License 2.0
10 stars 12 forks source link

How are we guaranteeing that our model is right, shifting correctly? #22

Closed brando90 closed 1 hour ago

brando90 commented 1 hour ago

Given we have this manual loop:,

            for step, batch in tqdm(enumerate(data_loader), desc='Iter step', total=len(data_loader), leave=leave_pbar_on_screen):
                optimizer.zero_grad()
                inputs = {'input_ids': batch['input_ids'].to(device),
                        'attention_mask': batch['attention_mask'].to(device)}
                logits = self.model(**inputs, labels=inputs["input_ids"]).logits
                loss = self.loss_fn(logits, inputs["input_ids"], ignore_index=50256)
                print(f'\nInitial loss {loss.item()} ({step=} {epoch=})') if step == 0 else None
                error = get_error(logits, inputs['input_ids'], ignore_index=50256)

How are we guaranteeing the input is right shifted? @alycialee :

input=[3, 0, 1] # eos a b
label=[0, 1, 3] # a b eos
brando90 commented 1 hour ago

https://github.com/alycialee/beyond-scale-language-data-diversity/blob/main/src/diversity/task2vec.py#L46

brando90 commented 1 hour ago
"""
Details of TFA: 

ref: https://chatgpt.com/c/66f9c323-5ca0-8001-bfeb-c83fa53a45a6
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import fire

def compute_tfa(model, tokenizer, input_texts):
    # Tokenize input texts with padding to handle the batch
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
    inputs = tokenizer(input_texts, return_tensors='pt', padding=True, truncation=True)
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']

    # Get the model's output logits for the input_ids
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # Shape: (batch_size, sequence_length, vocab_size)

    # Get the ground truth token indices (input_ids shifted to the right by one time step)
    labels = input_ids[:, 1:].contiguous()  # e.g., ["a", "b"] --> ["b"] by chopping the 1st token from input

    # Remove the last token's logits (since there's no "next token" to predict for the last token), i.e., self-attn for final token produces a token that doesn't matter (or should be eos)
    logits = logits[:, :-1, :]

    # Calculate accuracy per token
    predicted_token_ids = torch.argmax(logits, dim=-1)  # Shape: (batch_size, sequence_length - 1)
    correct_predictions = (predicted_token_ids == labels)  # Boolean tensor indicating correct predictions
    accuracy_per_token = correct_predictions.float().mean().item()

    return accuracy_per_token

def main():
    # Define a small batch of input texts
    # input_texts = [
    #     "The quick brown fox jumps over the lazy dog.",
    #     "GPT-2 is a large transformer model for natural language processing."
    # ]
    input_texts = ['a', 'b']

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    model = AutoModelForCausalLM.from_pretrained('gpt2')

    # Compute TFA
    start_time = time.time()
    tfa_score = compute_tfa(model, tokenizer, input_texts)
    end_time = time.time()

    # Output the results
    print(f"Teacher-Forced Accuracy (TFA): {tfa_score:.4f}")
    print(f"Time taken: {end_time - start_time:.2f} seconds")

if __name__ == "__main__":
    fire.Fire(main)
brando90 commented 1 hour ago

for teacher forced accuracy tfa