TransformerLensOrg / TransformerLens

A library for mechanistic interpretability of GPT-style language models
https://transformerlensorg.github.io/TransformerLens/
MIT License
1.41k stars 270 forks source link

[Bug Report] `tokenize_and_concatenate` doesn't work with small datasets. #707

Open yash-srivastava19 opened 3 weeks ago

yash-srivastava19 commented 3 weeks ago

Describe the bug

It was mentioned in the docstrings as well that the tokenize_and_concatenate function doesn't work properly with small datasets. I wanted to figure out is there a workaround that can be used.

Note: There is a bug when inputting very small datasets (eg, <1 batch per process) where it just outputs nothing. I'm not super sure why.

Code example The dataset I'm using is a small dataset, and sometimes contains only single word. Here is the minimal code that reproduces the error.

from datasets.load import load_dataset
from transformer_lens.utils import tokenize_and_concatenate

#----------- Utility Functions -------------

def create_dataset(link):
    ds = load_dataset(
        path=f"link",
        split="train",
        streaming=False,
    )
    return ds

def get_tokens(dataset, tokenizer=model.tokenizer, streaming=True, max_length=sae.cfg.context_size, add_bos_token=sae.cfg.prepend_bos):
    return tokenize_and_concatenate(
        dataset = dataset,
        tokenizer = tokenizer,
        streaming=streaming,
        max_length=max_length,
        column_name=column_name,
        add_bos_token=add_bos_token
    )

#----------- Example Usage -------------
DATASET_1 = "link-to-big-dataset"
DATASET_2 = "link-to-small-dataset"

dataset_1 = create_dataset(DATASET_1)
dataset_1_tokens  = get_tokens(dataset_1) # This gets executed. No issues.

dataset_2 = create_dataset(DATASET_2)
dataset_2_tokens  = get_tokens(dataset_2) # This line breaks.

Here's what the error stack trace looks like :

...
File /opt/conda/lib/python3.10/site-packages/transformer_lens/utils.py:358, in tokenize_and_concatenate(dataset, tokenizer, streaming, max_length, column_name, add_bos_token, num_proc)
    350     return {"tokens": tokens}
    352 tokenized_dataset = dataset.map(
    353     tokenize_function,
    354     batched=True,
    355     num_proc=(num_proc if not streaming else None),
    356     remove_columns=[column_name],
    357 )
--> 358 tokenized_dataset.set_format(type="torch", columns=["tokens"])
    359 return tokenized_dataset

File /opt/conda/lib/python3.10/site-packages/datasets/fingerprint.py:482, in fingerprint_transform.<locals>._fingerprint.<locals>.wrapper(*args, **kwargs)
    478             validate_fingerprint(kwargs[fingerprint_name])
    480 # Call actual function
--> 482 out = func(dataset, *args, **kwargs)
    484 # Update fingerprint of in-place transforms + update in-place history of transforms
    486 if inplace:  # update after calling func so that the fingerprint doesn't change if the function fails

File /opt/conda/lib/python3.10/site-packages/datasets/arrow_dataset.py:2596, in Dataset.set_format(self, type, columns, output_all_columns, **format_kwargs)
   2594     missing_columns = set(columns) - set(self._data.column_names)
   2595     if missing_columns:
-> 2596         raise ValueError(
   2597             f"Columns {list(missing_columns)} not in the dataset. Current columns in the dataset: {self._data.column_names}"
   2598         )
   2599 if columns is not None:
   2600     columns = columns.copy()  # Ensures modifications made to the list after this call don't cause bugs

ValueError: Columns ['tokens'] not in the dataset. Current columns in the dataset: ['text']

This works perfectly well for the DATASET_1, but for DATASET_2, it breaks.

System Info Describe the characteristic of your environment:

Checklist

yash-srivastava19 commented 2 weeks ago

After tinkering with the tokenize_and_concatenate function a little bit, I was able to work around(for my case), after removing the chunking part from the code. The number of batches for small datasets is 0, and that creates a problem. Here's the refactored code. If possible, can you tell whether this approach is ok?

...

def tokenize_and_concatenate(
    dataset,
    tokenizer,
    streaming: bool = False,
    max_length: int = 1024,
    column_name: str = "text",
    add_bos_token: bool = True,
    num_proc: int = 10,
):
    """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end.

    This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these)

    Args:
        dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset.
        tokenizer (AutoTokenizer): The tokenizer. Assumed to have a bos_token_id and an eos_token_id.
        streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False.
        max_length (int, optional): The length of the context window of the sequence. Defaults to 1024.
        column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'.
        add_bos_token (bool, optional): . Defaults to True.

    Returns:
        Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens"

    Note: There is a bug when inputting very small datasets (eg, <1 batch per process) where it just outputs nothing. I'm not super sure why
    """
    dataset = keep_single_column(dataset, column_name)
    if tokenizer.pad_token is None:
        # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model.
        tokenizer.add_special_tokens({"pad_token": "<PAD>"})
    # Define the length to chop things up into - leaving space for a bos_token if required
    if add_bos_token:
        seq_len = max_length - 1
    else:
        seq_len = max_length

    def tokenize_function(examples):
        text = examples[column_name]
        # Concatenate it all into an enormous string, separated by eos_tokens
        full_text = tokenizer.eos_token.join(text)

        tokens = tokenizer(full_text, return_tensors="np", padding=True)["input_ids"].flatten() # instead of chunking, just do it  for the full text.
        # Drop padding tokens
        tokens = tokens[tokens != tokenizer.pad_token_id]

        num_tokens = len(tokens)
        num_batches = num_tokens // (seq_len)

        # Drop the final tokens if not enough to make a full sequence
        tokens = tokens[: seq_len * num_batches] if num_batches else tokens

        if add_bos_token:
            if num_batches:  # if num_batches are not zero, proceed the standard way 
                tokens = einops.rearrange(tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len)
                prefix = np.full((num_batches, 1), tokenizer.bos_token_id)
                tokens = np.concatenate([prefix, tokens], axis=1)
            else: 
                tokens = np.array(tokens) # return the numpy array otherwise.
        return {"tokens": tokens}

    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        num_proc=(num_proc if not streaming else None),
        remove_columns=[column_name],
    )
    tokenized_dataset.set_format(type="torch", columns=["tokens"])
    return tokenized_dataset