Git Product home page Git Product logo

ukplab / gpl Goto Github PK

View Code? Open in Web Editor NEW
317.0 6.0 38.0 412 KB

Powerful unsupervised domain adaptation method for dense retrieval. Requires only unlabeled corpus and yields massive improvement: "GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval" https://arxiv.org/abs/2112.07577

License: Apache License 2.0

Python 99.22% Shell 0.78%
information-retrieval domain-adaptation nlp transformers vector-search bert

gpl's Introduction

Generative Pseudo Labeling (GPL)

GPL is an unsupervised domain adaptation method for training dense retrievers. It is based on query generation and pseudo labeling with powerful cross-encoders. To train a domain-adapted model, it needs only the unlabeled target corpus and can achieve significant improvement over zero-shot models.

For more information, checkout our publication:

For reproduction, please refer to this snapshot branch.

Installation

One can either install GPL via pip

pip install gpl

or via git clone

git clone https://github.com/UKPLab/gpl.git && cd gpl
pip install -e .

Meanwhile, please make sure the correct version of PyTorch has been installed according to your CUDA version.

Usage

GPL accepts data in the BeIR-format. For example, we can download the FiQA dataset hosted by BeIR:

wget https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip
unzip fiqa.zip
head -n 2 fiqa/corpus.jsonl  # One can check this data format. Actually GPL only need this `corpus.jsonl` as data input for training.

Then we can either use the python -m function to run GPL training directly:

export dataset="fiqa"
python -m gpl.train \
    --path_to_generated_data "generated/$dataset" \
    --base_ckpt "distilbert-base-uncased" \
    --gpl_score_function "dot" \
    --batch_size_gpl 32 \
    --gpl_steps 140000 \
    --new_size -1 \
    --queries_per_passage -1 \
    --output_dir "output/$dataset" \
    --evaluation_data "./$dataset" \
    --evaluation_output "evaluation/$dataset" \
    --generator "BeIR/query-gen-msmarco-t5-base-v1" \
    --retrievers "msmarco-distilbert-base-v3" "msmarco-MiniLM-L-6-v3" \
    --retriever_score_functions "cos_sim" "cos_sim" \
    --cross_encoder "cross-encoder/ms-marco-MiniLM-L-6-v2" \
    --qgen_prefix "qgen" \
    --do_evaluation \
    # --use_amp   # Use this for efficient training if the machine supports AMP

# One can run `python -m gpl.train --help` for the information of all the arguments
# To reproduce the experiments in the paper, set `base_ckpt` to "GPL/msmarco-distilbert-margin-mse" (https://huggingface.co/GPL/msmarco-distilbert-margin-mse)

or import GPL's trainining method in a python script:

import gpl

dataset = 'fiqa'
gpl.train(
    path_to_generated_data=f"generated/{dataset}",
    base_ckpt="distilbert-base-uncased",  
    # base_ckpt='GPL/msmarco-distilbert-margin-mse',  
    # The starting checkpoint of the experiments in the paper
    gpl_score_function="dot",
    # Note that GPL uses MarginMSE loss, which works with dot-product
    batch_size_gpl=32,
    gpl_steps=140000,
    new_size=-1,
    # Resize the corpus to `new_size` (|corpus|) if needed. When set to None (by default), the |corpus| will be the full size. When set to -1, the |corpus| will be set automatically: If QPP * |corpus| <= 250K, |corpus| will be the full size; else QPP will be set 3 and |corpus| will be set to 250K / 3
    queries_per_passage=-1,
    # Number of Queries Per Passage (QPP) in the query generation step. When set to -1 (by default), the QPP will be chosen automatically: If QPP * |corpus| <= 250K, then QPP will be set to 250K / |corpus|; else QPP will be set 3 and |corpus| will be set to 250K / 3
    output_dir=f"output/{dataset}",
    evaluation_data=f"./{dataset}",
    evaluation_output=f"evaluation/{dataset}",
    generator="BeIR/query-gen-msmarco-t5-base-v1",
    retrievers=["msmarco-distilbert-base-v3", "msmarco-MiniLM-L-6-v3"],
    retriever_score_functions=["cos_sim", "cos_sim"],
    # Note that these two retriever model work with cosine-similarity
    cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
    qgen_prefix="qgen",
    # This prefix will appear as part of the (folder/file) names for query-generation results: For example, we will have "qgen-qrels/" and "qgen-queries.jsonl" by default.
    do_evaluation=True,
    # use_amp=True   # One can use this flag for enabling the efficient float16 precision
)

One can also refer to this toy example on Google Colab for better understanding how the code works.

How does GPL work?

The workflow of GPL is shown as follows:

  1. GPL first use a seq2seq (we use BeIR/query-gen-msmarco-t5-base-v1 by default) model to generate queries_per_passage queries for each passage in the unlabeled corpus. The query-passage pairs are viewed as positive examples for training.

    Result files (under path $path_to_generated_data): (1) ${qgen}-qrels/train.tsv, (2) ${qgen}-queries.jsonl and also (3) corpus.jsonl (copied from $evaluation_data/);

  2. Then, it runs negative mining with the generated queries as input on the target corpus. The mined passages will be viewed as negative examples for training. One can specify any dense retrievers (SBERT or Huggingface/transformers checkpoints, we use msmarco-distilbert-base-v3 + msmarco-MiniLM-L-6-v3 by default) or BM25 to the argument retrievers as the negative miner.

    Result file (under path $path_to_generated_data): hard-negatives.jsonl;

  3. Finally, it does pseudo labeling with the powerful cross-encoders (we use cross-encoder/ms-marco-MiniLM-L-6-v2 by default.) on the query-passage pairs that we have so far (for both positive and negative examples).

    Result file (under path $path_to_generated_data): gpl-training-data.tsv. It contains (gpl_steps * batch_size_gpl) tuples in total.

Up to now, we have the actual training data ready. One can look at sample-data/generated/fiqa for a quick example about the data format. The very last step is to apply the MarginMSE loss to teach the student retriever to mimic the margin scores, CE(query, positive) - CE(query, negative) labeled by the teacher model (Cross-Encoder, CE). And of course, the MarginMSE step is included in GPL and will be done automatically:). Note that MarginMSE works with dot-product and thus the final models trained with GPL works with dot-product.

PS: The --retrievers are for negative mining. They can be any dense retrievers trained on the general domain (e.g. MS MARCO) and do not need to be strong for the target task/domain. Please refer to the paper for more details (cf. Table 7).

Customized data

One can also replace/put the customized data for any intermediate step under the path $path_to_generated_data with the same name fashion. GPL will skip the intermediate steps by using these provided data.

As a typical workflow, one might only have the (English) unlabeld corpus and want a good model performing well for this corpus. To run GPL training under such condition, one just needs these steps:

  1. Prepare your corpus in the same format as the data sample;
  2. Put your corpus.jsonl under a folder, e.g. named as "generated" for data loading and data generation by GPL;
  3. Call gpl.train with the folder path as an input argument: (other arguments work as usual)
python -m gpl.train \
    --path_to_generated_data "generated" \
    --output_dir "output" \
    --new_size -1 \
    --queries_per_passage -1

Pre-trained checkpoints and generated data

Pre-trained checkpoints

We now release the pre-trained GPL models via the https://huggingface.co/GPL. There are currently five types of models:

  1. GPL/${dataset}-msmarco-distilbert-gpl: Model with training order of (1) MarginMSE on MSMARCO -> (2) GPL on ${dataset};
  2. GPL/${dataset}-tsdae-msmarco-distilbert-gpl: Model with training order of (1) TSDAE on ${dataset} -> (2) MarginMSE on MSMARCO -> (3) GPL on ${dataset};
  3. GPL/msmarco-distilbert-margin-mse: Model trained on MSMARCO with MarginMSE;
  4. GPL/${dataset}-tsdae-msmarco-distilbert-margin-mse: Model with training order of (1) TSDAE on ${dataset} -> (2) MarginMSE on MSMARCO;
  5. GPL/${dataset}-distilbert-tas-b-gpl-self_miner: Starting from the tas-b model, the models were trained with GPL on the target corpus ${dataset} with the base model itself as the negative miner (here noted as "self_miner").

Models 1. and 2. were actually trained on top of models 3. and 4. resp. All GPL models were trained the automatic setting of new_size and queries_per_passage (by setting them to -1). This automatic setting can keep the performance while being efficient. For more details, please refer to the section 4.1 in the paper.

Among these models, GPL/${dataset}-distilbert-tas-b-gpl-self_miner ones works the best on the BeIR benchmark:

For reproducing the results with the same package versions used in the experiments, please refer to the conda environment file, environment.yml.

Generated data

We now release the generated data used in the experiments of the GPL paper:

  1. The generated data for the main experiments on the 6 BeIR datasets: https://public.ukp.informatik.tu-darmstadt.de/kwang/gpl/generated-data/main/;
  2. The generated data for the experiments on the full 18 BeIR datasets: https://public.ukp.informatik.tu-darmstadt.de/kwang/gpl/generated-data/beir.

Please note that the 4 datasets of bioasq, robust04, trec-news and signal1m are only available after registration with the original official authorities. We only release the document IDs for these corpora with the file name corpus.doc_ids.txt. For more details, please refer to the BeIR repository.

Citation

If you use the code for evaluation, feel free to cite our publication GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval:

@article{wang2021gpl,
    title = "GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval",
    author = "Kexin Wang and Nandan Thakur and Nils Reimers and Iryna Gurevych", 
    journal= "arXiv preprint arXiv:2112.07577",
    month = "4",
    year = "2021",
    url = "https://arxiv.org/abs/2112.07577",
}

Contact person and main contributor: Kexin Wang, [email protected]

https://www.ukp.tu-darmstadt.de/

https://www.tu-darmstadt.de/

Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.

This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.

gpl's People

Contributors

dpetrak avatar kwang2049 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

gpl's Issues

Multi-lingual GPL

Hi,

General problem with multilingual models: give unequal performance among languages if the proportion of docs in lang A is greatly superior to the proportion of docs in lang B.

Wouldn't it be beneficial for the multilingual model to translate all docs in all languages before fine-tuning with multi-lingual GPL?

Thanks!

Do we have to keep the intermediate results?

This is more a question than an issue.

I just run a training using gpl_steps as 50000. In the output folder there was created five folders (10000,20000,30000,40000,50000).

Does the pytorch_model.bin that is in the root level of the folder encompasses all the knowledge obtained during the train and I can use just it?

Do we have to keep the intermediate folders?

base checkpoint selection

I see in the code that two models (distilbert-base-uncased, msmarco-distilbert-margin-mse) are recommended to use as initial checkpoints. I tried to use other Sentence-Transformers models like all-mpnet-base-v2 but it didn't work. Is there a difference in the architecture of the models and the implementation out there? What models can be used here as initial checkpoints?

Error while running the training script

2022-04-14 06:00:25] INFO [gpl.toolkit.pl.run:60] Begin pseudo labeling
0%| | 0/140000 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/ec2-user/SageMaker/gpl/gpl/toolkit/pl.py", line 63, in run
batch = next(hard_negative_iterator)
File "/home/ec2-user/SageMaker/kernels/gpl_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 530, in next
data = self._next_data()
File "/home/ec2-user/SageMaker/kernels/gpl_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 569, in _next_data
index = self._next_index() # may raise StopIteration
File "/home/ec2-user/SageMaker/kernels/gpl_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in _next_index
return next(self._sampler_iter) # may raise StopIteration
StopIteration

pytrec_eval dependency issue

I'm not able to complete the gpl install due to this known issue: cvangysel/pytrec_eval#32

I'm running python 3.9 and windows 10. Any suggestions?

Update: I tried running python 3.6-3.9 to no avail. I was able to take the package yml with some light edit and do individual installs for the Beirs and pytrec-eval dependencies (used pytrec-eval-terrier and a conda no-deps flag on the Beirs build).

Model not saved

After running following script the model seem not to be saved.
Script run:
`!python -m gpl.train
--path_to_generated_data "generated"
--output_dir "output"
--new_size -1
--queries_per_passage -1

`

Final output:
Iteration: 100% 139995/140000 [2:39:11<00:00, 14.71it/s] Iteration: 100% 139997/140000 [2:39:11<00:00, 14.73it/s] Iteration: 100% 140000/140000 [2:39:12<00:00, 14.66it/s] Epoch: 100% 1/1 [2:39:12<00:00, 9552.27s/it]

Where as in the example in the collab model seems to get saved (link)

[2022-06-27 21:55:11] INFO [sentence_transformers.SentenceTransformer.save:352] Save model to output/fiqa
[2022-06-27 21:55:12] INFO [sentence_transformers.SentenceTransformer.save:352] Save model to output/fiqa/100

KeyError during pseudo labeling

Hi ,

I am facing a key error while pseudo labeling. Looks like pos_pid selected is not found in the corpus.

INFO [gpl.toolkit.pl.run:60] Begin pseudo labeling
.....
File ~gpl/toolkit/dataset.py:78, in HardNegativeDataset._sample_tuple(self, query_dict)
     75 query_text = self.queries[query_dict['qid']]
     77 pos_pid = random.choice(pos_pids)
---> 78 pos_text = concat_title_and_body(pos_pid, self.corpus, self.sep)
     80 neg_pid = random.choice(list(neg_pids))
     81 neg_text = concat_title_and_body(neg_pid, self.corpus, self.sep)

File ~gpl/toolkit/dataset.py:12, in concat_title_and_body(did, corpus, sep)
     10 def concat_title_and_body(did, corpus, sep):
     11     document = []
---> 12     title = corpus[did]['title'].strip()
     13     body = corpus[did]['text'].strip()
     14     if len(title):

