Git Product home page Git Product logo

gense's Introduction

GenSE

Official implementation of EMNLP 2022 paper Generate, Discriminate, and Contrast: A Semi-Supervised Sentence Representation Learning Framework

Overview

We propose a semi-supervised sentence embedding framework, GenSE, that effectively leverages large-scale unlabeled data. Our method include three parts:

  • Generate: A generator/discriminator model is jointly trained to synthesize sentence pairs from open-domain unlabeled corpus.
  • Discriminate: Noisy sentence pairs are filtered out by the discriminator to acquire high-quality positive and negative sentence pairs.
  • Contrast: A prompt-based contrastive approach is presented for sentence representation learning with both annotated and synthesized data.

Requirements

To run our code, please install all the dependency packages by using the following command:

pip install -r requirements.txt

Data Synthesis

We train a unified T5 model for data generation and discrimination. Details about data synthesis can be found in data_synthesis/README.md.

GenSE Training & Evaluation

After data synthesis, we can train GenSE sentence embedding model follows gense_training/README.md

Model List

All of our pre-trained models are now available from huggingface hub:

Model
mattymchen/nli-synthesizer-t5-base
mattymchen/gense-base
mattymchen/gense-base-plus

Example Usage

Use with GenSE

We provide a simple package, which can be used to generate NLI triplets and compute sentence embeddings:

pip install gense
Generation & Discrimination
from gense import Synthesizer

synthesizer = Synthesizer('mattymchen/nli-synthesizer-t5-base')
input_sents = [
    'The task of judging the best was not easy.',
    'A man plays the piano.'
]

# generate NLI triplets
triplets = synthesizer.generate_triplets(input_sents)

# filter triplets
filtered_triplets = synthesizer.filter_triplets(triplets)

print(filtered_triplets)
Sentence Embedding
from gense import GenSE
gense = GenSE('mattymchen/gense-base-plus')
example_sentences = [
    'An animal is biting a persons finger.',
    'A woman is reading.',
    'A man is lifting weights in a garage.',
    'A man plays the violin.',
    'A man is eating food.',
    'A man plays the piano.',
    'A panda is climbing.',
    'A man plays a guitar.',
    'A woman is slicing a meat.',
    'A woman is taking a picture.'
]
example_queries = [
    'A man is playing music.',
    'A woman is making a photo.'
]

Encode sentence

print(gense.encode(example_sentences))

Compute cosine similarity

similarities = gense.similarity(example_queries, example_sentences)
print(similarities)

Semantic search

gense.build_index(example_sentences, use_faiss=True)
results = gense.search(example_queries)
for i, result in enumerate(results):
    print("Retrieval results for query: {}".format(example_queries[i]))
    for sentence, score in result:
        print("    {}  (cosine similarity: {:.4f})".format(sentence, score))

Use with Huggingface

Alternatively, you can also directly use GenSE with huggingface transformers.

Generation & Discrimination
from transformers import T5ForConditionalGeneration, AutoTokenizer

# load data synthesis model
synthesis_model = T5ForConditionalGeneration.from_pretrained('mattymchen/nli-synthesizer-t5-base')
synthesis_tokenizer = AutoTokenizer.from_pretrained('mattymchen/nli-synthesizer-t5-base')

# prepare inputs
input_sents = [
    'Write two sentences that are entailment. Sentence 1: \"The task of judging the best was not easy.\"Sentence 2:',
    'Write two sentences that are contradictory. Sentence 1: \"The task of judging the best was not easy.\"Sentence 2:',
    'if \"The task of judging the best was not easy.\", does this mean that \" It was difficult to judge the best.\"? true or false',
    'if \"The task of judging the best was not easy.\", does this mean that \" It was easy to judge the best.\"? true or false'
]
input_features = synthesis_tokenizer(input_sents, add_special_tokens=True, padding=True, return_tensors='pt')

# generation
outputs = synthesis_model.generate(**input_features, top_p=0.9) 

# Outputs:
# It was difficult to judge the best.
# It was easy to judge the best.
# true
# false
print(synthesis_tokenizer.batch_decode(outputs, skip_special_tokens=True))
Sentence Embedding
import torch
from transformers import T5Model, AutoTokenizer

# load embedding model
embedding_model = T5Model.from_pretrained('mattymchen/gense-base-plus').eval()
embedding_tokenizer = AutoTokenizer.from_pretrained('mattymchen/gense-base-plus')

# prepare inputs
input_sents = [
    'The task of judging the best was not easy. Question: what can we draw from the above sentence?',
]
input_features = embedding_tokenizer(input_sents, add_special_tokens=True, padding=True, return_tensors='pt')
decoder_start_token_id = embedding_model._get_decoder_start_token_id()
input_features['decoder_input_ids'] = torch.full([input_features['input_ids'].shape[0], 1], decoder_start_token_id)

# inference
with torch.no_grad():
    outputs = embedding_model(**input_features, output_hidden_states=True, return_dict=True)
    last_hidden = outputs.last_hidden_state
    sent_embs = last_hidden[:, 0].cpu()
print(sent_embs)

Synthetic Data

We run our unified data synthesis model on open domain unlabeled sentences to obtain synthetic NLI triplets for GenSE training.

The resulting synthetic dataset SyNLI contains around 61M NLI triplets, which can now be downloaded from huggingface hub:

from datasets import load_dataset

dataset = load_dataset("mattymchen/synli")

Citation

Please cite our paper if you use GenSE in your work:

@inproceedings{chen2022gense,
  title={Generate, Discriminate and Contrast: A Semi-Supervised Sentence Representation Learning Framework},
  author={Chen, Yiming and Zhang, Yan and Wang, Bin and Liu, Zuozhu and Li, Haizhou},
  booktitle={Empirical Methods in Natural Language Processing (EMNLP)},
  year={2022}
}

Acknowledgement

Code is implemented based on SimCSE. We would like to thank the authors of SimCSE for making their code public.

gense's People

Contributors

matthewcym avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

muxichu

gense's Issues

Release Synthetic Corpus

Hi Yiming, great work for synthetic corpus generation on STS task! I have a simple request, will you be able to release the generated synthetic corpus? Many thanks for your contribution!

Why do not you use bert or prompt-bert to train instead of using t5?

Hello! Congratulations for this great work!

I want to ask that why don't you just use bert or prompt-bert directly after deriving these augmented data.
Does use t5 or other generation models work better than bert which is nlu model?

Cause I have used t5 and bart directly to run simcse training, but the results are just like a piece of shit......

How to train generator and discriminator

Dear author,

I've been reading your paper and I have a question regarding the training process for the generator and discriminator.
I understand that C4 news-like and English partitions are used to obtain synthetic triplets.
but I'm still confused about how the generator and discriminator models are trained without labeled data.
(I feel like the input-output pairs data are needed to train the generator and discriminator model, but the datas above are only just a sampled sentence)

Could you please clarify this for me?

Thank you

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.