morningmoni / gar Goto Github PK
View Code? Open in Web Editor NEWCode and resources for papers "Generation-Augmented Retrieval for Open-Domain Question Answering" and "Reader-Guided Passage Reranking for Open-Domain Question Answering", ACL 2021
Code and resources for papers "Generation-Augmented Retrieval for Open-Domain Question Answering" and "Reader-Guided Passage Reranking for Open-Domain Question Answering", ACL 2021
Thanks for your wonderful work. It brings a new angle to the IR task.
I am confused about the GAR generator and the prediction answer produced by a generator reader for Rider to reranks.
In my view, the format of input data for GAR is :
[start_token] question [end_token]
and the target is:
[start_token] answer/sentence/title [end_token]
while the format for the reader is :
[start_token] question [separate_token] passage [end_token]
What is the relationship between the GAR generator and the generator which provides the prediction for reranking?
Are these two independent models?
In addition, the call to the file rider.py was not found in the exposed code. Could you provide some instruction on it?
It seems like dense retrieval may achieve better results, but in GAR+ you only fuse the results of sparse and dense retrieval, so why not finetune DPR using the augmented querys? Or if you have carried out such experiements, what is the performance?
I noticed that there were two checkpoints when I finished training model,are they the same?
I have trained three checkpoints based on custom datasets,but I got the same checkpoints'name, they are: ①checkpointepoch=1.ckpt;and②checkpointlast.ckpt
So why is it *epoch=1, and not 50 or 100?("--num_train_epochs", default=8)
ImportError: cannot import name 'reweight_labels' from 'utils_gen'
I can't find the function reweight_labels in utils_gen.py, is there anything missing?
Hello!
Thanks for publishing your code. Do you intend to publish the checkpoint? The command-line arguments that you used to generate them would also be very nice.
Thanks.
Hi, I got a problem when I try to fusion the results from three sources following the instruction in another Issue "Say you have 3 lists of retrieved docs [a1, a2, ...], [b1, b2, ...], [c1, c2, ...]. The combined list would be [a1, b1, c1, a2, b2, c2, ...]". Especially, I treat list a, list b and list c as list answer, list sentence and list title separately. I check my code many time and I still obtain the result below slightly different, is it reasonable?
Top5 accuracy: 0.5853
Top10 accuracy: 0.6604
Top20 accuracy: 0.7269
Top50 accuracy: 0.7967
Top100 accuracy: 0.8382
I use the provided augmented query using BM25 on NQ, and I got same results as another Issue in each source.
Here is my fusion code, I would be grateful if you can point some mistake out. Or can you provide the resouce code? It would be really helpful!
`
def fusion(n_doc, tit=True, sen=True, ans_path=None, tit_path=None, sen_path=None, output_path=None):
if tit:
data_tit = json.load(open(tit_path, "r"))
if sen:
data_sen = json.load(open(sen_path, "r"))
data_ans = json.load(open(ans_path, "r"))
fusion_result = []
for i, data in enumerate(data_ans):
fusion_result.append(data) # I do this to keep same format with the evaluation file
fusion_psgs = []
for j, ctx in enumerate(data["ctxs"]):
fusion_psgs.append(ctx)
if len(fusion_psgs) == n_doc:
break
if tit:
fusion_psgs.append(data_tit[i]["ctxs"][j])
if sen:
fusion_psgs.append(data_sen[i]["ctxs"][j])
assert len(fusion_psgs) == n_doc, f"question{i} fusion passages length{len(fusion_psgs)} != {n_doc}"
fusion_result[i]["ctxs"] = fusion_psgs
json.dump(fusion_result, open(output_path, "w"))
return fusion_result
`
Thank you !!!
You include the DPR github code without making it a git submodule. Have you changed it from the repo or is it the same, aside from the train_reader.py
file?
How are the data look like for training an answer_generator and where are they?
I do not find any information about the variable "data_dir="./cnn-dailymail/cnn_dm/ " of class SummarizationDataset in file utils_gen.py line 156.
How should I download and prepare the training data for answer_generator?
I was going through the provided GAR data files for NQ. It seems like nq-title
is missing all test.target
files. It will be great if this can be updated to include it.
Hello,
Just to say that fp16 defaults to on. On my config at least, I get NaNs when it's turned on.
Also, the argument has default=True
and action="store_true"
by the way, which is a bit weird and makes me think it's supposed to be default=False
.
parser.add_argument(
"--fp16",
action='store_true',
default=True,
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
https://github.com/morningmoni/GAR/blob/master/gar/conf.py#L102
Hello, sorry to bother you,I met an issue when I tried to train the GAR model, as follows:
Traceback (most recent call last):
File "train_generator.py", line 245, in val_dataloader
Traceback (most recent call last):
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size, num_workers=4)
File "/home/zhangxy/QA/GAR-master/gar/train_generator.py", line 245, in val_dataloader
File "train_generator.py", line 225, in get_dataloader
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
File "../gar/utils_gen.py", line 177, in init
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size, num_workers=4)
self.source = pickle.load(open(os.path.join(data_dir, type_path + f".source.processed{suffix}"), 'rb'))
File "/home/zhangxy/QA/GAR-master/gar/train_generator.py", line 225, in get_dataloader
ModuleNotFoundError: No module named 'transformers.tokenization_utils_base'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "train_generator.py", line 308, in
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
File "../gar/utils_gen.py", line 177, in init
main(args)
File "train_generator.py", line 285, in main
self.source = pickle.load(open(os.path.join(data_dir, type_path + f".source.processed{suffix}"), 'rb'))
ModuleNotFoundError: No module named 'transformers.tokenization_utils_base'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/zhangxy/QA/GAR-master/gar/train_generator.py", line 308, in
trainer = generic_train(model, args, logger, resume_cp_file=cp_file, )
File "../gar/lightning_base.py", line 220, in generic_train
main(args)
trainer.fit(model)
File "/home/zhangxy/QA/GAR-master/gar/train_generator.py", line 285, in main
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/trainer/states.py", line 48, in wrapped_fn
result = fn(self, *args, **kwargs)
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1058, in fit
trainer = generic_train(model, args, logger, resume_cp_file=cp_file, )
File "../gar/lightning_base.py", line 220, in generic_train
trainer.fit(model)
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/trainer/states.py", line 48, in wrapped_fn
result = fn(self, *args, **kwargs)
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1046, in fit
results = self.accelerator_backend.spawn_ddp_children(model)
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/accelerators/ddp_backend.py", line 123, in spawn_ddp_children
results = self.ddp_train(local_rank, mp_queue=None, model=model, is_master=True)
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/accelerators/ddp_backend.py", line 224, in ddp_train
results = self.trainer.run_pretrain_routine(model)
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1224, in run_pretrain_routine
self.accelerator_backend.train(model)
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/accelerators/ddp_backend.py", line 57, in train
self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model)
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/accelerators/ddp_backend.py", line 224, in ddp_train
results = self.trainer.run_pretrain_routine(model)
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1224, in run_pretrain_routine
self._run_sanity_check(ref_model, model)
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1249, in _run_sanity_check
self._run_sanity_check(ref_model, model)
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1249, in _run_sanity_check
self.reset_val_dataloader(ref_model)
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/trainer/data_loading.py", line 337, in reset_val_dataloader
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/trainer/data_loading.py", line 266, in _reset_eval_dataloader
dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader'))
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/trainer/data_loading.py", line 360, in request_dataloader
self.reset_val_dataloader(ref_model)
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/trainer/data_loading.py", line 337, in reset_val_dataloader
dataloader = dataloader_fx()
File "train_generator.py", line 248, in val_dataloader
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/trainer/data_loading.py", line 266, in _reset_eval_dataloader
return self.get_dataloader("train", batch_size=self.hparams.eval_batch_size, num_workers=4)
File "train_generator.py", line 225, in get_dataloader
dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader'))
File "/home/zhangxy/anaconda3/envs/torch15DPR/lib/python3.6/site-packages/pytorch_lightning/trainer/data_loading.py", line 360, in request_dataloader
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
File "../gar/utils_gen.py", line 177, in init
self.source = pickle.load(open(os.path.join(data_dir, type_path + f".source.processed{suffix}"), 'rb'))
ModuleNotFoundError: No module named 'transformers.tokenization_utils_base'
dataloader = dataloader_fx()
File "/home/zhangxy/QA/GAR-master/gar/train_generator.py", line 248, in val_dataloader
return self.get_dataloader("train", batch_size=self.hparams.eval_batch_size, num_workers=4)
File "/home/zhangxy/QA/GAR-master/gar/train_generator.py", line 225, in get_dataloader
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
File "../gar/utils_gen.py", line 177, in init
self.source = pickle.load(open(os.path.join(data_dir, type_path + f".source.processed{suffix}"), 'rb'))
ModuleNotFoundError: No module named 'transformers.tokenization_utils_base'
I installed transformers==2.11.0 and tokenizers==0.7.0
I run the project with command: GEN_TARGET='answer' python train_generator.py --remark generator_train_nq_A --train_batch_size 128 --eval_batch_size 256 --ckpt_metric val-ROUGE-1
So, how can I solve it ?
Thanks in advance !!
Thank you for sharing your implementation of GAR and RIDER.
there is a function called rider_rerank, I would appreciate if you could advise me how you actually used this in your DPR reader inference code. If possible, could you give an example of how you called it?
Hey, is it possible to provide your TriviaQA source/target files at least for the train set? Or the process to generate them.
Hello! @morningmoni I am reaching out to inquire about the performance of your GAR model, specifically its training results on different components of the trivia dataset. While the model shows impressive results on trivia answers (achieving a top-5 accuracy of 0.665), I noticed a significant drop in performance when trained on trivia sentences (top-5 accuracy: 0.6221). This discrepancy becomes more evident when performing fusion (sentence + answer + title), where the overall recall drops to 0.6361, despite all training being conducted for 100 epochs.
Could you please provide insights into why the training effectiveness is considerably lower for sentences and titles? Additionally, I am curious about how the fusion results presented in your paper were achieved. Are there specific methodologies or considerations applied to enhance the fusion performance that might not be immediately apparent? Thank you!
https://github.com/morningmoni/GAR/blob/master/data/download_data.py#L32
the one with "compressed" in its name has compressed=False
, and the one without has compressed=True
. Also they seem to both download the same compressed file.
Just curious.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.