Git Product home page Git Product logo

sprint's Introduction

GitHub release Build License Downloads Downloads

SPRINT provides a unified repository to easily evaluate diverse state-of-the-art neural (BERT-based) sparse-retrieval models.

SPRINT toolkit allows you to easily search or evaluate any neural sparse retriever across any dataset in the BEIR benchmark (or your own dataset). The toolkit provides evaluation of seven diverse (neural) sparse retrieval models: SPLADEv2, BT-SPLADE-L, uniCOIL, TILDEv2, DeepImpact, DocT5query and SPARTA.

SPRINT Toolkit is built around as a useful wrapper around Pyserini. It performs inference a five-step sequential pipeline unifying all sparse retrieval evaluation. The process is illustrated in the figure below:

If you want to learn and read more about the SPRINT toolkit, please refer to our paper for more details:

๐Ÿƒ Getting Started

SPRINT is backed by Pyserini which relies on Java. To make the installation eaiser, we recommend to follow the steps below via conda:

#### Create a new conda environment using conda ####
$ conda create -n sprint_env python=3.8
$ conda activate sprint_env

# Install JDK 11 via conda
$ conda install -c conda-forge openjdk=11

# Install SPRINT toolkit using PyPI
$ pip install sprint-toolkit

๐Ÿƒ Quickstart with SPRINT Toolkit

Quick start

For a quick start, we can go to the example for evaluating SPLADE (distilsplade_max) on the BeIR/SciFact dataset:

cd examples/inference/distilsplade_max/beir_scifact
bash all_in_one.sh

This will go over the whole pipeline and give the final evaluation results in beir_scifact-distilsplade_max-quantized/evaluation/metrics.json:

Results: distilsplade_max on BeIR/SciFact
   cat beir_scifact-distilsplade_max-quantized/evaluation/metrics.json 
   # {
   #     "nDCG": {
   #         "NDCG@1": 0.60333,
   #         "NDCG@3": 0.65969,
   #         "NDCG@5": 0.67204,
   #         "NDCG@10": 0.6925,
   #         "NDCG@100": 0.7202,
   #         "NDCG@1000": 0.72753
   #     },
   #     "MAP": {
   #         "MAP@1": 0.57217,
   #     ...
   # }

Or if you like running python directly, just run the code snippet below for evaluating castorini/unicoil-noexp-msmarco-passage on BeIR/SciFact:

from sprint.inference import aio


if __name__ == '__main__':  # aio.run can only be called within __main__
    aio.run(
        encoder_name='unicoil',
        ckpt_name='castorini/unicoil-noexp-msmarco-passage',
        data_name='beir/scifact',
        gpus=[0, 1],
        output_dir='beir_scifact-unicoil_noexp',
        do_quantization=True,
        quantization_method='range-nbits',  # So the doc term weights will be quantized by `(term_weights / 5) * (2 ** 8)`
        original_score_range=5,
        quantization_nbits=8,
        original_query_format='beir',
        topic_split='test'
    )
    # You would get "NDCG@10": 0.68563

Step by step

One can also run the above process in 6 separate steps under the step_by_step folder:

  1. encode: Encode documents into term weights by multiprocessing on mutliple GPUs;
  2. quantize: Quantize the document term weights into integers (can be scaped);
  3. index: Index the term weights in to Lucene index (backended by Pyserini);
  4. reformat: Reformat the queries file (e.g. the ones from BeIR) into the Pyserini format;
  5. search: Retrieve the relevant documents (backended by Pyserini);
  6. evaluate: Evaluate the results against a certain labeled data, e.g.the qrels used in BeIR (backended by BeIR)

Currently it directly supports methods (with reproduction verified):

Currently it supports data formats (by downloading automatically):

  • BeIR

Other models and data (formats) will be added.

Custom encoders

To add a custom encoder, one can refer to the example examples/inference/custom_encoder/beir_scifact, where distilsplade_max is evaluated on BeIR/SciFact with stopwords filtered out.

In detail, one just needs to define your custom encoder class and write a new encoder builder function:

from typing import Dict, List
from pyserini.encode import QueryEncoder, DocumentEncoder

class CustomQueryEncoder(QueryEncoder):

    def encode(self, text, **kwargs) -> Dict[str, float]:
        # Just an example:
        terms = text.split()
        term_weights = {term: 1 for term in terms}
        return term_weights  # Dict object, where keys/values are terms/term scores, resp.

class CustomDocumentEncoder(DocumentEncoder):

    def encode(self, texts, **kwargs) -> List[Dict[str, float]]:
        # Just an example:
        term_weights_batch = []
        for text in texts:
            terms = text.split()
            term_weights = {term: 1 for term in terms}
            term_weights_batch.append(term_weights)
        return term_weights_batch 

def custom_encoder_builder(ckpt_name, etype, device='cpu'):
    if etype == 'query':
        return CustomQueryEncoder(ckpt_name, device=device)        
    elif etype == 'document':
        return CustomDocumentEncoder(ckpt_name, device=device)
    else:
        raise ValueError

Then register custom_encoder_builder with sprint.inference.encoder_builders.register before usage:

from sprint.inference.encoder_builders import register

register('custom_encoder_builder', custom_encoder_builder)

Training (Experimental)

Will be added.

Contacts

The main contributors of this repository are:

sprint's People

Contributors

justram avatar kwang2049 avatar thakur-nandan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

sprint's Issues

Suspicious results of BM25

I got suspicious results (see the code comments at the bottom) of BM25 on hotpotqa.

import json
import os
from sparse_retrieval.inference import index, search, evaluate
import tqdm


