lucidrains / toolformer-pytorch

Implementation of Toolformer, Language Models That Can Use Tools, by MetaAI
MIT License
1.94k stars 124 forks source link
api-calling artificial-intelligence attention-mechanisms deep-learning transformers

Toolformer - Pytorch (wip)

Implementation of Toolformer, Language Models That Can Use Tools, by MetaAI

Appreciation

Install

$ pip install toolformer-pytorch

Usage

Example usage with giving language models awareness of current date and time.

import torch
from toolformer_pytorch import Toolformer, PaLM

# simple calendar api call - function that returns a string

def Calendar():
    import datetime
    from calendar import day_name, month_name
    now = datetime.datetime.now()
    return f'Today is {day_name[now.weekday()]}, {month_name[now.month]} {now.day}, {now.year}.'

# prompt for teaching it to use the Calendar function from above

prompt = f"""
Your task is to add calls to a Calendar API to a piece of text.
The API calls should help you get information required to complete the text.
You can call the API by writing "[Calendar()]"
Here are some examples of API calls:
Input: Today is the first Friday of the year.
Output: Today is the first [Calendar()] Friday of the year.
Input: The president of the United States is Joe Biden.
Output: The president of the United States is [Calendar()] Joe Biden.
Input: [input]
Output: 
"""

data = [
    "The store is never open on the weekend, so today it is closed.",
    "The number of days from now until Christmas is 30",
    "The current day of the week is Wednesday."
]

# model - here using PaLM, but any nn.Module that returns logits in the shape (batch, seq, num_tokens) is fine

model = PaLM(
    dim = 512,
    depth = 2,
    heads = 8,
    dim_head = 64
).cuda()

# toolformer

toolformer = Toolformer(
    model = model,
    model_seq_len = 256,
    teach_tool_prompt = prompt,
    tool_id = 'Calendar',
    tool = Calendar,
    finetune = True
)

# invoking this will
# (1) prompt the model with your inputs (data), inserted into [input] tag
# (2) with the sampled outputs, filter out the ones that made proper API calls
# (3) execute the API calls with the `tool` given
# (4) filter with the specialized filter function (which can be used independently as shown in the next section)
# (5) fine-tune on the filtered results

filtered_stats = toolformer(data)

# then, once you see the 'finetune complete' message

response = toolformer.sample_model_with_api_calls("How many days until the next new years?")

# hopefully you see it invoke the calendar and utilize the response of the api call...

The main novelty of the paper is defining a fitness score for the outputs from a transformer instructed to insert API calls. The score is used to filter the sampled outputs for finetuning the transformer to make API calls that decreases perplexity of the text that follows it.

import torch

from toolformer_pytorch import (
    Toolformer,
    PaLM,
    filter_tokens_with_api_response
)

# model

palm = PaLM(
    dim = 512,
    num_tokens = 20000,
    depth = 2,
    heads = 8,
    dim_head = 64
).cuda()

# mock some tokens

mock_start_pos = 512
mock_api_call_length = 10
mock_api_start_id = 19998
mock_api_stop_id = 19999

tokens = torch.randint(0, 20000, (10, 1024)).cuda()
tokens_with_api_response = torch.randint(0, 20000, (10, 1024)).cuda()
tokens_without_api_response = torch.randint(0, 20000, (10, 1024)).cuda()

tokens_with_api_response[:, mock_start_pos] = mock_api_start_id
tokens_with_api_response[:, mock_start_pos + mock_api_call_length] = mock_api_stop_id

tokens_without_api_response[:, mock_start_pos] = mock_api_start_id
tokens_without_api_response[:, mock_start_pos + mock_api_call_length] = mock_api_stop_id

# filter

filtered_results = filter_tokens_with_api_response(
    model = palm,
    tokens = tokens,
    tokens_with_api_response = tokens_with_api_response,
    tokens_without_api_response = tokens_without_api_response,
    filter_threshold = 1.,
    api_start_token_id = mock_api_start_id,
    api_end_token_id = mock_api_stop_id
)

To invoke the tools on a string generated by the language model, use invoke_tools

from toolformer_pytorch import invoke_tools

def inc(i):
    return i + 1

def dec(i):
    return i - 1

function_registry = dict(
    inc = inc,
    dec = dec
)

text = 'make the following api calls: [inc(1)] and [dec(2)] and [ignored(3)]'

invoke_tools(function_registry, text)

# make the following api calls: [inc(1) → 2] and [dec(2) → 1] and [ignored(3)]

Todo

Citations

@inproceedings{Schick2023ToolformerLM,
    title   = {Toolformer: Language Models Can Teach Themselves to Use Tools},
    author  = {Timo Schick and Jane Dwivedi-Yu and Roberto Dessi and Roberta Raileanu and Maria Lomeli and Luke Zettlemoyer and Nicola Cancedda and Thomas Scialom},
    year    = {2023}
}
@article{Gao2022PALPL,
    title   = {PAL: Program-aided Language Models},
    author  = {Luyu Gao and Aman Madaan and Shuyan Zhou and Uri Alon and Pengfei Liu and Yiming Yang and Jamie Callan and Graham Neubig},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2211.10435}
}

Reality is that which, when you stop believing it, doesn't go away. – Philip K. Dick.