lucidrains / CALM-pytorch

Implementation of CALM from the paper "LLM Augmented LLMs: Expanding Capabilities through Composition", out of Google Deepmind
MIT License
170 stars 9 forks source link
artificial-intelligence attention-mechanisms cross-attention deep-learning transformers

CALM - Pytorch

Implementation of CALM from the paper LLM Augmented LLMs: Expanding Capabilities through Composition, out of Google Deepmind

Can support any number of augmentation LLMs

Install

$ pip install CALM-pytorch

Appreciation

Usage

ex. with x-transformers

import torch
from x_transformers import TransformerWrapper, Decoder

augment_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

anchor_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 2,
        heads = 8
    )
)

# import CALM wrapper

from CALM_pytorch import CALM, AugmentParams

calm = CALM(
    anchor_llm,
    augment_llms = AugmentParams(
        model = augment_llm,
        connect_every_num_layers = 4
    )
)

# mock input

seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()
prompt = torch.randint(0, 20000, (1, 256))

# forward for finetuning loss

loss = calm(
    seq,
    mask = mask,
    prompt = prompt
)

loss.backward()

# after much training, prompt the composed model

generated = calm.generate(
    prompt = seq[:, :1],
    seq_len = 1024
)

To use a handy trainer class using 🤗 Accelerate, just import FineTuner and use as follows

trainer = FineTuner(
    calm = calm,
    dataset = dataset,   # returns a dictionary of input kwargs to calm - dict(seq: Tensor, mask: Tensor, prompt: Tensor). it can also return a Tuple, in which data_kwargs needs to be set to the correct ordered value of kwarg names
    batch_size = 16,
    num_train_steps = 10000,
    learning_rate = 3e-4,
    weight_decay = 1e-2,
    warmup_steps = 1000,
    checkpoint_every = 1000
)

trainer()

# checkpoints of the cross attention parameters will be saved to ./checkpoints every 1000 steps

To explore multiple augmentation LLMs, simply pass in a list for augment_llm

ex.

calm = CALM(
    anchor_llm = anchor_llm,
    augment_llm = [AugmentParams(augment_llm1), AugmentParams(augment_llm2)] # pass in a list of AugmentParams wrapping model and other hparams specific to that transformer
)

Say you want to explore different types of connectivity between anchor and augmentation model(s), just pass in the connections as a tuple of tuple integer pairs, specifying the anchor to augment layer number.

calm = CALM(
    anchor_llm = anchor_llm,
    augment_llms = (
        AugmentParams(
            model = augment_llm1,
            connections = (
                (1, 12),  # 1st layer of augment llm1 attended to by 12th layer of anchor llm
                (2, 12),
                (3, 12),
                (4, 12),
            ),
        ),
        AugmentParams(
            model = augment_llm2,
            connections = (
                (6, 1),   # 6th layer of augment llm2 attended to by 1st layer of anchor llm
                (6, 2),
                (12, 12),
            )
        )
    )
)

CALM setup with 2 specialized augmentation LLMs + a vision transformer

import torch

# pip install vit-pytorch x-transformers

from vit_pytorch.vit import ViT, Attention
from x_transformers import TransformerWrapper, Encoder, Decoder

anchor_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8
    )
)

augment_llm1 = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8
    )
)

augment_llm2 = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8
    )
)

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 256,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

# calm

from CALM_pytorch import CALM, AugmentParams, FineTuner

calm = CALM(
    anchor_llm = anchor_llm,
    augment_llms = (
        AugmentParams(
            model = augment_llm1,
            mask_kwarg = 'mask'
        ),
        AugmentParams(
            model = augment_llm2,
            mask_kwarg = 'mask'
        ),
        AugmentParams(
            model = vit,
            input_shape = (3, 256, 256),
            hidden_position = 'input',
            extract_blocks_fn = lambda vit: [m for m in vit.modules() if isinstance(m, Attention)]
        )
    ),
    attn_kwargs = dict(
        linear_project_context = True,
        pre_rmsnorm = True,
        flash = True
    )
)

seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()

prompt = (
    torch.randint(0, 20000, (1, 256)),
    torch.randint(0, 20000, (1, 256)),
    torch.randn(1, 3, 256, 256)
)

loss = calm(
    seq,
    mask = mask,
    prompt = prompt
)

loss.backward()

Todo

Citations

@inproceedings{Bansal2024LLMAL,
  title   = {LLM Augmented LLMs: Expanding Capabilities through Composition},
  author  = {Rachit Bansal and Bidisha Samanta and Siddharth Dalmia and Nitish Gupta and Shikhar Vashishth and Sriram Ganapathy and Abhishek Bapna and Prateek Jain and Partha Pratim Talukdar},
  year    = {2024},
  url     = {https://api.semanticscholar.org/CorpusID:266755751}
}