Git Product home page Git Product logo

repbert-index's Introduction

Learning To Retrieve (LTRe)

Recently, we proposed a new Dense Retrieval training method, Learning To Retrieve (LTRe). It is very effective and efficient. It achieves 0.341 in MRR@10 on MSMARCO Passage first-stage retrieval. Check out our paper: Learning To Retrieve: How to Train a Dense Retrieval Model Effectively and Efficiently.

RepBERT

RepBERT is is currently the state-of-the-art first-stage retrieval technique on MS MARCO Passage Ranking task. It represents documents and queries with fixed-length contextualized embeddings. The inner products of them are regarded as relevance scores. Its efficiency is comparable to bag-of-words methods. For more details, check out our paper:

MS MARCO Passage Ranking Leaderboard (Jun 28th 2020) Category Eval MRR@10 Latency
BM25 + BERT from (Nogueira and Cho, 2019) Cascade 0.358 3400 ms
LTRe (our recent work) First-Stage 0.341(Dev) 47 ms
RepBERT (this code) First-Stage 0.294 80 ms
BiLSTM + Co-Attention + self attention based document scorer (Alaparthi et al., 2019) (best non-BERT) Cascade 0.291 -
docTTTTTquery (Nogueira1 et al., 2019) First-Stage 0.272 64 ms
DeepCT (Dai and Callan, 2019) First-Stage 0.239 55 ms
doc2query (Nogueira et al., 2019) First-Stage 0.218 90 ms
BM25(Anserini) First-Stage 0.186 50 ms

Data and Trained Models

We make the following data available for download:

  • repbert.dev.small.top1k.tsv: 6,980,000 pairs of dev set queries and retrieved passages. In this tsv file, the first column is the query id, the second column is the passage id, and the third column is the rank of the passage. There are 1000 passages per query in this file.
  • repbert.eval.small.top1k.tsv: 6,837,000 pairs of eval set queries and retrieved passages. In this tsv file, the first column is the query id, the second column is the passage id, and the third column is the rank of the passage. There are 1000 passages per query in this file.
  • repbert.ckpt-350000.zip: Trained BERT base model to represent queries and passages. It contains two files, namely config.json and pytorch_model.bin.

Download and verify the above files from the below table:

File Size MD5 Download
repbert.dev.small.top1k.tsv 127 MB 0d08617b62a777c3c8b2d42ca5e89a8e [Google Drive]
repbert.eval.small.top1k.tsv 125 MB b56a79138f215292d674f58c694d5206 [Google Drive]
repbert.ckpt-350000.zip 386 MB b59a574f53c92de6a4ddd4b3fbef784a [Google Drive]

Replicating Results with Provided Trained Model

We provide instructions on how to replicate RepBERT retrieval results using provided trained model.

First, make sure you already installed ๐Ÿค— Transformers:

pip install transformers
git clone https://github.com/jingtaozhan/RepBERT-Index
cd RepBERT-Index

Next, download collectionandqueries.tar.gz from MSMARCO-Passage-Ranking. It contains passages, queries, and qrels.

mkdir data
cd data
wget https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz
mkdir msmarco-passage
tar xvfz collectionandqueries.tar.gz -C msmarco-passage

To confirm, collectionandqueries.tar.gz should have MD5 checksum of 31644046b18952c1386cd4564ba2ae69.

To reduce duplication of effort in training and testing, we tokenize queries and passages in advance. This should take some time (about 3-4 hours). Besides, we convert tokenized passages to numpy memmap array, which can greatly reduce the memory overhead at run time.

python convert_text_to_tokenized.py --tokenize_queries --tokenize_collection
python convert_collection_to_memmap.py

Please download the provided model repbert.ckpt-350000.zip, put it in ./data, and unzip it. You should see two files in the directory ./data/ckpt-350000, namely pytorch_model.bin and config.json.

Next, you need to precompute the representations of passages and queries.

python precompute.py --load_model_path ./data/ckpt-350000 --task doc
python precompute.py --load_model_path ./data/ckpt-350000 --task query_dev.small
python precompute.py --load_model_path ./data/ckpt-350000 --task query_eval.small

At last, you can retrieve the passages for the queries in the dev set (or eval set). multi_retrieve.py will use the gpus specified by --gpus argument and the representations of all passages are evenly distributed among all gpus. If your CUDA memory is limited, you can use --per_gpu_doc_num to specify the num of passages distributed to each gpu.

python multi_retrieve.py  --query_embedding_dir ./data/precompute/query_dev.small_embedding --output_path ./data/retrieve/repbert.dev.small.top1k.tsv --hit 1000 --gpus 0,1,2,3,4
python ms_marco_eval.py ./data/msmarco-passage/qrels.dev.small.tsv ./data/retrieve/repbert.dev.small.top1k.tsv

You can also retrieve the passages with only one GPU.

export CUDA_VISIBLE_DEVICES=0
python retrieve.py  --query_embedding_dir ./data/precompute/query_dev.small_embedding --output_path ./data/retrieve/repbert.dev.small.top1k.tsv --hit 1000 --per_gpu_doc_num 1800000
python ms_marco_eval.py ./data/msmarco-passage/qrels.dev.small.tsv ./data/retrieve/repbert.dev.small.top1k.tsv

The results should be:

#####################
MRR @10: 0.3038783713103188
QueriesRanked: 6980
#####################

Train RepBERT

Next, download qidpidtriples.train.full.tsv.gz from MSMARCO-Passage-Ranking.

cd ./data/msmarco-passage
wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.tsv.gz

Extract it and use shuf command to generate a smaller file (10%).

shuf ./qidpidtriples.train.full.tsv -o ./qidpidtriples.train.small.tsv -n 26991900

Start training. Note that the evaluaton result is about reranking.

python ./train.py --task train --evaluate_during_training

repbert-index's People

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.