Git Product home page Git Product logo

toolformer-pytorch's Introduction

Toolformer - Pytorch (wip)

Implementation of Toolformer, Language Models That Can Use Tools, by MetaAI

Appreciation

  • Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research

  • Enrico for getting the ball rolling with the initial commit of different tools!

  • Thanks goes out to ChatGPT for doing all the regular expressions in this repository for parsing the functions and parameters for the API calls. I am terrible at regular expressions, so this was enormous help from the AI (with no hitches, it was perfect).

Install

$ pip install toolformer-pytorch

Usage

Example usage with giving language models awareness of current date and time.

import torch
from toolformer_pytorch import Toolformer, PaLM

# simple calendar api call - function that returns a string

def Calendar():
    import datetime
    from calendar import day_name, month_name
    now = datetime.datetime.now()
    return f'Today is {day_name[now.weekday()]}, {month_name[now.month]} {now.day}, {now.year}.'

# prompt for teaching it to use the Calendar function from above

prompt = f"""
Your task is to add calls to a Calendar API to a piece of text.
The API calls should help you get information required to complete the text.
You can call the API by writing "[Calendar()]"
Here are some examples of API calls:
Input: Today is the first Friday of the year.
Output: Today is the first [Calendar()] Friday of the year.
Input: The president of the United States is Joe Biden.
Output: The president of the United States is [Calendar()] Joe Biden.
Input: [input]
Output: 
"""

data = [
    "The store is never open on the weekend, so today it is closed.",
    "The number of days from now until Christmas is 30",
    "The current day of the week is Wednesday."
]

# model - here using PaLM, but any nn.Module that returns logits in the shape (batch, seq, num_tokens) is fine

model = PaLM(
    dim = 512,
    depth = 2,
    heads = 8,
    dim_head = 64
).cuda()

# toolformer

toolformer = Toolformer(
    model = model,
    model_seq_len = 256,
    teach_tool_prompt = prompt,
    tool_id = 'Calendar',
    tool = Calendar,
    finetune = True
)

# invoking this will
# (1) prompt the model with your inputs (data), inserted into [input] tag
# (2) with the sampled outputs, filter out the ones that made proper API calls
# (3) execute the API calls with the `tool` given
# (4) filter with the specialized filter function (which can be used independently as shown in the next section)
# (5) fine-tune on the filtered results

filtered_stats = toolformer(data)

# then, once you see the 'finetune complete' message

response = toolformer.sample_model_with_api_calls("How many days until the next new years?")

# hopefully you see it invoke the calendar and utilize the response of the api call...

The main novelty of the paper is defining a fitness score for the outputs from a transformer instructed to insert API calls. The score is used to filter the sampled outputs for finetuning the transformer to make API calls that decreases perplexity of the text that follows it.

import torch

from toolformer_pytorch import (
    Toolformer,
    PaLM,
    filter_tokens_with_api_response
)

# model

palm = PaLM(
    dim = 512,
    num_tokens = 20000,
    depth = 2,
    heads = 8,
    dim_head = 64
).cuda()

# mock some tokens

mock_start_pos = 512
mock_api_call_length = 10
mock_api_start_id = 19998
mock_api_stop_id = 19999

tokens = torch.randint(0, 20000, (10, 1024)).cuda()
tokens_with_api_response = torch.randint(0, 20000, (10, 1024)).cuda()
tokens_without_api_response = torch.randint(0, 20000, (10, 1024)).cuda()

tokens_with_api_response[:, mock_start_pos] = mock_api_start_id
tokens_with_api_response[:, mock_start_pos + mock_api_call_length] = mock_api_stop_id

tokens_without_api_response[:, mock_start_pos] = mock_api_start_id
tokens_without_api_response[:, mock_start_pos + mock_api_call_length] = mock_api_stop_id

# filter

filtered_results = filter_tokens_with_api_response(
    model = palm,
    tokens = tokens,
    tokens_with_api_response = tokens_with_api_response,
    tokens_without_api_response = tokens_without_api_response,
    filter_threshold = 1.,
    api_start_token_id = mock_api_start_id,
    api_end_token_id = mock_api_stop_id
)

To invoke the tools on a string generated by the language model, use invoke_tools

from toolformer_pytorch import invoke_tools

