lucidrains / CALM-pytorch

Implementation of CALM from the paper "LLM Augmented LLMs: Expanding Capabilities through Composition", out of Google Deepmind
MIT License
170 stars 9 forks source link
artificial-intelligence attention-mechanisms cross-attention deep-learning transformers

CALM - Pytorch

Implementation of CALM from the paper LLM Augmented LLMs: Expanding Capabilities through Composition, out of Google Deepmind

Can support any number of augmentation LLMs


$ pip install CALM-pytorch



ex. with x-transformers

import torch
from x_transformers import TransformerWrapper, Decoder

augment_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8

anchor_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 2,
        heads = 8

# import CALM wrapper

from CALM_pytorch import CALM, AugmentParams

calm = CALM(
    augment_llms = AugmentParams(
        model = augment_llm,
        connect_every_num_layers = 4

# mock input

seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()
prompt = torch.randint(0, 20000, (1, 256))

# forward for finetuning loss

loss = calm(
    mask = mask,
    prompt = prompt


# after much training, prompt the composed model

generated = calm.generate(
    prompt = seq[:, :1],
    seq_len = 1024

To use a handy trainer class using 🤗 Accelerate, just import FineTuner and use as follows

trainer = FineTuner(
    calm = calm,
    dataset = dataset,   # returns a dictionary of input kwargs to calm - dict(seq: Tensor, mask: Tensor, prompt: Tensor). it can also return a Tuple, in which data_kwargs needs to be set to the correct ordered value of kwarg names
    batch_size = 16,
    num_train_steps = 10000,
    learning_rate = 3e-4,
    weight_decay = 1e-2,
    warmup_steps = 1000,
    checkpoint_every = 1000


# checkpoints of the cross attention parameters will be saved to ./checkpoints every 1000 steps

To explore multiple augmentation LLMs, simply pass in a list for augment_llm


calm = CALM(
    anchor_llm = anchor_llm,
    augment_llm = [AugmentParams(augment_llm1), AugmentParams(augment_llm2)] # pass in a list of AugmentParams wrapping model and other hparams specific to that transformer

Say you want to explore different types of connectivity between anchor and augmentation model(s), just pass in the connections as a tuple of tuple integer pairs, specifying the anchor to augment layer number.

calm = CALM(
    anchor_llm = anchor_llm,
    augment_llms = (
            model = augment_llm1,
            connections = (
                (1, 12),  # 1st layer of augment llm1 attended to by 12th layer of anchor llm
                (2, 12),
                (3, 12),
                (4, 12),
            model = augment_llm2,
            connections = (
                (6, 1),   # 6th layer of augment llm2 attended to by 1st layer of anchor llm
                (6, 2),
                (12, 12),

CALM setup with 2 specialized augmentation LLMs + a vision transformer

import torch

# pip install vit-pytorch x-transformers

from vit_pytorch.vit import ViT, Attention
from x_transformers import TransformerWrapper, Encoder, Decoder

anchor_llm = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8

augment_llm1 = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8

augment_llm2 = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 16,
        dim_head = 2,
        depth = 12,
        heads = 8

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 256,
    depth = 6,
    heads = 16,
    mlp_dim = 2048

# calm

from CALM_pytorch import CALM, AugmentParams, FineTuner

calm = CALM(
    anchor_llm = anchor_llm,
    augment_llms = (
            model = augment_llm1,
            mask_kwarg = 'mask'
            model = augment_llm2,
            mask_kwarg = 'mask'
            model = vit,
            input_shape = (3, 256, 256),
            hidden_position = 'input',
            extract_blocks_fn = lambda vit: [m for m in vit.modules() if isinstance(m, Attention)]
    attn_kwargs = dict(
        linear_project_context = True,
        pre_rmsnorm = True,
        flash = True

seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()

prompt = (
    torch.randint(0, 20000, (1, 256)),
    torch.randint(0, 20000, (1, 256)),
    torch.randn(1, 3, 256, 256)

loss = calm(
    mask = mask,
    prompt = prompt




  title   = {LLM Augmented LLMs: Expanding Capabilities through Composition},
  author  = {Rachit Bansal and Bidisha Samanta and Siddharth Dalmia and Nitish Gupta and Shikhar Vashishth and Sriram Ganapathy and Abhishek Bapna and Prateek Jain and Partha Pratim Talukdar},
  year    = {2024},
  url     = {}