KeyError: '92974'

The corpus, I have has the below structure. Does the order of the _id and numbers matter?

{"text":"This is the domain text","_id":3,"title":"","metadata":{}}
{"text":"This is the domain text 2","_id":4,"title":"","metadata":{}}

Code to train:

gpl.train(
    path_to_generated_data=f"generated/{dataset}",
    mnrl_output_dir="mnrl_output_dir",
    mnrl_evaluation_output="mnrl_evaluation_output",
    base_ckpt="distilbert-base-uncased",  
    # base_ckpt='GPL/msmarco-distilbert-margin-mse',  
    # The starting checkpoint of the experiments in the paper
    gpl_score_function="dot",
    # Note that GPL uses MarginMSE loss, which works with dot-product
    batch_size_gpl=64,
    gpl_steps=140000,
    new_size=-1,
    # Resize the corpus to `new_size` (|corpus|) if needed. When set to None (by default), the |corpus| will be the full size. When set to -1, the |corpus| will be set automatically: If QPP * |corpus| <= 250K, |corpus| will be the full size; else QPP will be set 3 and |corpus| will be set to 250K / 3
    queries_per_passage=-1,
    # Number of Queries Per Passage (QPP) in the query generation step. When set to -1 (by default), the QPP will be chosen automatically: If QPP * |corpus| <= 250K, then QPP will be set to 250K / |corpus|; else QPP will be set 3 and |corpus| will be set to 250K / 3
    output_dir=f"output/{dataset}",
    evaluation_data=f"./{dataset}",
    evaluation_output=f"evaluation/{dataset}",
    generator="BeIR/query-gen-msmarco-t5-base-v1",
    retrievers=["msmarco-distilbert-base-v3", "msmarco-MiniLM-L-6-v3"],
    retriever_score_functions=["cos_sim", "cos_sim"],
    # Note that these two retriever model work with cosine-similarity
    cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
    qgen_prefix="qgen",
    # This prefix will appear as part of the (folder/file) names for query-generation results: For example, we will have "qgen-qrels/" and "qgen-queries.jsonl" by default.
    do_evaluation=True,
    # --use_amp   # One can use this flag for enabling the efficient float16 precision
)

Could you help in what I am missing or doing wrong?

GPL with low performant CE

Does it make sense to train a model using GPL, when the CE used for pseudo labelling is a bad performer on the domain dataset (i.e. when using the CE directly for IR tasks on the domain dataset, the results are poor)? I would think the GPL trained model would also be a poor performer as the CE performance represents the upperbound the GPL can achieve.

If my reasoning is correct, is there a way to deal with this shortcoming?

Guidance on gpl_stapes, new_size and batch_size_gpl

Hello,

I am looking for some guidance on below parameters of gpl.train().

  • gpl_stapes - Do we need such a huge value of 140000 for corpus of size 1300?
  • new_size
  • batch_size_gpl - would it help to speed up the training if we keep this as 64 or 128?
    How to derive the values of these parameters based on dataset or corpus.jsonl?

Evaluation every 10k steps

Hello,

I can't figure out how to use the evaluation while training, not sure what the data format is and how to plot the ndcg@k. I've figured out how to do it after training though, by loading the saved models in a loop and predicting on my evaluation data.

The question I have is, does seeing the evaluation only every 10k steps make sense ? How to be sure there aren't some big variations in between ? My training doesn't stop improving after 100k steps and is not as smooth as in your paper.
Screenshot 2023-01-31 at 14 31 01

Any hints would be greatly appreciated.

What are the effects of overfitting for downstream tasks?

I was trying to adapt the sentence-transformers/multi-qa-mpnet-base-dot-v1 model to the financial domain using SEC data using GPL.

I trained the model with the following hyperparams:

{
"learning_rate": 0.00002,
"num_examples_eval": 1000, 
"num_examples_train" : 20000,
"num_epochs": 15
}

My loss curves were as follows:
Train Loss Curve
Validation Loss Curve

Seems like the model itself is overfitting, but the performance of the trained model is not up to the mark even if I had used early stopping. I trained one for 3 epochs and the unadapted models perform better than the trained ones. And I was wondering if I could have some insights on why this is, I don't really know where to ask this question. If there is some other place where this question is suitable, please let me know and I will take it there, Especially because this is more of a theoretical question than something tied to this library.

I am relatively new to training models, so please let me know if I am making any obvious mistakes here (or if any other information is required).

issue while running the training script

Hi,

I have created a custom corpus.jsonl in the format structure as instructed.
I am successfully able to install the library the gpl library on mac machine.

I use the following piece of code:
import gpl

dataset = 'fiqa'
gpl.train(
path_to_generated_data=f"generated/{dataset}",
base_ckpt="distilbert-base-uncased",
# base_ckpt='GPL/msmarco-distilbert-margin-mse',
# The starting checkpoint of the experiments in the paper
gpl_score_function="dot",
# Note that GPL uses MarginMSE loss, which works with dot-product
batch_size_gpl=32,
gpl_steps=140000,
new_size=-1,
# Resize the corpus to new_size (|corpus|) if needed. When set to None (by default), the |corpus| will be the full size. When set to -1, the |corpus| will be set automatically: If QPP * |corpus| <= 250K, |corpus| will be the full size; else QPP will be set 3 and |corpus| will be set to 250K / 3
queries_per_passage=-1,
# Number of Queries Per Passage (QPP) in the query generation step. When set to -1 (by default), the QPP will be chosen automatically: If QPP * |corpus| <= 250K, then QPP will be set to 250K / |corpus|; else QPP will be set 3 and |corpus| will be set to 250K / 3
output_dir=f"output/{dataset}",
evaluation_data=f"./{dataset}",
evaluation_output=f"evaluation/{dataset}",
generator="BeIR/query-gen-msmarco-t5-base-v1",
retrievers=["msmarco-distilbert-base-v3", "msmarco-MiniLM-L-6-v3"],
retriever_score_functions=["cos_sim", "cos_sim"],
# Note that these two retriever model work with cosine-similarity
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
qgen_prefix="qgen",
# This prefix will appear as part of the (folder/file) names for query-generation results: For example, we will have "qgen-qrels/" and "qgen-queries.jsonl" by default.
do_evaluation=True,
# --use_amp # One can use this flag for enabling the efficient float16 precision
)

I have changed the following paths:

path_to_generated_data=f"generated/{dataset}",
Here i am adding my path to custom data corpus.jsonl

as i run this file, I get the following error:

train() missing 2 required positional arguments: " mnrl_output_dir" and "mnrl_evaluation_output"

My purpose here is to do domain adaption for questions in form of sentences for semantic search task.

Please let me know what would be the exact steps to train on custom data ?

Evaluation data format

Hi,

1/ How should the evaluation data format be as passed in the evaluation_data argument? Could you provide me some example of evaluation data and how it should be formatted?

2/ How does the evaluation work on these data? What are the tests passed and labels used?

Thanks!

Training on multi gpu

Is there a way to train GPL model with multi gpu? If yes can that help for training with larger batches?

Should the leaning domain contain only assertion texts (like "Python is a high-level general-purpose programming language")?

Hi.
Should the leaning domain contain only assertion texts (like "Python is a high-level general-purpose programming language" in your example)? In your pipeline the first step is Query Generation: For a given text from our domain, we first use a T5 model that generates a possible query for the given text. E.g. when your text is “Python is a high-level general-purpose programming language”, the model might generate a query like “What is Python”. You can find various query generators on our doc2query-hub. Does that mean that texts which couldn't be converted into queries (e.g. "Investment consulting for legal entities and individuals.") cannot be used for training?

RuntimeError: CUDA out of memory

Hi,

When trying to generate intermediate results with the following command:

dataset = 'tiny'
gpl.train(
    path_to_generated_data=f"generated/{dataset}",
    base_ckpt='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',  
    # base_ckpt='GPL/msmarco-distilbert-margin-mse',  # The starting checkpoint of the experiments in the paper
    gpl_score_function="dot",
    # Note that GPL uses MarginMSE loss, which works with dot-product
    batch_size_gpl=32,
    gpl_steps=140000,
    new_size=-1,
    # Resize the corpus to `new_size` (|corpus|) if needed. When set to None (by default), the |corpus| will be the full size. When set to -1, the |corpus| will be set automatically: If QPP * |corpus| <= 250K, |corpus| will be the full size; else QPP will be set 3 and |corpus| will be set to 250K / 3
    queries_per_passage=-1,
    # Number of Queries Per Passage (QPP) in the query generation step. When set to -1 (by default), the QPP will be chosen automatically: If QPP * |corpus| <= 250K, then QPP will be set to 250K / |corpus|; else QPP will be set 3 and |corpus| will be set to 250K / 3
    output_dir=f"output/{dataset}",
    evaluation_data=f"./{dataset}",
    evaluation_output=f"evaluation/{dataset}",
    generator="BeIR/query-gen-msmarco-t5-large-v1",
    retrievers=["msmarco-distilbert-base-tas-b", "msmarco-MiniLM-L6-cos-v5"],
    retriever_score_functions=["dot", "cos_sim"],
    # Note that these two retriever model work with cosine-similarity
    cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
    qgen_prefix="qgen",
    # This prefix will appear as part of the (folder/file) names for query-generation results: For example, we will have "qgen-qrels/" and "qgen-queries.jsonl" by default.
    do_evaluation=True,
    use_amp=True   # One can use this flag for enabling the efficient float16 precision
)

I got the following error:

2022-08-26 11:55:08 - Loading faiss with AVX2 support.
2022-08-26 11:55:08 - Could not load library with AVX2 support due to:
ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'")
2022-08-26 11:55:08 - Loading faiss.
2022-08-26 11:55:08 - Successfully loaded faiss.
[2022-08-26 11:55:10] INFO [gpl.train.train:79] Corpus does not exist in generated/tiny. Now clone the one from the evaluation path ./tiny
[2022-08-26 11:55:10] INFO [gpl.train.train:84] Automatically set `new_size` to 83334
[2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:89] Loading Corpus...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4252/4252 [00:00<00:00, 277639.61it/s]
[2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:91] Loaded 4252 Documents.
[2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:92] Doc Example: {'text': 'Without a specific goal for your speech, your audience will be lost in understanding the message you are seeking to deliver, because you will not know yourself what you are seeking to deliver in that speech.', 'title': ''}
[2022-08-26 11:55:10] WARNING [gpl.toolkit.resize.resize:19] `new_size` should be smaller than the corpus size
[2022-08-26 11:55:10] INFO [gpl.toolkit.resize.resize:41] Resized the corpus in ./tiny to generated/tiny with new size 83334
[2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:89] Loading Corpus...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4252/4252 [00:00<00:00, 321974.74it/s]
[2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:91] Loaded 4252 Documents.
[2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:92] Doc Example: {'text': 'Without a specific goal for your speech, your audience will be lost in understanding the message you are seeking to deliver, because you will not know yourself what you are seeking to deliver in that speech.', 'title': ''}
[2022-08-26 11:55:10] INFO [gpl.train.train:99] Automatically set `queries_per_passage` to 59
[2022-08-26 11:55:10] INFO [gpl.train.train:125] No generated queries found. Now generating it
[2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:89] Loading Corpus...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4252/4252 [00:00<00:00, 308459.11it/s]
[2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:91] Loaded 4252 Documents.
[2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:92] Doc Example: {'text': 'Without a specific goal for your speech, your audience will be lost in understanding the message you are seeking to deliver, because you will not know yourself what you are seeking to deliver in that speech.', 'title': ''}
[2022-08-26 11:55:20] INFO [beir.generation.models.auto_model.__init__:16] Use pytorch device: cuda
[2022-08-26 11:55:21] INFO [beir.generation.generate.generate:40] Starting to Generate 59 Questions Per Passage using top-p (nucleus) sampling...
[2022-08-26 11:55:21] INFO [beir.generation.generate.generate:41] Params: top_p = 0.95
[2022-08-26 11:55:21] INFO [beir.generation.generate.generate:42] Params: top_k = 25
[2022-08-26 11:55:21] INFO [beir.generation.generate.generate:43] Params: max_length = 64
[2022-08-26 11:55:21] INFO [beir.generation.generate.generate:44] Params: ques_per_passage = 59
[2022-08-26 11:55:21] INFO [beir.generation.generate.generate:45] Params: batch size = 32
pas:   0%|                                                                                                                                                                                          | 0/133 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/matthieu/Tinycoaching/GPL/v.0.1.0/gpl_query_generation.py", line 316, in <module>
    gpl.train(
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/gpl/train.py", line 127, in train
    qgen(path_to_generated_data, path_to_generated_data, generator_name_or_path=generator, ques_per_passage=queries_per_passage, bsz=batch_size_generation, qgen_prefix=qgen_prefix)
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/gpl/toolkit/qgen.py", line 23, in qgen
    generator.generate(corpus, output_dir=output_dir, ques_per_passage=ques_per_passage, prefix=prefix, batch_size=bsz)
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/beir/generation/generate.py", line 54, in generate
    queries = self.model.generate(
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/beir/generation/models/auto_model.py", line 28, in generate
    outs = self.model.generate(
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/generation_utils.py", line 1326, in generate
    return self.sample(
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/generation_utils.py", line 1944, in sample
    outputs = self(
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 1639, in forward
    decoder_outputs = self.decoder(
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 1035, in forward
    layer_outputs = layer_module(
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 692, in forward
    cross_attention_outputs = self.layer[1](
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 606, in forward
    attention_output = self.EncDecAttention(
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 509, in forward
    scores = torch.matmul(
RuntimeError: CUDA out of memory. Tried to allocate 584.00 MiB (GPU 0; 23.70 GiB total capacity; 20.69 GiB already allocated; 587.94 MiB free; 20.83 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

My corpus consists of small paragraphs of 3-4 lines and I used use_amp option. How could I deal with it?

TSDAE + GPL with french data

Hello, Thanks for the amazing work.
I am trying to do domain adaptation using TSDAE +GPL with an unlabeled french dataset. As for TSDAE there are few good base models like camemBERT etc. Once I pretrain with TSDAE, I intend to use GPL like so :

gpl.train(
    path_to_generated_data="generated",
    base_ckpt="MY TSDAE MODEL",  
    gpl_score_function="dot",
    batch_size_gpl=32,
    gpl_steps=-1,
    new_size=-1,
    queries_per_passage=1,
    output_dir="output",
    generator="doc2query/msmarco-french-mt5-base-v1",
    retrievers=["antoinelouis/biencoder-msmarco-distilbert-cos-v5-mmarcoFR", "antoinelouis/biencoder-msmarco-MiniLM-L12-cos-v5-mmarcoFR"],
    retriever_score_functions=["cos_sim", "cos_sim"],
    cross_encoder="cross-encoder/mmarco-mMiniLMv2-L12-H384-v1",
    qgen_prefix="qgen",
    do_evaluation=False,
    # use_amp=True   # One can use this flag for enabling the efficient float16 precision
)

The generator, retrievers and cross-encoder are all french models. The code seems to work but i'm not sure if I'm doing the right thing with the choice of models since there is no infos about using GPL for other languages. Does this configuration seem okay to you ?

Also, can you please confirm my understanding for (1) TSDAE on ${target} -> (2) MarginMSE on MSMARCO -> (3) GPL on ${target};
The base model (camemBERT in my case) will be pretrained via TSDAE (step1) and when i plug it to the GPL step (2) will be done automatically (training on MSMACRO dataset which is apparently provided in the GPL package) then the actual GPL will be done on my unlabeled corpus (target which is the same one used in step 1).

And if this is true, how do i train on a french version of MS MACRO ? Actually this whole "MarginMSE on MSMARCO" thing confuses me, because why do we it if the retrievers are already trained on such datasets.

Thanks.

Problem with tensorflow when installing GPL in python environment

Hi,

When creating a conda environment with python==3.8.8 and trying to install GPL within it using pip install gpl, the installation loops by collecting iteratively descending versions of tensorflow without end... :

Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com
Collecting gpl
  Using cached gpl-0.0.9-py3-none-any.whl (24 kB)
Collecting beir
  Using cached beir-0.2.3.tar.gz (52 kB)
Collecting easy-elasticsearch>=0.0.7
  Using cached easy_elasticsearch-0.0.7-py3-none-any.whl (12 kB)
Collecting elasticsearch==7.12.1
  Using cached elasticsearch-7.12.1-py2.py3-none-any.whl (339 kB)
Collecting requests
  Downloading requests-2.27.1-py2.py3-none-any.whl (63 kB)
     |████████████████████████████████| 63 kB 1.5 MB/s 
Collecting tqdm
  Using cached tqdm-4.62.3-py2.py3-none-any.whl (76 kB)
Requirement already satisfied: certifi in ./anaconda3/envs/gpl_fresh/lib/python3.8/site-packages (from elasticsearch==7.12.1->easy-elasticsearch>=0.0.7->gpl) (2021.10.8)
Collecting urllib3<2,>=1.21.1
  Downloading urllib3-1.26.8-py2.py3-none-any.whl (138 kB)
     |████████████████████████████████| 138 kB 12.7 MB/s 
Collecting sentence-transformers
  Using cached sentence_transformers-2.1.0-py3-none-any.whl
Collecting pytrec_eval
  Using cached pytrec_eval-0.5.tar.gz (15 kB)
Collecting faiss_cpu
  Using cached faiss_cpu-1.7.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)
Collecting tensorflow>=2.2.0
  Using cached tensorflow-2.8.0-cp38-cp38-manylinux2010_x86_64.whl (497.6 MB)
Collecting tensorflow-text
  Using cached tensorflow_text-2.7.3-cp38-cp38-manylinux2010_x86_64.whl (4.9 MB)
Collecting tensorflow-hub
  Using cached tensorflow_hub-0.12.0-py2.py3-none-any.whl (108 kB)
Requirement already satisfied: setuptools in ./anaconda3/envs/gpl_fresh/lib/python3.8/site-packages (from tensorflow>=2.2.0->beir->gpl) (58.0.4)
Collecting grpcio<2.0,>=1.24.3
  Downloading grpcio-1.43.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.1 MB)
     |████████████████████████████████| 4.1 MB 2.2 MB/s 
Collecting typing-extensions>=3.6.6
  Downloading typing_extensions-4.0.1-py3-none-any.whl (22 kB)
Collecting keras-preprocessing>=1.1.1
  Using cached Keras_Preprocessing-1.1.2-py2.py3-none-any.whl (42 kB)
Collecting wrapt>=1.11.0
  Downloading wrapt-1.13.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (84 kB)
     |████████████████████████████████| 84 kB 9.6 MB/s 
Collecting tf-estimator-nightly==2.8.0.dev2021122109
  Using cached tf_estimator_nightly-2.8.0.dev2021122109-py2.py3-none-any.whl (462 kB)
Collecting tensorboard<2.9,>=2.8
  Using cached tensorboard-2.8.0-py3-none-any.whl (5.8 MB)
Collecting google-pasta>=0.1.1
  Using cached google_pasta-0.2.0-py3-none-any.whl (57 kB)
Collecting six>=1.12.0
  Using cached six-1.16.0-py2.py3-none-any.whl (11 kB)
Collecting absl-py>=0.4.0
  Using cached absl_py-1.0.0-py3-none-any.whl (126 kB)
Collecting opt-einsum>=2.3.2
  Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)
Collecting protobuf>=3.9.2
  Downloading protobuf-3.19.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
     |████████████████████████████████| 1.1 MB 25.7 MB/s 
Collecting libclang>=9.0.1
  Using cached libclang-13.0.0-py2.py3-none-manylinux1_x86_64.whl (14.5 MB)
Collecting keras<2.9,>=2.8.0rc0
  Using cached keras-2.8.0-py2.py3-none-any.whl (1.4 MB)
Collecting tensorflow-io-gcs-filesystem>=0.23.1
  Using cached tensorflow_io_gcs_filesystem-0.23.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (2.1 MB)
Collecting termcolor>=1.1.0
  Using cached termcolor-1.1.0-py3-none-any.whl
Collecting h5py>=2.9.0
  Using cached h5py-3.6.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (4.5 MB)
Requirement already satisfied: numpy>=1.20 in ./anaconda3/envs/gpl_fresh/lib/python3.8/site-packages (from tensorflow>=2.2.0->beir->gpl) (1.21.2)
Collecting astunparse>=1.6.0
  Using cached astunparse-1.6.3-py2.py3-none-any.whl (12 kB)
Collecting gast>=0.2.1
  Using cached gast-0.5.3-py3-none-any.whl (19 kB)
Collecting flatbuffers>=1.12
  Using cached flatbuffers-2.0-py2.py3-none-any.whl (26 kB)
Requirement already satisfied: wheel<1.0,>=0.23.0 in ./anaconda3/envs/gpl_fresh/lib/python3.8/site-packages (from astunparse>=1.6.0->tensorflow>=2.2.0->beir->gpl) (0.37.1)
Collecting google-auth-oauthlib<0.5,>=0.4.1
  Downloading google_auth_oauthlib-0.4.6-py2.py3-none-any.whl (18 kB)
Collecting google-auth<3,>=1.6.3
  Downloading google_auth-2.6.0-py2.py3-none-any.whl (156 kB)
     |████████████████████████████████| 156 kB 17.3 MB/s 
Collecting tensorboard-plugin-wit>=1.6.0
  Downloading tensorboard_plugin_wit-1.8.1-py3-none-any.whl (781 kB)
     |████████████████████████████████| 781 kB 27.3 MB/s 
Collecting werkzeug>=0.11.15
  Using cached Werkzeug-2.0.2-py3-none-any.whl (288 kB)
Collecting markdown>=2.6.8
  Downloading Markdown-3.3.6-py3-none-any.whl (97 kB)
     |████████████████████████████████| 97 kB 3.8 MB/s 
Collecting tensorboard-data-server<0.7.0,>=0.6.0
  Using cached tensorboard_data_server-0.6.1-py3-none-manylinux2010_x86_64.whl (4.9 MB)
Collecting rsa<5,>=3.1.4
  Downloading rsa-4.8-py3-none-any.whl (39 kB)
Collecting pyasn1-modules>=0.2.1
  Using cached pyasn1_modules-0.2.8-py2.py3-none-any.whl (155 kB)
Collecting cachetools<6.0,>=2.0.0
  Downloading cachetools-5.0.0-py3-none-any.whl (9.1 kB)
Collecting requests-oauthlib>=0.7.0
  Downloading requests_oauthlib-1.3.1-py2.py3-none-any.whl (23 kB)
Collecting importlib-metadata>=4.4
  Downloading importlib_metadata-4.10.1-py3-none-any.whl (17 kB)
Collecting zipp>=0.5
  Downloading zipp-3.7.0-py3-none-any.whl (5.3 kB)
Collecting pyasn1<0.5.0,>=0.4.6
  Using cached pyasn1-0.4.8-py2.py3-none-any.whl (77 kB)
Collecting charset-normalizer~=2.0.0
  Downloading charset_normalizer-2.0.11-py3-none-any.whl (39 kB)
Collecting idna<4,>=2.5
  Using cached idna-3.3-py3-none-any.whl (61 kB)
Collecting oauthlib>=3.0.0
  Downloading oauthlib-3.2.0-py3-none-any.whl (151 kB)
     |████████████████████████████████| 151 kB 26.8 MB/s 
Collecting scipy
  Using cached scipy-1.7.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (39.3 MB)
Collecting scikit-learn
  Downloading scikit_learn-1.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.7 MB)
     |████████████████████████████████| 26.7 MB 2.9 MB/s 
Collecting torchvision
  Using cached torchvision-0.11.3-cp38-cp38-manylinux1_x86_64.whl (23.2 MB)
Collecting nltk
  Downloading nltk-3.6.7-py3-none-any.whl (1.5 MB)
     |████████████████████████████████| 1.5 MB 7.0 MB/s 
Collecting tokenizers>=0.10.3
  Downloading tokenizers-0.11.4-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.8 MB)
     |████████████████████████████████| 6.8 MB 2.5 MB/s 
Collecting huggingface-hub
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
     |████████████████████████████████| 67 kB 2.9 MB/s 
Collecting torch>=1.6.0
  Downloading torch-1.10.2-cp38-cp38-manylinux1_x86_64.whl (881.9 MB)
     |████████████████████████████████| 881.9 MB 5.9 kB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
     |████████████████████████████████| 1.2 MB 5.4 MB/s 
Collecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.16.2-py3-none-any.whl (3.5 MB)
     |████████████████████████████████| 3.5 MB 2.2 MB/s 
Collecting filelock
  Downloading filelock-3.4.2-py3-none-any.whl (9.9 kB)
Collecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
     |████████████████████████████████| 895 kB 9.6 MB/s 
Requirement already satisfied: pyyaml>=5.1 in ./.local/lib/python3.8/site-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers->beir->gpl) (5.4.1)
Collecting packaging>=20.0
  Downloading packaging-21.3-py3-none-any.whl (40 kB)
     |████████████████████████████████| 40 kB 2.4 MB/s 
Collecting regex!=2019.12.17
  Downloading regex-2022.1.18-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (764 kB)
     |████████████████████████████████| 764 kB 26.7 MB/s 
Collecting pyparsing!=3.0.5,>=2.0.2
  Downloading pyparsing-3.0.7-py3-none-any.whl (98 kB)
     |████████████████████████████████| 98 kB 4.1 MB/s 
Collecting joblib
  Using cached joblib-1.1.0-py2.py3-none-any.whl (306 kB)
Collecting click
  Using cached click-8.0.3-py3-none-any.whl (97 kB)
Collecting threadpoolctl>=2.0.0
  Downloading threadpoolctl-3.1.0-py3-none-any.whl (14 kB)
Collecting tensorflow>=2.2.0
  Using cached tensorflow-2.7.1-cp38-cp38-manylinux2010_x86_64.whl (495.1 MB)
Collecting gast>=0.2.1
  Using cached gast-0.4.0-py3-none-any.whl (9.8 kB)
Collecting tensorflow>=2.2.0
  Using cached tensorflow-2.7.0-cp38-cp38-manylinux2010_x86_64.whl (489.6 MB)
INFO: pip is looking at multiple versions of tensorflow-text to determine which version is compatible with other requirements. This could take a while.
Collecting tensorflow-text
  Using cached tensorflow_text-2.7.0-cp38-cp38-manylinux2010_x86_64.whl (4.9 MB)
  Using cached tensorflow_text-2.6.0-cp38-cp38-manylinux1_x86_64.whl (4.4 MB)
Collecting tensorflow>=2.2.0
  Using cached tensorflow-2.6.3-cp38-cp38-manylinux2010_x86_64.whl (463.9 MB)
Collecting six>=1.12.0
  Using cached six-1.15.0-py2.py3-none-any.whl (10 kB)
Collecting h5py>=2.9.0
  Downloading h5py-3.1.0-cp38-cp38-manylinux1_x86_64.whl (4.4 MB)
     |████████████████████████████████| 4.4 MB 2.4 MB/s 
Collecting tensorflow>=2.2.0
  Using cached tensorflow-2.6.2-cp38-cp38-manylinux2010_x86_64.whl (458.4 MB)
  Using cached tensorflow-2.6.1-cp38-cp38-manylinux2010_x86_64.whl (458.4 MB)
  Using cached tensorflow-2.6.0-cp38-cp38-manylinux2010_x86_64.whl (458.4 MB)
Collecting tensorflow-estimator~=2.6
  Using cached tensorflow_estimator-2.8.0-py2.py3-none-any.whl (462 kB)
Collecting tensorflow-text
  Using cached tensorflow_text-2.5.0-cp38-cp38-manylinux1_x86_64.whl (4.3 MB)
Collecting tensorflow>=2.2.0
  Using cached tensorflow-2.5.3-cp38-cp38-manylinux2010_x86_64.whl (460.4 MB)
  Downloading tensorflow-2.5.2-cp38-cp38-manylinux2010_x86_64.whl (454.5 MB)
     |████████████████████████████████| 454.5 MB 24 kB/s 
  Downloading tensorflow-2.5.1-cp38-cp38-manylinux2010_x86_64.whl (454.5 MB)
     |▍                               | 6.0 MB 266 kB/s eta 0:28:03^C
ERROR: Operation cancelled by user

Is there a way to fix this no-end tensorflow installation and is it possible to install GPU versions of pytorch and tensorflow?

GPL for sentence embedding tasks?

In the provided examples GPL us used for semantic search tasks: given a query, relevant results should be retrieved. Is it also the recommended approach to get meaningful embeddings / bi-encoders, or is it better to use TSDAE?

Recomended GPU Memory ? Cuda out of memory during query generation.

Hi,

Thank you for the amazing library.

I am trying GPL on another dataset but encounter problems during the query generation. Now I am using google colab V100 with 16GB GPU RAM with the following config. The CUDA would be out of memory after 10 percent of iterations.

How much GPU do you need for your experiment? Do I need to split the corpus into 10 splits to run ?

import gpl

gpl.toolkit.qgen(
    data_path = "xxxxxx",
    output_dir = "xxxxxxxx",
    generator_name_or_path="doc2query/msmarco-french-mt5-base-v1",
    ques_per_passage=1,
    bsz=1,
    qgen_prefix="qgen",
)
RuntimeError: CUDA out of memory during query generation (queries_per_passage: 1, batch_size_generation: 1). Please try smaller `queries_per_passage` and/or `batch_size_generation`.

Thank you for your help !

Evaluation data and generation code is not released ?

Hi,

First of all thank you for sharing this incredible work ! It truly is amazing that you've shared your code, some models weights and some generated data.

For my end-of-word studies I'd like to adapt this work with the LoRa technique. I came upon realization that the evaluation data is not shared (test.tsv) and the creation of the test.tsv file in unavailable in the repo, is there a reason for this ?

GPU speedup

I recon this is more of a generic question for TSADE + GPL (or any transformer used) , but can you use GPU by simply doing something like gpl.to(device)?

Loss function

Is it a typo of having the minus sign "-" in the MarginMSE loss function in Equation (1) in the GPL paper?

There should be no minus sign "-". Because the model should minimize the MSE(delta_teacher, delta_student), not maximize it.
I checked the released code of GPL, the loss function is without the minus sign "-".

image

image

Support for Azure?

Tried to run the toy example on Azure, and I believe I made it all the way through training on the generated. My logs abruptly cut off so not sure on the full error. But am wondering if this is the culprit:

WARNING [root._load_auto_model:789] No sentence-transformers model found with name /root/.cache/torch/sentence_transformers/distilbert-base-uncased. Creating a new one with MEAN pooling.

Azure ML can only write to an Outputs folder-wondering if that's the issue? Am guessing this is included in the Beir data loader, though I couldn't find the actual code to this warning.

Training code:

import sys, os, joblib
import gpl

# Save the result to the outputs folder
os.makedirs("outputs", exist_ok=True)

dataset = 'fiqa'

gpl.train(
    path_to_generated_data = "generated/" + dataset,
    base_ckpt = "distilbert-base-uncased",
    gpl_score_function = "dot", 
    batch_size_gpl = 4, 
    gpl_steps = 100, 
    new_size = 10, 
    queries_per_passage = 1,
    output_dir = "outputs/" + dataset,
    evaluation_data = "./" + dataset, 
    evaluation_output = "evaluation/" + dataset,
    generator = "BeIR/query-gen-msmarco-t5-base-v1",
    retrievers = ["msmarco-distilbert-base-v3", "msmarco-MiniLM-L-6-v3"], 
    retriever_score_functions = ["cos_sim", "cos_sim"], 
    cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2", 
    mnrl_output_dir = None,
    mnrl_evaluation_output = None,
    qgen_prefix = "qgen",
)

Logs:


/azureml-envs/azureml_ec637423e82cc698715575ac22b521b8/lib/python3.6/site-packages/paramiko/transport.py:33: CryptographyDeprecationWarning: Python 3.6 is no longer supported by the Python core team. Therefore, support for it is deprecated in cryptography and will be removed in a future release.
  from cryptography.hazmat.backends import default_backend
2022-06-29 19:52:08.489294: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /azureml-envs/azureml_ec637423e82cc698715575ac22b521b8/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64
2022-06-29 19:52:08.489374: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2022-06-29 19:52:11 - Loading faiss with AVX2 support.
2022-06-29 19:52:11 - Could not load library with AVX2 support due to:
ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'",)
2022-06-29 19:52:11 - Loading faiss.
2022-06-29 19:52:11 - Successfully loaded faiss.
[2022-06-29 19:52:12] INFO [gpl.train.train:125] No generated queries found. Now generating it
[2022-06-29 19:52:12] INFO [beir.datasets.data_loader.load_corpus:89] Loading Corpus...

  0%|          | 0/10 [00:00<?, ?it/s]
100%|██████████| 10/10 [00:00<00:00, 41486.69it/s]
[2022-06-29 19:52:12] INFO [beir.datasets.data_loader.load_corpus:91] Loaded 10 Documents.
[2022-06-29 19:52:12] INFO [beir.datasets.data_loader.load_corpus:92] Doc Example: {'text': "I'm not saying I don't like the idea of on-the-job training too, but you can't expect the company to do that. Training workers is not their job - they're building software. Perhaps educational systems in the U.S. (or their students) should worry a little about getting marketable skills in exchange for their massive investment in education, rather than getting out with thousands in student debt and then complaining that they aren't qualified to do anything.", 'title': ''}

Downloading:   0%|          | 0.00/1.81k [00:00<?, ?B/s]
Downloading: 100%|██████████| 1.81k/1.81k [00:00<00:00, 1.58MB/s]

Downloading:   0%|          | 0.00/1.35k [00:00<?, ?B/s]
Downloading: 100%|██████████| 1.35k/1.35k [00:00<00:00, 1.12MB/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]
Downloading: 100%|██████████| 773k/773k [00:00<00:00, 11.1MB/s]

Downloading:   0%|          | 0.00/1.74k [00:00<?, ?B/s]
Downloading: 100%|██████████| 1.74k/1.74k [00:00<00:00, 1.59MB/s]

Downloading:   0%|          | 0.00/850M [00:00<?, ?B/s]
Downloading:   0%|          | 2.39M/850M [00:00<00:35, 25.1MB/s]
Downloading:   1%|          | 6.66M/850M [00:00<00:24, 36.6MB/s]
Downloading:   1%|▏         | 11.0M/850M [00:00<00:21, 40.6MB/s]
Downloading:   2%|▏         | 15.1M/850M [00:00<00:20, 41.8MB/s]
Downloading:   2%|▏         | 19.1M/850M [00:00<00:20, 41.6MB/s]
Downloading:   3%|▎         | 23.1M/850M [00:01<00:51, 17.0MB/s]
Downloading:   3%|▎         | 26.0M/850M [00:01<00:46, 18.4MB/s]
Downloading:   4%|▎         | 29.8M/850M [00:01<00:38, 22.4MB/s]
Downloading:   4%|▍         | 33.8M/850M [00:01<00:32, 26.4MB/s]
Downloading:   4%|▍         | 37.4M/850M [00:01<00:29, 29.1MB/s]
Downloading:   5%|▍         | 41.5M/850M [00:01<00:26, 32.4MB/s]
Downloading:   5%|▌         | 45.6M/850M [00:01<00:24, 35.1MB/s]
Downloading:   6%|▌         | 49.6M/850M [00:01<00:22, 37.1MB/s]
Downloading:   6%|▋         | 53.6M/850M [00:01<00:21, 38.3MB/s]
Downloading:   7%|▋         | 57.8M/850M [00:01<00:20, 40.0MB/s]
Downloading:   7%|▋         | 61.9M/850M [00:02<00:20, 40.8MB/s]
Downloading:   8%|▊         | 66.0M/850M [00:02<00:19, 41.6MB/s]
Downloading:   8%|▊         | 70.3M/850M [00:02<00:19, 42.5MB/s]
Downloading:   9%|▉         | 74.5M/850M [00:02<00:18, 43.2MB/s]
Downloading:   9%|▉         | 78.7M/850M [00:02<00:19, 42.3MB/s]
Downloading:  10%|▉         | 82.8M/850M [00:02<00:19, 41.8MB/s]
Downloading:  10%|█         | 87.1M/850M [00:02<00:18, 42.8MB/s]
Downloading:  11%|█         | 91.2M/850M [00:02<00:18, 43.0MB/s]
Downloading:  11%|█         | 95.7M/850M [00:02<00:17, 44.0MB/s]
Downloading:  12%|█▏        | 99.9M/850M [00:02<00:17, 43.7MB/s]
Downloading:  12%|█▏        | 104M/850M [00:03<00:18, 42.9MB/s] 
Downloading:  13%|█▎        | 108M/850M [00:03<00:18, 42.7MB/s]
Downloading:  13%|█▎        | 112M/850M [00:03<00:17, 43.3MB/s]
Downloading:  14%|█▎        | 117M/850M [00:03<00:17, 43.5MB/s]
Downloading:  14%|█▍        | 121M/850M [00:03<00:18, 42.4MB/s]
Downloading:  15%|█▍        | 125M/850M [00:03<00:17, 43.0MB/s]
Downloading:  15%|█▌        | 129M/850M [00:03<00:17, 42.7MB/s]
Downloading:  16%|█▌        | 133M/850M [00:03<00:17, 42.8MB/s]
Downloading:  16%|█▌        | 138M/850M [00:03<00:17, 43.6MB/s]
Downloading:  17%|█▋        | 142M/850M [00:03<00:17, 43.5MB/s]
Downloading:  17%|█▋        | 146M/850M [00:04<00:16, 44.2MB/s]
Downloading:  18%|█▊        | 150M/850M [00:04<00:17, 43.1MB/s]
Downloading:  18%|█▊        | 154M/850M [00:04<00:16, 43.0MB/s]
Downloading:  19%|█▊        | 159M/850M [00:04<00:16, 43.5MB/s]
Downloading:  19%|█▉        | 163M/850M [00:04<00:16, 44.0MB/s]
Downloading:  20%|█▉        | 167M/850M [00:04<00:16, 44.3MB/s]
Downloading:  20%|██        | 172M/850M [00:04<00:16, 43.1MB/s]
Downloading:  21%|██        | 176M/850M [00:04<00:15, 44.6MB/s]
Downloading:  21%|██        | 180M/850M [00:04<00:15, 44.8MB/s]
Downloading:  22%|██▏       | 185M/850M [00:05<00:15, 45.7MB/s]
Downloading:  22%|██▏       | 190M/850M [00:05<00:14, 46.4MB/s]
Downloading:  23%|██▎       | 194M/850M [00:05<00:14, 46.4MB/s]
Downloading:  23%|██▎       | 198M/850M [00:05<00:16, 41.8MB/s]
Downloading:  24%|██▍       | 203M/850M [00:05<00:16, 41.6MB/s]
Downloading:  24%|██▍       | 207M/850M [00:05<00:15, 42.5MB/s]
Downloading:  25%|██▍       | 211M/850M [00:05<00:15, 44.2MB/s]
Downloading:  25%|██▌       | 216M/850M [00:05<00:14, 45.0MB/s]
Downloading:  26%|██▌       | 220M/850M [00:05<00:14, 45.1MB/s]
Downloading:  26%|██▋       | 225M/850M [00:05<00:14, 45.9MB/s]
Downloading:  27%|██▋       | 229M/850M [00:06<00:14, 45.3MB/s]
Downloading:  27%|██▋       | 234M/850M [00:06<00:14, 45.5MB/s]
Downloading:  28%|██▊       | 238M/850M [00:06<00:14, 45.5MB/s]
Downloading:  29%|██▊       | 243M/850M [00:06<00:13, 46.5MB/s]
Downloading:  29%|██▉       | 247M/850M [00:06<00:13, 45.8MB/s]
Downloading:  30%|██▉       | 251M/850M [00:06<00:13, 45.7MB/s]
Downloading:  30%|███       | 256M/850M [00:06<00:13, 44.7MB/s]
Downloading:  31%|███       | 260M/850M [00:06<00:13, 45.0MB/s]
Downloading:  31%|███       | 265M/850M [00:06<00:13, 45.9MB/s]
Downloading:  32%|███▏      | 269M/850M [00:06<00:13, 46.5MB/s]
Downloading:  32%|███▏      | 274M/850M [00:07<00:13, 46.1MB/s]
Downloading:  33%|███▎      | 278M/850M [00:07<00:12, 46.5MB/s]
Downloading:  33%|███▎      | 283M/850M [00:07<00:12, 46.6MB/s]
Downloading:  34%|███▍      | 287M/850M [00:07<00:14, 42.0MB/s]
Downloading:  34%|███▍      | 291M/850M [00:07<00:13, 42.7MB/s]
Downloading:  35%|███▍      | 296M/850M [00:07<00:13, 43.3MB/s]
Downloading:  35%|███▌      | 300M/850M [00:07<00:13, 43.8MB/s]
Downloading:  36%|███▌      | 305M/850M [00:07<00:12, 44.8MB/s]
Downloading:  36%|███▋      | 309M/850M [00:07<00:12, 44.4MB/s]
Downloading:  37%|███▋      | 313M/850M [00:08<00:12, 44.7MB/s]
Downloading:  37%|███▋      | 317M/850M [00:08<00:12, 44.5MB/s]
Downloading:  38%|███▊      | 322M/850M [00:08<00:12, 45.7MB/s]
Downloading:  38%|███▊      | 326M/850M [00:08<00:12, 45.1MB/s]
Downloading:  39%|███▉      | 331M/850M [00:08<00:12, 44.9MB/s]
Downloading:  39%|███▉      | 335M/850M [00:08<00:12, 43.6MB/s]
Downloading:  40%|███▉      | 339M/850M [00:08<00:12, 43.9MB/s]
Downloading:  40%|████      | 344M/850M [00:08<00:11, 44.4MB/s]
Downloading:  41%|████      | 348M/850M [00:08<00:12, 43.7MB/s]
Downloading:  41%|████▏     | 352M/850M [00:08<00:12, 43.5MB/s]
Downloading:  42%|████▏     | 356M/850M [00:09<00:11, 43.4MB/s]
Downloading:  42%|████▏     | 360M/850M [00:09<00:11, 42.9MB/s]
Downloading:  43%|████▎     | 364M/850M [00:09<00:12, 39.2MB/s]
Downloading:  43%|████▎     | 368M/850M [00:09<00:13, 38.1MB/s]
Downloading:  44%|████▍     | 373M/850M [00:09<00:12, 40.1MB/s]
Downloading:  44%|████▍     | 376M/850M [00:09<00:14, 33.6MB/s]
Downloading:  45%|████▍     | 380M/850M [00:09<00:14, 33.5MB/s]
Downloading:  45%|████▌     | 384M/850M [00:09<00:13, 36.4MB/s]
Downloading:  46%|████▌     | 388M/850M [00:09<00:12, 38.0MB/s]
Downloading:  46%|████▌     | 392M/850M [00:10<00:11, 40.2MB/s]
Downloading:  47%|████▋     | 397M/850M [00:10<00:11, 42.6MB/s]
Downloading:  47%|████▋     | 401M/850M [00:10<00:10, 43.5MB/s]
Downloading:  48%|████▊     | 406M/850M [00:10<00:10, 43.3MB/s]
Downloading:  48%|████▊     | 410M/850M [00:10<00:10, 43.9MB/s]
Downloading:  49%|████▊     | 414M/850M [00:10<00:10, 43.0MB/s]
Downloading:  49%|████▉     | 418M/850M [00:10<00:10, 43.4MB/s]
Downloading:  50%|████▉     | 423M/850M [00:10<00:10, 43.8MB/s]
Downloading:  50%|█████     | 427M/850M [00:10<00:09, 44.9MB/s]
Downloading:  51%|█████     | 432M/850M [00:10<00:09, 45.1MB/s]
Downloading:  51%|█████▏    | 436M/850M [00:11<00:09, 44.8MB/s]
Downloading:  52%|█████▏    | 440M/850M [00:11<00:09, 45.2MB/s]
Downloading:  52%|█████▏    | 445M/850M [00:11<00:09, 42.9MB/s]
Downloading:  53%|█████▎    | 449M/850M [00:11<00:09, 42.3MB/s]
Downloading:  53%|█████▎    | 453M/850M [00:11<00:09, 41.8MB/s]
Downloading:  54%|█████▎    | 457M/850M [00:11<00:09, 42.7MB/s]
Downloading:  54%|█████▍    | 461M/850M [00:11<00:09, 42.3MB/s]
Downloading:  55%|█████▍    | 465M/850M [00:11<00:09, 41.9MB/s]
Downloading:  55%|█████▌    | 469M/850M [00:11<00:09, 40.3MB/s]
Downloading:  56%|█████▌    | 474M/850M [00:12<00:09, 41.9MB/s]
Downloading:  56%|█████▌    | 478M/850M [00:12<00:10, 38.1MB/s]
Downloading:  57%|█████▋    | 482M/850M [00:12<00:09, 39.1MB/s]
Downloading:  57%|█████▋    | 486M/850M [00:12<00:09, 41.1MB/s]
Downloading:  58%|█████▊    | 490M/850M [00:12<00:08, 42.4MB/s]
Downloading:  58%|█████▊    | 495M/850M [00:12<00:08, 44.1MB/s]
Downloading:  59%|█████▊    | 499M/850M [00:12<00:08, 44.8MB/s]
Downloading:  59%|█████▉    | 504M/850M [00:12<00:08, 42.6MB/s]
Downloading:  60%|█████▉    | 508M/850M [00:12<00:08, 44.2MB/s]
Downloading:  60%|██████    | 513M/850M [00:12<00:07, 44.4MB/s]
Downloading:  61%|██████    | 517M/850M [00:13<00:08, 43.6MB/s]
Downloading:  61%|██████▏   | 521M/850M [00:13<00:08, 41.9MB/s]
Downloading:  62%|██████▏   | 525M/850M [00:13<00:07, 43.3MB/s]
Downloading:  62%|██████▏   | 530M/850M [00:13<00:07, 43.9MB/s]
Downloading:  63%|██████▎   | 534M/850M [00:13<00:07, 44.5MB/s]
Downloading:  63%|██████▎   | 538M/850M [00:13<00:07, 43.8MB/s]
Downloading:  64%|██████▍   | 543M/850M [00:13<00:07, 43.4MB/s]
Downloading:  64%|██████▍   | 547M/850M [00:13<00:07, 43.8MB/s]
Downloading:  65%|██████▍   | 551M/850M [00:13<00:07, 44.1MB/s]
Downloading:  65%|██████▌   | 555M/850M [00:14<00:06, 44.3MB/s]
Downloading:  66%|██████▌   | 560M/850M [00:14<00:06, 44.5MB/s]
Downloading:  66%|██████▋   | 564M/850M [00:14<00:08, 34.8MB/s]
Downloading:  67%|██████▋   | 568M/850M [00:14<00:08, 35.4MB/s]
Downloading:  67%|██████▋   | 572M/850M [00:14<00:07, 37.1MB/s]
Downloading:  68%|██████▊   | 576M/850M [00:14<00:07, 39.7MB/s]
Downloading:  68%|██████▊   | 580M/850M [00:14<00:06, 40.6MB/s]
Downloading:  69%|██████▊   | 584M/850M [00:14<00:06, 41.9MB/s]
Downloading:  69%|██████▉   | 589M/850M [00:14<00:06, 41.3MB/s]
Downloading:  70%|██████▉   | 593M/850M [00:15<00:06, 41.5MB/s]
Downloading:  70%|███████   | 597M/850M [00:15<00:06, 40.2MB/s]
Downloading:  71%|███████   | 600M/850M [00:15<00:06, 40.4MB/s]
Downloading:  71%|███████   | 605M/850M [00:15<00:06, 41.2MB/s]
Downloading:  72%|███████▏  | 609M/850M [00:15<00:06, 41.7MB/s]
Downloading:  72%|███████▏  | 613M/850M [00:15<00:05, 42.1MB/s]
Downloading:  73%|███████▎  | 617M/850M [00:15<00:05, 43.9MB/s]
Downloading:  73%|███████▎  | 622M/850M [00:15<00:05, 44.6MB/s]
Downloading:  74%|███████▎  | 626M/850M [00:15<00:05, 43.8MB/s]
Downloading:  74%|███████▍  | 630M/850M [00:15<00:05, 44.7MB/s]
Downloading:  75%|███████▍  | 635M/850M [00:16<00:05, 44.0MB/s]
Downloading:  75%|███████▌  | 639M/850M [00:16<00:05, 44.2MB/s]
Downloading:  76%|███████▌  | 644M/850M [00:16<00:04, 45.8MB/s]
Downloading:  76%|███████▌  | 648M/850M [00:16<00:04, 46.6MB/s]
Downloading:  77%|███████▋  | 653M/850M [00:16<00:04, 45.9MB/s]
Downloading:  77%|███████▋  | 657M/850M [00:16<00:04, 46.8MB/s]
Downloading:  78%|███████▊  | 662M/850M [00:16<00:04, 47.0MB/s]
Downloading:  78%|███████▊  | 667M/850M [00:16<00:04, 46.0MB/s]
Downloading:  79%|███████▉  | 671M/850M [00:16<00:04, 43.7MB/s]
Downloading:  79%|███████▉  | 675M/850M [00:16<00:04, 43.8MB/s]
Downloading:  80%|███████▉  | 679M/850M [00:17<00:04, 44.1MB/s]
Downloading:  80%|████████  | 684M/850M [00:17<00:03, 43.9MB/s]
Downloading:  81%|████████  | 688M/850M [00:17<00:03, 45.3MB/s]
Downloading:  81%|████████▏ | 693M/850M [00:17<00:03, 45.3MB/s]
Downloading:  82%|████████▏ | 697M/850M [00:17<00:03, 44.0MB/s]
Downloading:  82%|████████▏ | 701M/850M [00:17<00:03, 45.2MB/s]
Downloading:  83%|████████▎ | 706M/850M [00:17<00:03, 45.6MB/s]
Downloading:  84%|████████▎ | 710M/850M [00:17<00:03, 45.7MB/s]
Downloading:  84%|████████▍ | 715M/850M [00:17<00:03, 46.5MB/s]
Downloading:  85%|████████▍ | 719M/850M [00:17<00:03, 45.2MB/s]
Downloading:  85%|████████▌ | 724M/850M [00:18<00:02, 44.9MB/s]
Downloading:  86%|████████▌ | 728M/850M [00:18<00:03, 41.5MB/s]
Downloading:  86%|████████▌ | 732M/850M [00:18<00:02, 42.6MB/s]
Downloading:  87%|████████▋ | 737M/850M [00:18<00:02, 43.8MB/s]
Downloading:  87%|████████▋ | 741M/850M [00:18<00:02, 43.4MB/s]
Downloading:  88%|████████▊ | 746M/850M [00:18<00:02, 44.5MB/s]
Downloading:  88%|████████▊ | 750M/850M [00:18<00:02, 41.5MB/s]
Downloading:  89%|████████▊ | 754M/850M [00:18<00:02, 42.8MB/s]
Downloading:  89%|████████▉ | 758M/850M [00:18<00:02, 43.3MB/s]
Downloading:  90%|████████▉ | 763M/850M [00:19<00:02, 44.2MB/s]
Downloading:  90%|█████████ | 767M/850M [00:19<00:01, 44.0MB/s]
Downloading:  91%|█████████ | 772M/850M [00:19<00:01, 45.1MB/s]
Downloading:  91%|█████████▏| 776M/850M [00:19<00:01, 45.3MB/s]
Downloading:  92%|█████████▏| 780M/850M [00:19<00:01, 45.6MB/s]
Downloading:  92%|█████████▏| 785M/850M [00:19<00:01, 45.7MB/s]
Downloading:  93%|█████████▎| 790M/850M [00:19<00:01, 46.7MB/s]
Downloading:  93%|█████████▎| 794M/850M [00:19<00:01, 47.7MB/s]
Downloading:  94%|█████████▍| 799M/850M [00:19<00:01, 46.7MB/s]
Downloading:  94%|█████████▍| 803M/850M [00:19<00:01, 46.5MB/s]
Downloading:  95%|█████████▌| 808M/850M [00:20<00:00, 47.2MB/s]
Downloading:  96%|█████████▌| 812M/850M [00:20<00:00, 45.6MB/s]
Downloading:  96%|█████████▌| 817M/850M [00:20<00:00, 46.0MB/s]
Downloading:  97%|█████████▋| 821M/850M [00:20<00:00, 46.2MB/s]
Downloading:  97%|█████████▋| 826M/850M [00:20<00:00, 45.6MB/s]
Downloading:  98%|█████████▊| 830M/850M [00:20<00:00, 43.5MB/s]
Downloading:  98%|█████████▊| 834M/850M [00:20<00:00, 38.2MB/s]
Downloading:  99%|█████████▊| 838M/850M [00:20<00:00, 28.4MB/s]
Downloading:  99%|█████████▉| 841M/850M [00:21<00:00, 28.4MB/s]
Downloading:  99%|█████████▉| 844M/850M [00:21<00:00, 26.6MB/s]
Downloading: 100%|█████████▉| 847M/850M [00:21<00:00, 25.2MB/s]
Downloading: 100%|██████████| 850M/850M [00:21<00:00, 27.8MB/s]
Downloading: 100%|██████████| 850M/850M [00:21<00:00, 41.5MB/s]
[2022-06-29 19:52:46] INFO [beir.generation.models.auto_model.__init__:16] Use pytorch device: cpu
[2022-06-29 19:52:46] INFO [beir.generation.generate.generate:40] Starting to Generate 1 Questions Per Passage using top-p (nucleus) sampling...
[2022-06-29 19:52:46] INFO [beir.generation.generate.generate:41] Params: top_p = 0.95
[2022-06-29 19:52:46] INFO [beir.generation.generate.generate:42] Params: top_k = 25
[2022-06-29 19:52:46] INFO [beir.generation.generate.generate:43] Params: max_length = 64
[2022-06-29 19:52:46] INFO [beir.generation.generate.generate:44] Params: ques_per_passage = 1
[2022-06-29 19:52:46] INFO [beir.generation.generate.generate:45] Params: batch size = 32

pas:   0%|          | 0/1 [00:00<?, ?it/s]
pas: 100%|██████████| 1/1 [00:18<00:00, 18.03s/it]
pas: 100%|██████████| 1/1 [00:18<00:00, 18.03s/it]
[2022-06-29 19:53:04] INFO [beir.generation.generate.generate:82] Saving 10 Generated Queries...
[2022-06-29 19:53:04] INFO [beir.generation.generate.save:23] Saving Generated Queries to generated/fiqa/qgen-queries.jsonl
[2022-06-29 19:53:04] INFO [beir.generation.generate.save:26] Saving Generated Qrels to generated/fiqa/qgen-qrels/train.tsv
[2022-06-29 19:53:04] INFO [beir.datasets.data_loader.load:67] Loading Corpus...

  0%|          | 0/10 [00:00<?, ?it/s]
100%|██████████| 10/10 [00:00<00:00, 11963.22it/s]
[2022-06-29 19:53:04] INFO [beir.datasets.data_loader.load:69] Loaded 10 TRAIN Documents.
[2022-06-29 19:53:04] INFO [beir.datasets.data_loader.load:70] Doc Example: {'text': "I'm not saying I don't like the idea of on-the-job training too, but you can't expect the company to do that. Training workers is not their job - they're building software. Perhaps educational systems in the U.S. (or their students) should worry a little about getting marketable skills in exchange for their massive investment in education, rather than getting out with thousands in student debt and then complaining that they aren't qualified to do anything.", 'title': ''}
[2022-06-29 19:53:04] INFO [beir.datasets.data_loader.load:73] Loading Queries...
[2022-06-29 19:53:04] INFO [beir.datasets.data_loader.load:79] Loaded 10 TRAIN Queries.
[2022-06-29 19:53:04] INFO [beir.datasets.data_loader.load:80] Query Example: can you train yourself
[2022-06-29 19:53:04] INFO [gpl.train.train:136] No hard-negative data found. Now mining it
[2022-06-29 19:53:04] INFO [beir.datasets.data_loader.load:67] Loading Corpus...

  0%|          | 0/10 [00:00<?, ?it/s]
100%|██████████| 10/10 [00:00<00:00, 76959.71it/s]
[2022-06-29 19:53:04] INFO [beir.datasets.data_loader.load:69] Loaded 10 TRAIN Documents.
[2022-06-29 19:53:04] INFO [beir.datasets.data_loader.load:70] Doc Example: {'text': "I'm not saying I don't like the idea of on-the-job training too, but you can't expect the company to do that. Training workers is not their job - they're building software. Perhaps educational systems in the U.S. (or their students) should worry a little about getting marketable skills in exchange for their massive investment in education, rather than getting out with thousands in student debt and then complaining that they aren't qualified to do anything.", 'title': ''}
[2022-06-29 19:53:04] INFO [beir.datasets.data_loader.load:73] Loading Queries...
[2022-06-29 19:53:04] INFO [beir.datasets.data_loader.load:79] Loaded 10 TRAIN Queries.
[2022-06-29 19:53:04] INFO [beir.datasets.data_loader.load:80] Query Example: can you train yourself
[2022-06-29 19:53:04] WARNING [gpl.toolkit.mine.__init__:42] `negatives_per_query` > corpus size. Please use a smaller `negatives_per_query`
[2022-06-29 19:53:04] INFO [gpl.toolkit.mine._mine_sbert:49] Mining with msmarco-distilbert-base-v3
[2022-06-29 19:53:04] INFO [sentence_transformers.SentenceTransformer.__init__:60] Load pretrained SentenceTransformer: msmarco-distilbert-base-v3

Downloading:   0%|          | 0.00/690 [00:00<?, ?B/s]
Downloading: 100%|██████████| 690/690 [00:00<00:00, 305kB/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]
Downloading: 100%|██████████| 190/190 [00:00<00:00, 176kB/s]

Downloading:   0%|          | 0.00/3.71k [00:00<?, ?B/s]
Downloading: 100%|██████████| 3.71k/3.71k [00:00<00:00, 2.23MB/s]

Downloading:   0%|          | 0.00/545 [00:00<?, ?B/s]
Downloading: 100%|██████████| 545/545 [00:00<00:00, 352kB/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]
Downloading: 100%|██████████| 122/122 [00:00<00:00, 96.9kB/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]
Downloading: 100%|██████████| 229/229 [00:00<00:00, 136kB/s]

Downloading:   0%|          | 0.00/265M [00:00<?, ?B/s]
Downloading:   1%|          | 1.96M/265M [00:00<00:13, 19.6MB/s]
Downloading:   3%|▎         | 6.75M/265M [00:00<00:07, 36.2MB/s]
Downloading:   4%|▍         | 10.8M/265M [00:00<00:06, 38.4MB/s]
Downloading:   6%|▌         | 15.6M/265M [00:00<00:05, 41.9MB/s]
Downloading:   8%|▊         | 20.2M/265M [00:00<00:05, 43.4MB/s]
Downloading:   9%|▉         | 24.9M/265M [00:00<00:05, 44.5MB/s]
Downloading:  11%|█         | 29.3M/265M [00:00<00:05, 42.6MB/s]
Downloading:  13%|█▎        | 34.0M/265M [00:00<00:05, 43.9MB/s]
Downloading:  15%|█▍        | 38.7M/265M [00:00<00:05, 44.9MB/s]
Downloading:  16%|█▋        | 43.4M/265M [00:01<00:04, 45.7MB/s]
Downloading:  18%|█▊        | 48.0M/265M [00:01<00:04, 45.0MB/s]
Downloading:  20%|█▉        | 52.5M/265M [00:01<00:04, 43.2MB/s]
Downloading:  21%|██▏       | 57.0M/265M [00:01<00:04, 43.6MB/s]
Downloading:  23%|██▎       | 61.5M/265M [00:01<00:04, 44.0MB/s]
Downloading:  25%|██▍       | 66.1M/265M [00:01<00:04, 44.5MB/s]
Downloading:  27%|██▋       | 70.8M/265M [00:01<00:04, 45.4MB/s]
Downloading:  28%|██▊       | 75.4M/265M [00:01<00:04, 45.6MB/s]
Downloading:  30%|███       | 80.0M/265M [00:01<00:04, 45.2MB/s]
Downloading:  32%|███▏      | 84.5M/265M [00:01<00:04, 42.9MB/s]
Downloading:  33%|███▎      | 88.8M/265M [00:02<00:04, 42.5MB/s]
Downloading:  35%|███▌      | 93.3M/265M [00:02<00:03, 43.1MB/s]
Downloading:  37%|███▋      | 97.9M/265M [00:02<00:03, 44.0MB/s]
Downloading:  39%|███▊      | 102M/265M [00:02<00:03, 44.4MB/s] 
Downloading:  40%|████      | 107M/265M [00:02<00:03, 43.9MB/s]
Downloading:  42%|████▏     | 112M/265M [00:02<00:03, 44.7MB/s]
Downloading:  44%|████▎     | 116M/265M [00:02<00:03, 44.4MB/s]
Downloading:  45%|████▌     | 121M/265M [00:02<00:03, 45.2MB/s]
Downloading:  47%|████▋     | 125M/265M [00:02<00:03, 45.5MB/s]
Downloading:  49%|████▉     | 130M/265M [00:02<00:02, 45.4MB/s]
Downloading:  51%|█████     | 135M/265M [00:03<00:02, 45.6MB/s]
Downloading:  52%|█████▏    | 139M/265M [00:03<00:02, 44.7MB/s]
Downloading:  54%|█████▍    | 144M/265M [00:03<00:02, 44.8MB/s]
Downloading:  56%|█████▌    | 148M/265M [00:03<00:02, 44.5MB/s]
Downloading:  57%|█████▋    | 153M/265M [00:03<00:02, 44.2MB/s]
Downloading:  59%|█████▉    | 157M/265M [00:03<00:02, 42.1MB/s]
Downloading:  61%|██████    | 161M/265M [00:03<00:02, 42.7MB/s]
Downloading:  62%|██████▏   | 166M/265M [00:03<00:02, 42.8MB/s]
Downloading:  64%|██████▍   | 170M/265M [00:03<00:02, 42.3MB/s]
Downloading:  66%|██████▌   | 175M/265M [00:03<00:02, 43.1MB/s]
Downloading:  67%|██████▋   | 179M/265M [00:04<00:01, 43.9MB/s]
Downloading:  69%|██████▉   | 184M/265M [00:04<00:01, 43.8MB/s]
Downloading:  71%|███████   | 188M/265M [00:04<00:01, 44.7MB/s]
Downloading:  73%|███████▎  | 193M/265M [00:04<00:01, 45.4MB/s]
Downloading:  74%|███████▍  | 197M/265M [00:04<00:01, 45.3MB/s]
Downloading:  76%|███████▌  | 202M/265M [00:04<00:01, 45.6MB/s]
Downloading:  78%|███████▊  | 207M/265M [00:04<00:01, 45.4MB/s]
Downloading:  80%|███████▉  | 211M/265M [00:04<00:01, 45.0MB/s]
Downloading:  81%|████████▏ | 216M/265M [00:04<00:01, 44.3MB/s]
Downloading:  83%|████████▎ | 220M/265M [00:05<00:01, 44.7MB/s]
Downloading:  85%|████████▍ | 225M/265M [00:05<00:00, 45.3MB/s]
Downloading:  86%|████████▋ | 230M/265M [00:05<00:00, 45.1MB/s]
Downloading:  88%|████████▊ | 234M/265M [00:05<00:00, 43.1MB/s]
Downloading:  90%|████████▉ | 239M/265M [00:05<00:00, 44.0MB/s]
Downloading:  92%|█████████▏| 243M/265M [00:05<00:00, 44.3MB/s]
Downloading:  93%|█████████▎| 248M/265M [00:05<00:00, 43.9MB/s]
Downloading:  95%|█████████▌| 252M/265M [00:05<00:00, 44.8MB/s]
Downloading:  97%|█████████▋| 257M/265M [00:05<00:00, 45.7MB/s]
Downloading:  99%|█████████▊| 262M/265M [00:05<00:00, 45.6MB/s]
Downloading: 100%|██████████| 265M/265M [00:06<00:00, 44.1MB/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]
Downloading: 100%|██████████| 53.0/53.0 [00:00<00:00, 42.0kB/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]
Downloading: 100%|██████████| 112/112 [00:00<00:00, 83.8kB/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]
Downloading:  18%|█▊        | 86.0k/466k [00:00<00:00, 751kB/s]
Downloading: 100%|██████████| 466k/466k [00:00<00:00, 2.36MB/s]

Downloading:   0%|          | 0.00/499 [00:00<?, ?B/s]
Downloading: 100%|██████████| 499/499 [00:00<00:00, 442kB/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]
Downloading:  38%|███▊      | 87.0k/232k [00:00<00:00, 715kB/s]
Downloading: 100%|██████████| 232k/232k [00:00<00:00, 1.41MB/s]
[2022-06-29 19:53:16] INFO [sentence_transformers.SentenceTransformer.__init__:97] Use pytorch device: cpu

Batches:   0%|          | 0/1 [00:00<?, ?it/s]
Batches: 100%|██████████| 1/1 [00:04<00:00,  4.63s/it]
Batches: 100%|██████████| 1/1 [00:04<00:00,  4.63s/it]

  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00,  4.05it/s]
100%|██████████| 1/1 [00:00<00:00,  4.04it/s]
[2022-06-29 19:53:21] INFO [gpl.toolkit.mine._mine_sbert:49] Mining with msmarco-MiniLM-L-6-v3
[2022-06-29 19:53:21] INFO [sentence_transformers.SentenceTransformer.__init__:60] Load pretrained SentenceTransformer: msmarco-MiniLM-L-6-v3

Downloading:   0%|          | 0.00/736 [00:00<?, ?B/s]
Downloading: 100%|██████████| 736/736 [00:00<00:00, 457kB/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]
Downloading: 100%|██████████| 190/190 [00:00<00:00, 164kB/s]

Downloading:   0%|          | 0.00/3.68k [00:00<?, ?B/s]
Downloading: 100%|██████████| 3.68k/3.68k [00:00<00:00, 2.15MB/s]

Downloading:   0%|          | 0.00/627 [00:00<?, ?B/s]
Downloading: 100%|██████████| 627/627 [00:00<00:00, 430kB/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]
Downloading: 100%|██████████| 122/122 [00:00<00:00, 80.3kB/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]
Downloading: 100%|██████████| 229/229 [00:00<00:00, 182kB/s]

Downloading:   0%|          | 0.00/90.9M [00:00<?, ?B/s]
Downloading:   2%|▏         | 1.62M/90.9M [00:00<00:05, 16.2MB/s]
Downloading:   6%|▌         | 5.21M/90.9M [00:00<00:03, 27.8MB/s]
Downloading:  10%|█         | 9.52M/90.9M [00:00<00:02, 34.8MB/s]
Downloading:  15%|█▌        | 13.9M/90.9M [00:00<00:02, 38.4MB/s]
Downloading:  20%|█▉        | 17.8M/90.9M [00:00<00:01, 38.2MB/s]
Downloading:  24%|██▍       | 22.1M/90.9M [00:00<00:01, 39.8MB/s]
Downloading:  29%|██▉       | 26.6M/90.9M [00:00<00:01, 41.5MB/s]
Downloading:  34%|███▍      | 31.1M/90.9M [00:00<00:01, 42.9MB/s]
Downloading:  39%|███▉      | 35.9M/90.9M [00:00<00:01, 44.2MB/s]
Downloading:  45%|████▍     | 40.5M/90.9M [00:01<00:01, 44.8MB/s]
Downloading:  49%|████▉     | 45.0M/90.9M [00:01<00:01, 44.8MB/s]
Downloading:  55%|█████▍    | 49.6M/90.9M [00:01<00:00, 45.4MB/s]
Downloading:  60%|█████▉    | 54.3M/90.9M [00:01<00:00, 45.7MB/s]
Downloading:  65%|██████▍   | 58.9M/90.9M [00:01<00:00, 45.3MB/s]
Downloading:  70%|██████▉   | 63.4M/90.9M [00:01<00:00, 44.5MB/s]
Downloading:  75%|███████▌  | 68.2M/90.9M [00:01<00:00, 45.7MB/s]
Downloading:  80%|████████  | 73.0M/90.9M [00:01<00:00, 46.4MB/s]
Downloading:  85%|████████▌ | 77.7M/90.9M [00:01<00:00, 46.4MB/s]
Downloading:  91%|█████████ | 82.3M/90.9M [00:01<00:00, 45.5MB/s]
Downloading:  96%|█████████▌| 86.9M/90.9M [00:02<00:00, 44.8MB/s]
Downloading: 100%|██████████| 90.9M/90.9M [00:02<00:00, 43.2MB/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]
Downloading: 100%|██████████| 53.0/53.0 [00:00<00:00, 46.7kB/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]
Downloading: 100%|██████████| 112/112 [00:00<00:00, 85.4kB/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]
Downloading:  18%|█▊        | 81.9k/466k [00:00<00:00, 701kB/s]
Downloading: 100%|██████████| 466k/466k [00:00<00:00, 2.34MB/s]

Downloading:   0%|          | 0.00/430 [00:00<?, ?B/s]
Downloading: 100%|██████████| 430/430 [00:00<00:00, 305kB/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]
Downloading:  32%|███▏      | 74.8k/232k [00:00<00:00, 637kB/s]
Downloading: 100%|██████████| 232k/232k [00:00<00:00, 1.45MB/s]
[2022-06-29 19:53:28] INFO [sentence_transformers.SentenceTransformer.__init__:97] Use pytorch device: cpu

Batches:   0%|          | 0/1 [00:00<?, ?it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.56s/it]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.56s/it]

  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 12.29it/s]
[2022-06-29 19:53:30] INFO [gpl.toolkit.mine.run:114] Combining all the data

  0%|          | 0/10 [00:00<?, ?it/s]
100%|██████████| 10/10 [00:00<00:00, 177724.75it/s]
[2022-06-29 19:53:30] INFO [gpl.toolkit.mine.run:126] Saving data to generated/fiqa/hard-negatives.jsonl
[2022-06-29 19:53:30] INFO [gpl.toolkit.mine.run:130] Done
[2022-06-29 19:53:30] INFO [gpl.train.train:147] No GPL-training data found. Now generating it via pseudo labeling

Downloading:   0%|          | 0.00/794 [00:00<?, ?B/s]
Downloading: 100%|██████████| 794/794 [00:00<00:00, 500kB/s]

Downloading:   0%|          | 0.00/86.7M [00:00<?, ?B/s]
Downloading:   2%|▏         | 1.68M/86.7M [00:00<00:05, 17.5MB/s]
Downloading:   6%|▌         | 5.39M/86.7M [00:00<00:02, 30.0MB/s]
Downloading:  10%|▉         | 8.26M/86.7M [00:00<00:02, 30.0MB/s]
Downloading:  13%|█▎        | 11.1M/86.7M [00:00<00:02, 26.5MB/s]
Downloading:  17%|█▋        | 14.6M/86.7M [00:00<00:02, 30.0MB/s]
Downloading:  21%|██▏       | 18.5M/86.7M [00:00<00:02, 33.4MB/s]
Downloading:  26%|██▌       | 22.5M/86.7M [00:00<00:01, 36.1MB/s]
Downloading:  31%|███       | 26.6M/86.7M [00:00<00:01, 38.2MB/s]
Downloading:  36%|███▌      | 30.8M/86.7M [00:00<00:01, 40.0MB/s]
Downloading:  40%|████      | 34.8M/86.7M [00:01<00:01, 40.6MB/s]
Downloading:  45%|████▍     | 39.0M/86.7M [00:01<00:01, 41.6MB/s]
Downloading:  50%|████▉     | 43.1M/86.7M [00:01<00:01, 41.9MB/s]
Downloading:  55%|█████▍    | 47.5M/86.7M [00:01<00:00, 43.3MB/s]
Downloading:  60%|██████    | 52.2M/86.7M [00:01<00:00, 45.1MB/s]
Downloading:  65%|██████▌   | 56.6M/86.7M [00:01<00:00, 45.5MB/s]
Downloading:  70%|███████   | 61.1M/86.7M [00:01<00:00, 45.9MB/s]
Downloading:  76%|███████▌  | 65.6M/86.7M [00:01<00:00, 46.3MB/s]
Downloading:  81%|████████  | 70.1M/86.7M [00:01<00:00, 46.4MB/s]
Downloading:  86%|████████▌ | 74.6M/86.7M [00:01<00:00, 46.8MB/s]
Downloading:  91%|█████████▏| 79.1M/86.7M [00:02<00:00, 47.0MB/s]
Downloading:  96%|█████████▋| 83.6M/86.7M [00:02<00:00, 45.4MB/s]
Downloading: 100%|██████████| 86.7M/86.7M [00:02<00:00, 41.1MB/s]

Downloading:   0%|          | 0.00/316 [00:00<?, ?B/s]
Downloading: 100%|██████████| 316/316 [00:00<00:00, 270kB/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]
Downloading:  77%|███████▋  | 173k/226k [00:00<00:00, 1.48MB/s]
Downloading: 100%|██████████| 226k/226k [00:00<00:00, 1.89MB/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]
Downloading: 100%|██████████| 112/112 [00:00<00:00, 71.6kB/s]
[2022-06-29 19:53:37] INFO [sentence_transformers.cross_encoder.CrossEncoder.__init__:55] Use pytorch device: cpu
[2022-06-29 19:53:39] INFO [gpl.toolkit.pl.run:60] Begin pseudo labeling

  0%|          | 0/100 [00:00<?, ?it/s]
  1%|          | 1/100 [00:01<02:17,  1.39s/it]
  2%|▏         | 2/100 [00:02<02:08,  1.31s/it]
  3%|▎         | 3/100 [00:03<02:05,  1.29s/it]
  4%|▍         | 4/100 [00:04<01:35,  1.00it/s]
  5%|▌         | 5/100 [00:05<01:46,  1.12s/it]
  6%|▌         | 6/100 [00:07<01:50,  1.18s/it]
  7%|▋         | 7/100 [00:08<01:52,  1.21s/it]
  8%|▊         | 8/100 [00:09<01:53,  1.23s/it]
  9%|▉         | 9/100 [00:10<01:53,  1.25s/it]
 10%|█         | 10/100 [00:12<01:53,  1.26s/it]
 11%|█         | 11/100 [00:13<01:50,  1.25s/it]
 12%|█▏        | 12/100 [00:14<01:42,  1.16s/it]
 13%|█▎        | 13/100 [00:15<01:45,  1.21s/it]
 14%|█▍        | 14/100 [00:16<01:45,  1.23s/it]
 15%|█▌        | 15/100 [00:18<01:45,  1.25s/it]
 16%|█▌        | 16/100 [00:19<01:45,  1.25s/it]
 17%|█▋        | 17/100 [00:20<01:45,  1.27s/it]
 18%|█▊        | 18/100 [00:22<01:42,  1.24s/it]
 19%|█▉        | 19/100 [00:23<01:41,  1.25s/it]
 20%|██        | 20/100 [00:24<01:40,  1.25s/it]
 21%|██        | 21/100 [00:25<01:39,  1.26s/it]
 22%|██▏       | 22/100 [00:27<01:38,  1.26s/it]
 23%|██▎       | 23/100 [00:28<01:36,  1.25s/it]
 24%|██▍       | 24/100 [00:29<01:34,  1.25s/it]
 25%|██▌       | 25/100 [00:30<01:36,  1.28s/it]
 26%|██▌       | 26/100 [00:32<01:35,  1.29s/it]
 27%|██▋       | 27/100 [00:33<01:34,  1.29s/it]
 28%|██▊       | 28/100 [00:34<01:35,  1.32s/it]
 29%|██▉       | 29/100 [00:36<01:31,  1.29s/it]
 30%|███       | 30/100 [00:37<01:32,  1.32s/it]
 31%|███       | 31/100 [00:38<01:15,  1.10s/it]
 32%|███▏      | 32/100 [00:39<01:16,  1.13s/it]
 33%|███▎      | 33/100 [00:40<01:17,  1.16s/it]
 34%|███▍      | 34/100 [00:41<01:17,  1.17s/it]
 35%|███▌      | 35/100 [00:42<01:17,  1.19s/it]
 36%|███▌      | 36/100 [00:43<01:11,  1.12s/it]
 37%|███▋      | 37/100 [00:44<01:00,  1.05it/s]
 38%|███▊      | 38/100 [00:45<01:04,  1.04s/it]
 39%|███▉      | 39/100 [00:47<01:08,  1.11s/it]
 40%|████      | 40/100 [00:48<01:04,  1.08s/it]
 41%|████      | 41/100 [00:49<01:06,  1.13s/it]
 42%|████▏     | 42/100 [00:50<01:09,  1.20s/it]
 43%|████▎     | 43/100 [00:51<01:09,  1.22s/it]
 44%|████▍     | 44/100 [00:53<01:10,  1.25s/it]
 45%|████▌     | 45/100 [00:54<01:08,  1.24s/it]
 46%|████▌     | 46/100 [00:55<01:02,  1.16s/it]
 47%|████▋     | 47/100 [00:56<00:58,  1.10s/it]
 48%|████▊     | 48/100 [00:57<00:59,  1.15s/it]
 49%|████▉     | 49/100 [00:58<01:00,  1.18s/it]
 50%|█████     | 50/100 [01:00<01:00,  1.21s/it]
 51%|█████     | 51/100 [01:01<00:55,  1.14s/it]
 52%|█████▏    | 52/100 [01:02<00:55,  1.16s/it]
 53%|█████▎    | 53/100 [01:03<00:55,  1.19s/it]
 54%|█████▍    | 54/100 [01:04<00:51,  1.11s/it]
 55%|█████▌    | 55/100 [01:05<00:52,  1.18s/it]
 56%|█████▌    | 56/100 [01:07<00:52,  1.18s/it]
 57%|█████▋    | 57/100 [01:08<00:51,  1.19s/it]
 58%|█████▊    | 58/100 [01:09<00:51,  1.22s/it]
 59%|█████▉    | 59/100 [01:10<00:50,  1.23s/it]
 60%|██████    | 60/100 [01:12<00:49,  1.24s/it]
 61%|██████    | 61/100 [01:13<00:48,  1.25s/it]
 62%|██████▏   | 62/100 [01:14<00:47,  1.25s/it]
 63%|██████▎   | 63/100 [01:15<00:47,  1.27s/it]
 64%|██████▍   | 64/100 [01:17<00:45,  1.27s/it]
 65%|██████▌   | 65/100 [01:18<00:44,  1.27s/it]
 66%|██████▌   | 66/100 [01:19<00:42,  1.25s/it]
 67%|██████▋   | 67/100 [01:20<00:41,  1.27s/it]
 68%|██████▊   | 68/100 [01:22<00:41,  1.29s/it]
 69%|██████▉   | 69/100 [01:23<00:39,  1.27s/it]
 70%|███████   | 70/100 [01:24<00:37,  1.26s/it]
 71%|███████   | 71/100 [01:26<00:36,  1.26s/it]
 72%|███████▏  | 72/100 [01:27<00:35,  1.26s/it]
 73%|███████▎  | 73/100 [01:28<00:34,  1.26s/it]
 74%|███████▍  | 74/100 [01:29<00:30,  1.17s/it]
 75%|███████▌  | 75/100 [01:30<00:30,  1.21s/it]
 76%|███████▌  | 76/100 [01:32<00:29,  1.21s/it]
 77%|███████▋  | 77/100 [01:33<00:28,  1.23s/it]
 78%|███████▊  | 78/100 [01:34<00:27,  1.24s/it]
 79%|███████▉  | 79/100 [01:35<00:26,  1.26s/it]
 80%|████████  | 80/100 [01:37<00:25,  1.25s/it]
 81%|████████  | 81/100 [01:38<00:23,  1.26s/it]
 82%|████████▏ | 82/100 [01:39<00:22,  1.26s/it]
 83%|████████▎ | 83/100 [01:40<00:21,  1.27s/it]
 84%|████████▍ | 84/100 [01:42<00:20,  1.27s/it]
 85%|████████▌ | 85/100 [01:43<00:19,  1.27s/it]
 86%|████████▌ | 86/100 [01:44<00:17,  1.26s/it]
 87%|████████▋ | 87/100 [01:45<00:16,  1.25s/it]
 88%|████████▊ | 88/100 [01:47<00:15,  1.26s/it]
 89%|████████▉ | 89/100 [01:48<00:12,  1.18s/it]
 90%|█████████ | 90/100 [01:49<00:12,  1.21s/it]
 91%|█████████ | 91/100 [01:50<00:11,  1.23s/it]
 92%|█████████▏| 92/100 [01:52<00:09,  1.24s/it]
 93%|█████████▎| 93/100 [01:53<00:09,  1.29s/it]
 94%|█████████▍| 94/100 [01:53<00:06,  1.05s/it]
 95%|█████████▌| 95/100 [01:55<00:05,  1.13s/it]
 96%|█████████▌| 96/100 [01:56<00:04,  1.16s/it]
 97%|█████████▋| 97/100 [01:57<00:03,  1.18s/it]
 98%|█████████▊| 98/100 [01:59<00:02,  1.20s/it]
 99%|█████████▉| 99/100 [02:00<00:01,  1.22s/it]
100%|██████████| 100/100 [02:01<00:00,  1.25s/it]
100%|██████████| 100/100 [02:01<00:00,  1.22s/it]
[2022-06-29 19:55:41] INFO [gpl.toolkit.pl.run:80] Done pseudo labeling and saving data
[2022-06-29 19:55:41] INFO [gpl.toolkit.pl.run:84] Saved GPL-training data to generated/fiqa/gpl-training-data.tsv
[2022-06-29 19:55:41] INFO [gpl.train.train:168] Now doing training on the generated data with the MarginMSE loss
[2022-06-29 19:55:41] INFO [sentence_transformers.SentenceTransformer.__init__:60] Load pretrained SentenceTransformer: distilbert-base-uncased

Downloading:   0%|          | 0.00/391 [00:00<?, ?B/s]
Downloading: 100%|██████████| 391/391 [00:00<00:00, 327kB/s]

Downloading:   0%|          | 0.00/11.4k [00:00<?, ?B/s]
Downloading: 100%|██████████| 11.4k/11.4k [00:00<00:00, 8.79MB/s]

Downloading:   0%|          | 0.00/8.56k [00:00<?, ?B/s]
Downloading: 100%|██████████| 8.56k/8.56k [00:00<00:00, 6.58MB/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]
Downloading: 100%|██████████| 483/483 [00:00<00:00, 433kB/s]

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]
Downloading:   1%|          | 2.76M/268M [00:00<00:09, 27.6MB/s]
Downloading:   3%|▎         | 7.71M/268M [00:00<00:06, 40.5MB/s]
Downloading:   5%|▍         | 12.5M/268M [00:00<00:05, 43.7MB/s]
Downloading:   6%|▋         | 17.0M/268M [00:00<00:05, 44.5MB/s]
Downloading:   8%|▊         | 21.8M/268M [00:00<00:05, 45.5MB/s]
Downloading:  10%|▉         | 26.3M/268M [00:00<00:05, 44.9MB/s]
Downloading:  12%|█▏        | 30.8M/268M [00:00<00:05, 45.0MB/s]
Downloading:  13%|█▎        | 35.7M/268M [00:00<00:05, 46.2MB/s]
Downloading:  15%|█▌        | 40.7M/268M [00:00<00:04, 47.2MB/s]
Downloading:  17%|█▋        | 45.4M/268M [00:01<00:05, 39.8MB/s]
Downloading:  19%|█▊        | 49.9M/268M [00:01<00:05, 41.1MB/s]
Downloading:  20%|██        | 54.7M/268M [00:01<00:04, 43.2MB/s]
Downloading:  22%|██▏       | 59.3M/268M [00:01<00:04, 44.0MB/s]
Downloading:  24%|██▍       | 64.2M/268M [00:01<00:04, 45.4MB/s]
Downloading:  26%|██▌       | 68.9M/268M [00:01<00:04, 45.9MB/s]
Downloading:  27%|██▋       | 73.6M/268M [00:01<00:04, 44.9MB/s]
Downloading:  29%|██▉       | 78.1M/268M [00:01<00:04, 44.6MB/s]
Downloading:  31%|███       | 82.7M/268M [00:01<00:04, 45.1MB/s]
Downloading:  33%|███▎      | 87.6M/268M [00:01<00:03, 46.1MB/s]
Downloading:  34%|███▍      | 92.2M/268M [00:02<00:03, 45.9MB/s]
Downloading:  36%|███▌      | 96.9M/268M [00:02<00:03, 46.2MB/s]
Downloading:  38%|███▊      | 101M/268M [00:02<00:03, 46.1MB/s] 
Downloading:  40%|███▉      | 106M/268M [00:02<00:03, 44.0MB/s]
Downloading:  41%|████▏     | 111M/268M [00:02<00:03, 45.3MB/s]
Downloading:  43%|████▎     | 116M/268M [00:02<00:03, 45.4MB/s]
Downloading:  45%|████▍     | 120M/268M [00:02<00:03, 44.4MB/s]
Downloading:  47%|████▋     | 125M/268M [00:02<00:03, 45.3MB/s]
Downloading:  48%|████▊     | 129M/268M [00:02<00:03, 42.8MB/s]
Downloading:  50%|█████     | 134M/268M [00:03<00:03, 43.9MB/s]
Downloading:  52%|█████▏    | 139M/268M [00:03<00:02, 45.1MB/s]
Downloading:  54%|█████▎    | 143M/268M [00:03<00:02, 44.7MB/s]
Downloading:  55%|█████▌    | 148M/268M [00:03<00:02, 44.3MB/s]
Downloading:  57%|█████▋    | 152M/268M [00:03<00:02, 43.9MB/s]
Downloading:  59%|█████▊    | 157M/268M [00:03<00:02, 44.6MB/s]
Downloading:  60%|██████    | 162M/268M [00:03<00:02, 45.0MB/s]
Downloading:  62%|██████▏   | 166M/268M [00:03<00:02, 45.5MB/s]
Downloading:  64%|██████▎   | 171M/268M [00:03<00:02, 45.4MB/s]
Downloading:  65%|██████▌   | 175M/268M [00:03<00:02, 46.0MB/s]
Downloading:  67%|██████▋   | 180M/268M [00:04<00:01, 46.0MB/s]
Downloading:  69%|██████▉   | 185M/268M [00:04<00:01, 46.4MB/s]
Downloading:  71%|███████   | 189M/268M [00:04<00:01, 46.2MB/s]
Downloading:  72%|███████▏  | 194M/268M [00:04<00:01, 40.4MB/s]
Downloading:  74%|███████▍  | 198M/268M [00:04<00:01, 41.1MB/s]
Downloading:  76%|███████▌  | 203M/268M [00:04<00:01, 40.8MB/s]
Downloading:  77%|███████▋  | 207M/268M [00:04<00:01, 41.4MB/s]
Downloading:  79%|███████▉  | 211M/268M [00:04<00:01, 42.1MB/s]
Downloading:  81%|████████  | 216M/268M [00:04<00:01, 43.2MB/s]
Downloading:  82%|████████▏ | 220M/268M [00:04<00:01, 43.3MB/s]
Downloading:  84%|████████▍ | 225M/268M [00:05<00:01, 43.0MB/s]
Downloading:  86%|████████▌ | 229M/268M [00:05<00:00, 43.9MB/s]
Downloading:  87%|████████▋ | 234M/268M [00:05<00:00, 43.8MB/s]
Downloading:  89%|████████▉ | 238M/268M [00:05<00:00, 43.5MB/s]
Downloading:  90%|█████████ | 242M/268M [00:05<00:00, 43.9MB/s]
Downloading:  92%|█████████▏| 247M/268M [00:05<00:00, 44.2MB/s]
Downloading:  94%|█████████▍| 251M/268M [00:05<00:00, 44.3MB/s]
Downloading:  96%|█████████▌| 256M/268M [00:05<00:00, 45.2MB/s]
Downloading:  97%|█████████▋| 261M/268M [00:05<00:00, 45.3MB/s]
Downloading:  99%|█████████▉| 265M/268M [00:05<00:00, 45.0MB/s]
Downloading: 100%|██████████| 268M/268M [00:06<00:00, 44.2MB/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]
Downloading:  38%|███▊      | 177k/466k [00:00<00:00, 1.64MB/s]
Downloading: 100%|██████████| 466k/466k [00:00<00:00, 3.14MB/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]
Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 22.8kB/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]
Downloading:  33%|███▎      | 76.8k/232k [00:00<00:00, 629kB/s]
Downloading: 100%|██████████| 232k/232k [00:00<00:00, 1.41MB/s]
[2022-06-29 19:55:50] WARNING [root._load_auto_model:789] No sentence-transformers model found with name /root/.cache/torch/sentence_transformers/distilbert-base-uncased. Creating a new one with MEAN pooling.
Some weights of the model checkpoint at /root/.cache/torch/sentence_transformers/distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[2022-06-29 19:55:51] INFO [sentence_transformers.SentenceTransformer.__init__:97] Use pytorch device: cpu