def inc(i):
    return i + 1

def dec(i):
    return i - 1

function_registry = dict(
    inc = inc,
    dec = dec
)

text = 'make the following api calls: [inc(1)] and [dec(2)] and [ignored(3)]'

invoke_tools(function_registry, text)

# make the following api calls: [inc(1) → 2] and [dec(2) → 1] and [ignored(3)]

Todo

  • create custom generate function for palm that can do external API calls
    • allow for generating tokens at different cursor indices
    • api token (which was left and right brackets in paper) needs to be customizable
    • allow for customizing how to fine handling errors in function name, parameters, or execution and output
  • Toolformer should eventually calculate all statistics (how many properly sampled, filtered out by different criterias, the distribution of scores as well as how many were rejected) before the final fine-tuning
  • do end-to-end training in Toolformer
    • doing the prompting and bootstrapping the data
    • prefiltering of bootstrapped data followed by api calls and then another round of filtering
      • keep track of all stats
    • take care of fine-tuning
      • interleaving of datasets + optimizer hyperparams
  • hook up gpt-j
  • test for a simple calculator eval dataset
  • add a default callback within the Toolformer that automatically aligns the text and checks for validity before the filtering step - if the text was not copied correctly, the filtering step is not valid.
  • make sure final model, trained on many Toolformer instances, can be invoked with multiple tools - start with batch size of 1 and work way up

Citations

@inproceedings{Schick2023ToolformerLM,
    title   = {Toolformer: Language Models Can Teach Themselves to Use Tools},
    author  = {Timo Schick and Jane Dwivedi-Yu and Roberto Dessi and Roberta Raileanu and Maria Lomeli and Luke Zettlemoyer and Nicola Cancedda and Thomas Scialom},
    year    = {2023}
}
@article{Gao2022PALPL,
    title   = {PAL: Program-aided Language Models},
    author  = {Luyu Gao and Aman Madaan and Shuyan Zhou and Uri Alon and Pengfei Liu and Yiming Yang and Jamie Callan and Graham Neubig},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2211.10435}
}

Reality is that which, when you stop believing it, doesn't go away. – Philip K. Dick.

toolformer-pytorch's People

Contributors

conceptofmind avatar lucidrains avatar murthyn avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

toolformer-pytorch's Issues

TypeError: 'type' object is not subscriptable

Description

After installing toolformer with "$ pip install toolformer-pytorch", I tried to run the example.
When I call "from toolformer_pytorch import Toolformer, PaLM", I have got a type error.

Screenshots