if __name__ == "__main__":  # aio.run can only be called within __main__
    data_dir = "../hotpotqa"
    collection_dir = os.path.join(data_dir, "collection")
    if not os.path.exists(collection_dir):
        os.makedirs(collection_dir, exist_ok=True)
        with open(os.path.join(data_dir, "corpus.jsonl")) as fin, open(
            os.path.join(collection_dir, "documents.jsonl"), "w"
        ) as fout:
            for line in tqdm.tqdm(fin, total=5233329):
                line_dict = json.loads(line)
                line_out = {
                    "id": line_dict["_id"],
                    "contents": line_dict["text"],
                    "title": line_dict["title"],
                }
                fout.write(json.dumps(line_out) + "\n")

    index_dir = "./index"
    if not os.path.exists(index_dir):
        index.run(
            collection="JsonCollection",
            input=collection_dir,
            index=index_dir,
            generator="DefaultLuceneDocumentGenerator",
            impact=False,
            pretokenized=False,
            threads=12,
        )

    output_path = "./trec-format/run.tsv"
    output_latency_path = "./trec-format/latency.tsv"
    if not os.path.exists(output_path):
        search.run(
            topics=os.path.join(data_dir, "queries-test.tsv"),
            encoder_name=None,
            impact=False,
            ckpt_name=None,
            index=index_dir,
            output=output_path,
            output_latency=output_latency_path,
            hits=1000 + 1,
            batch_size=1,
            threads=1,
            output_format="trec",
            min_idf=-1,
            bm25=True,
            fields=["contents=1.0", "title=1.0"]
        )
    
    evaluate.run(
        result_path=output_path,
        latency_path=output_latency_path,
        index_path=index_dir,
        format="trec",
        qrels_path=os.path.join(data_dir, "qrels", "test.tsv"),
        output_dir="./evaluation",
        bins=10,
        k_values=[1, 2, 3, 5, 10, 20, 100, 1000],
    )


# {
#     "nDCG": {
#         "NDCG@1": 0.03038,
#         "NDCG@2": 0.02255,
#         "NDCG@3": 0.02454,
#         "NDCG@5": 0.02719,
#         "NDCG@10": 0.0308,
#         "NDCG@20": 0.03473,
#         "NDCG@100": 0.04297,
#         "NDCG@1000": 0.05644
#     },
#     "MAP": {
#         "MAP@1": 0.01519,
#         "MAP@2": 0.01789,
#         "MAP@3": 0.01906,
#         "MAP@5": 0.0203,
#         "MAP@10": 0.02159,
#         "MAP@20": 0.02253,
#         "MAP@100": 0.02344,
#         "MAP@1000": 0.02382
#     },
#     "Recall": {
#         "Recall@1": 0.01519,
#         "Recall@2": 0.02026,
#         "Recall@3": 0.0235,
#         "Recall@5": 0.02876,
#         "Recall@10": 0.03788,
#         "Recall@20": 0.05064,
#         "Recall@100": 0.08839,
#         "Recall@1000": 0.18035
#     },
#     "Precision": {
#         "P@1": 0.03038,
#         "P@2": 0.02026,
#         "P@3": 0.01567,
#         "P@5": 0.01151,
#         "P@10": 0.00758,
#         "P@20": 0.00506,
#         "P@100": 0.00177,
#         "P@1000": 0.00036
#     },
#     "mrr": {
#         "MRR@1": 0.03038,
#         "MRR@2": 0.03518,
#         "MRR@3": 0.03716,
#         "MRR@5": 0.03943,
#         "MRR@10": 0.04171,
#         "MRR@20": 0.04333,
#         "MRR@100": 0.04481,
#         "MRR@1000": 0.04534
#     },
#     "latency": {
#         "latency_avg": 0.03811540496960071,
#         "query_word_length_avg": 17.444294395678597,
#         "binned": {
#             "word_length_bins": [
#                 7.0,
#                 12.0,
#                 17.0,
#                 22.0,
#                 27.0,
#                 32.0,
#                 37.0,
#                 42.0,
#                 47.0,
#                 52.0,
#                 57.0
#             ],
#             "freqs": [
#                 1092,
#                 2675,
#                 2064,
#                 962,
#                 391,
#                 153,
#                 48,
#                 16,
#                 2,
#                 2
#             ],
#             "latencies_avg": [
#                 0.029643922404560086,
#                 0.03363995291153404,
#                 0.03924605450701228,
#                 0.045261269053850875,
#                 0.05286584836919137,
#                 0.06515988533452056,
#                 0.06344743787000577,
#                 0.08745476719923317,
#                 0.06267806049436331,
#                 0.057787854224443436
#             ],
#             "latencies_std": [
#                 0.012647166341674091,
#                 0.01210004323494389,
#                 0.01437605419108805,
#                 0.01759624604937569,
#                 0.01885252533135144,
#                 0.02530032182733274,
#                 0.020812859140483156,
#                 0.04858220706849878,
#                 0.00041049253195524216,
#                 0.0
#             ]
#         },
#         "batch_size": 1.0,
#         "processor": " Intel(R) Xeon(R) Platinum 8168 CPU @ 2.70GHz"
#     },
#     "index_size": "428.73MB"
# }

ArguAna evaluation scores are incorrect

@kwang2049 like we talked about earlier in the meeting.

We are facing self-retrieval issues with the ArguAna dataset, which leads to low scores for all sparse techniques. We should fix this by removing all doc_ids identical to the query_ids to ensure no self-retrieval is happening.

Change preprocessing script to not update the original BEIR qrels file

@kwang2049 as discussed earlier in the meeting.

python -m sparse_retrieval.inference.reformat_query \
    --original_format 'beir' \
    --data_dir datasets/beir/msmarco

We should change this above preprocessing script, as it changes the qrels files from test.tsv to test.bak.tsv. This causes issues if I want to reuse the BEIR datasets for evaluation, I need to manually go and change them again.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.