Batches:   0%|          | 0/1 [00:00<?, ?it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 21.75it/s]
[2022-06-29 19:55:51] INFO [gpl.toolkit.sbert.load_sbert:44] Set max_seq_length=350
[2022-06-29 19:55:51] INFO [gpl.train.train:173] Load GPL training data from generated/fiqa/gpl-training-data.tsv
[2022-06-29 19:55:51] INFO [gpl.toolkit.loss.__init__:22] Set GPL score function to dot

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/100 [00:00<?, ?it/s]�[A

Version conflict due to `easy-elasticsearch` and latest `beir`

There exists a version conflict in the dependencies of GPL. GPL requires easy-elasticsearch>=0.0.7 (0.0.7 is the latest version of easy-elasticsearch). easy-elasticsearch requires elasticsearch==7.12.1 but beir, which GPL depends on at any version, requires elasticsearch==7.9.1. Thus, a dependency solver (like poetry) will install beir==0.0.4 when installing gpl since beir==0.0.4 does not depend on elasticsearch at all. However, GPL is not compatible with beir==0.0.4. So, the latest versions of GPL and BEIR are not compatible.

However, running pip install gpl will work because easy-elasticsearch==0.0.7 presumably works with elasticsearch==7.9.1. Installing with pip will warn that easy-elasticsearch 0.0.7 requires elasticsearch==7.12.1, but you have elasticsearch 7.9.1 which is incompatible, but training still works correctly.

The easiest solution is to change easy-elasticsearch or beir's dependencies to allow elasticsearch between 7.9.1 and 7.12.1. I don't think anything can be done in gpl's setup.py file because it has two dependencies that each are pinning a different version of elasticsearch. A minimum version of beir could be set in gpl's setup.py to prevent an incompatible version from being installed via a package manager with a dependency solver.

TSDAE + GPL and TAS-B + GPL

@kwang2049 Hi, thanks for your amazing work!

I wonder if the TSDAE + GPL mentioned in the paper refers to: fine-tuning the distilbert-base PLM with training order of (1)TSDAE on {dataset} -> (2) GPL on {dataset} ?

Thx.

TSDAE to GPL... Error on start

I'm trying to go from my trained TSDAE and then apply GPL... However, keep getting errors.

! export dataset="hs_resume_tsdae_gpl_mini" 
! python -m gpl.train \
    --path_to_generated_data "generated/$dataset" \
    --base_ckpt "/Users/cfeld/Desktop/dev/trajectory/finetuning/gpl/outputs/tsdae/MiniLM-L6-H384-uncased-model" \
    --gpl_score_function "dot" \
    --batch_size_gpl 34 \
    --gpl_steps 100 \
    --queries_per_passage 1 \
    --output_dir "output/$dataset" \
    --evaluation_data "./$dataset" \
    --evaluation_output "evaluation/$dataset" \
    --generator "BeIR/query-gen-msmarco-t5-base-v1" \
    --retrievers "msmarco-distilbert-base-v3" "msmarco-MiniLM-L-6-v3" \
    --retriever_score_functions "cos_sim" "cos_sim" \
    --cross_encoder "cross-encoder/ms-marco-MiniLM-L-6-v2" \
    --use_train_qrels

However, I'm getting this error:

2022-09-12 17:37:44 - Loading faiss.
2022-09-12 17:37:44 - Successfully loaded faiss.
/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/runpy.py:127: RuntimeWarning: 'gpl.train' found in sys.modules after import of package 'gpl', but prior to execution of 'gpl.train'; this may result in unpredictable behaviour
  warn(RuntimeWarning(msg))
[2022-09-12 17:37:44] INFO [gpl.train.train:79] Corpus does not exist in generated/. Now clone the one from the evaluation path ./
[2022-09-12 17:37:44] WARNING [gpl.train.train:106] Found `qgen_prefix` is not None. By setting `use_train_qrels == True`, the `qgen_prefix` will not be used
[2022-09-12 17:37:44] INFO [gpl.train.train:113] Loading qrels and queries from labeled data under the path of `evaluation_data`
Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/site-packages/gpl/train.py", line 250, in <module>
    train(**vars(args))
  File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/site-packages/gpl/train.py", line 114, in train
    assert 'qrels' in os.listdir(evaluation_data) and 'queries.jsonl' in os.listdir(evaluation_data)
AssertionError

Perhaps my folder structure isn't quite right? I've tried all kinds of combos... Folder:
corpus.jsonl
evaluation
- corpus.jsonl
- hs_resume_tsdae_gpl_mini
-- corpus.jsonl
generated
- corpus.jsonl
- hs_resume_tsdae_gpl_mini
-- corpus.jsonl
hs_resume_tsdae_gpl_mini
- corpus.jsonl
output
- hs_resume_tsdae_gpl_mini

How to create dataset to train in GPL from normal set of domain specific word docs or pdfs.

First of all thanks a lot for bringing out this unique paper. After going through the paper I wanted to try out this approach but am little confused with the initial data to be created for training in GPL. Currently all my domain corpus content are in Word files or pdf and I don't have any labelled data ,As your paper claims it can be used for unlabeled data can you please guide how to create the initial input data to pass through this Method. Also kindly share if you have any pretrained models available to experiment the method?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.