import torch
from toolformer_pytorch import Toolformer, PaLM
Traceback (most recent call last):
File "", line 1, in
File "/home/memect/ck/workspace/toolformer-pytorch/toolformer_pytorch/init.py", line 3, in
from toolformer_pytorch.toolformer_pytorch import (
File "/home/memect/ck/workspace/toolformer-pytorch/toolformer_pytorch/toolformer_pytorch.py", line 110, in
registry: dict[str, Callable],
TypeError: 'type' object is not subscriptable

Loss Weights Calculation Error: toolformer_pytorch.py#L427

When I needed the original three loss scores, I return them and printed these loss values, and I found that the values were all very large, which was not normal.So I checked the calculation method of loss and found that there seemed to be some errors when calculating weights.

In your code toolformer_pytorch.py line 427, you use 0 replace the elements in weights when weights == pad_id(-1), but when the elements equal to -1 in t (get from get_arange_start_at_token_id()) pass through the weighting_fn(), they all become 1.2. There is no more element equal to -1 in weights.
So when we calculate the weighted loss value, the position (ex, before ) we don't want to calculate is 1.2, which makes the final loss value very large.

Just change the condition toolformer_pytorch.py line 427 to (t == pad_id) or eights == 1.2 can fix this.

issues about invoke_tools

My invoke_tools function can't get right outputs, but when I replace "regex = create_function_regex(wapi_start, api_stop)" line in function with regex = create_function_regex(), things get better.
I don't know why but it does work.

About the Filtering API Calls implementation

Hello, In the origin paper, the author wrote "We provide e(ci, ri) as a prefix instead of inserting it at position i because M is not yet finetuned on any examples containing API calls, so inserting it in the middle of x would interrupt the flow and not align with patterns in the pretraining corpus, thus hurting perplexity." on the footnote of page 3.

However I found that u seem to inserting API call in the origin sentence when calculating loss.

What may I miss?

Misplaced API Calls , which can pass the filtering step.

Consider the input:
When I wanted to implement Toolformer, I found this problem:

Considering this input:
In one hour, there are 3 sets of 20 minutes.
So, Joy can read 8 x 3 = 24 pages in an hour.
It will take her 120/24 = 5 hours to read 120 pages.

With the API generation steps, it eventually becomes:
In one hour, there are 3 sets of 20 minutes.
So, Joy can read 8 x 3 = 24 pages in an hour.
It will take her [CALCULATOR(24 * 5) -> 120.00] 120/24 = 5 hours to read 120 pages.

This result is not expected, since the API calls CALCULATOR(24 * 5) includes the parameter number 5, which is mentioned actually after this API call.
I suppose this API is misplaced. However, it cannot be filtered with the filtering step, since this API call includes 24, 5 and 120, which originally appears in the back and hence does decrease the perplexity.

I want to know how to solve this problem?

Error when running readme example

import torch
from toolformer_pytorch import Toolformer, PaLM

# simple calendar api call - function that returns a string

def Calendar():
    import datetime
    from calendar import day_name, month_name
    now = datetime.datetime.now()
    return f'Today is {day_name[now.weekday()]}, {month_name[now.month]} {now.day}, {now.year}.'

# prompt for teaching it to use the Calendar function from above

prompt = f"""
Your task is to add calls to a Calendar API to a piece of text.
The API calls should help you get information required to complete the text.
You can call the API by writing "[Calendar()]"
Here are some examples of API calls:
Input: Today is the first Friday of the year.
Output: Today is the first [Calendar()] Friday of the year.
Input: The president of the United States is Joe Biden.
Output: The president of the United States is [Calendar()] Joe Biden.
Input: [input]
Output: 
"""

data = [
    "The store is never open on the weekend, so today it is closed.",
    "The number of days from now until Christmas is 30",
    "The current day of the week is Wednesday."
]

# model - here using PaLM, but any nn.Module that returns logits in the shape (batch, seq, num_tokens) is fine

model = PaLM(
    dim = 512,
    depth = 2,
    heads = 8,
    dim_head = 64
).to('mps')

# toolformer

toolformer = Toolformer(
    model = model,
    model_seq_len = 256,
    teach_tool_prompt = prompt,
    tool_id = 'Calendar',
    tool = Calendar,
    finetune = True
)

# invoking this will
# (1) prompt the model with your inputs (data), inserted into [input] tag
# (2) with the sampled outputs, filter out the ones that made proper API calls
# (3) execute the API calls with the `tool` given
# (4) filter with the specialized filter function (which can be used independently as shown in the next section)
# (5) fine-tune on the filtered results

filtered_stats = toolformer(data)

# then, once you see the 'finetune complete' message

response = toolformer.sample_model_with_api_calls("How many days until the next new years?")

# hopefully you see it invoke the calendar and utilize the response of the api call...

Error when running the above code

Traceback (most recent call last):
  File "/Users/nripeshniketan/Documents - Nripesh’s MacBook Pro/python_programs/toolformer-ivy/toolformer.py", line 60, in <module>
    filtered_stats = toolformer(data)
                     ^^^^^^^^^^^^^^^^
  File "/Users/nripeshniketan/Documents - Nripesh’s MacBook Pro/python_programs/toolformer-ivy/toolformer_dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nripeshniketan/Documents - Nripesh’s MacBook Pro/python_programs/toolformer-ivy/toolformer_dev/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<@beartype(toolformer_pytorch.toolformer_pytorch.Toolformer.forward) at 0x160673ec0>", line 41, in forward
  File "/Users/nripeshniketan/Documents - Nripesh’s MacBook Pro/python_programs/toolformer-ivy/toolformer_dev/lib/python3.11/site-packages/toolformer_pytorch/toolformer_pytorch.py", line 883, in forward
    assert len(filtered_data_with_api_calls) > 0, 'your model failed to follow instructions and make API calls. please try a better model or do some better prompt engineering'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: your model failed to follow instructions and make API calls. please try a better model or do some better prompt engineering

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.