Git Product home page Git Product logo

multitask-prompting's Introduction

multitask-prompting

Overview

Prompting uses the capability of large language models (LLM's) to "fill in the blank" in order to classify the meaning of text. I conducted research on how a single model can use prompting to simultaneously learn multiple tasks. RoBERTa was the primary model I experimented with. Here are the results. This repository contains the code I used to train and evaluate models during my experiments.

Here is a diagram from this paper that explains the difference between prompting and head-based fine-tuning for language models. image The paper How Many Data Points is a Prompt Worth? performs some cool experiments showing the power of prompts in low resource settings.

RobertaPrompt

RobertaPrompt wraps around HuggingFace's RobertaForMaskedLM class and allows developers to train and test a Roberta model using prompting based on a prompt definition

Suppose we have two tasks: given an argument and a topic, we must detect if the argument was in support or against the topic, and also whether or not the argument contained a fallacy (if it does, what exact fallacy the argument contains).

A prompt definition would contain:

  1. A template for each task. A template is a consistent text pattern associated with a task so the model recognizes which task needs to be completed. For example,
"Stance detection task. Topic: {insert topic here} and Argument: {insert argument here}. The stance is: <mask>"
  1. A policy function for each task. The policy function maps the token that the model uses to fill in the blank with the predicted label.

My experiments trained a Roberta model to accomplish the exact tasks mentioned above - you can take a look at some example predictions in the prompting_example.ipynb notebook

Training and Testing

First, load a base model. A GPU as the device is highly reccomended.

pmodel = RobertaPrompt(model='roberta-large', device = torch.device('cuda'), prompt = argument_prompt)

Start training immediately by specifying the paths to a training and validation dataset. Training statistics will be displayed in stdout.

pmodel.train("sample_train_set.tsv", "sample_val_set.tsv", output_dir="sample_model", epochs=10)

After training is finished, evaluate the model on a test set using the following function and save the test results

pmodel.test("sample_test_set.tsv", save_path='stats.txt')

You should see text content in this format in the file specificed by save_path. Overall f1 scores are included, along with more fine-grained statistics on model performance for each label

image

One can then use this model and fine-tune it on other tasks with different prompts.

Data

Sample data for fallacious argument and stance detection is from Argotario.

multitask-prompting's People

Contributors

lievan avatar

Stargazers

Yotam avatar

Watchers

 avatar Yotam avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.