Comments (7)
@JLUGQQ: Here's what I have in git diff
. Let me know if this helps.
diff --git a/blink/biencoder/nn_prediction.py b/blink/biencoder/nn_prediction.py
index eab90a8..18e50cd 100644
--- a/blink/biencoder/nn_prediction.py
+++ b/blink/biencoder/nn_prediction.py
@@ -55,13 +55,20 @@ def get_topk_predictions(
oid = 0
for step, batch in enumerate(iter_):
batch = tuple(t.to(device) for t in batch)
- context_input, _, srcs, label_ids = batch
+ if is_zeshel:
+ context_input, _, srcs, label_ids = batch
+ else:
+ context_input, _, label_ids = batch
+ srcs = torch.tensor([0] * context_input.size(0), device=device)
+
src = srcs[0].item()
+ cand_encode_list[src] = cand_encode_list[src].to(device)
scores = reranker.score_candidate(
context_input,
None,
- cand_encs=cand_encode_list[src].to(device)
+ cand_encs=cand_encode_list[src]
)
+
values, indicies = scores.topk(top_k)
old_src = src
for i in range(context_input.size(0)):
@@ -93,7 +100,7 @@ def get_topk_predictions(
continue
# add examples in new_data
- cur_candidates = candidate_pool[src][inds]
+ cur_candidates = candidate_pool[srcs[i].item()][inds]
nn_context.append(context_input[i].cpu().tolist())
nn_candidates.append(cur_candidates.cpu().tolist())
nn_labels.append(pointer)
from blink.
@JLUGQQ: Yes, this is a bug, you can look at issue #95 for the solution.
from blink.
Thank you. I have tried this solution before, but it didn't work. Maybe I should change my package version accoring to requirements.txt.
from blink.
@JLUGQQ: I am able to successfully run both eval on both zeshel and non-zeshel datasets. Feel free to copy and paste your error message here, I'd be glad to take a look.
from blink.
Thank you very much for your help!
I could successfully run train_biencoder. But when I ran eval_biencoder. I encountered this problem. I have changed code according to issue #95
05/06/2022 13:33:00 - INFO - Blink - Getting top 64 predictions.
0%| | 0/2500 [00:00<?, ?it/s]05/06/2022 13:33:00 - INFO - Blink - World size : 16
0%| | 0/2500 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/data/gavin/BLINK-main/blink/biencoder/eval_biencoder.py", line 337, in
main(new_params)
File "/data/gavin/BLINK-main/blink/biencoder/eval_biencoder.py", line 289, in main
save_results,
File "/data/gavin/BLINK-main/blink/biencoder/nn_prediction.py", line 65, in get_topk_predictions
cand_encs=cand_encode_list[src].to(device)
KeyError: 12
from blink.
@JLUGQQ: Here's what I have in
git diff
. Let me know if this helps.diff --git a/blink/biencoder/nn_prediction.py b/blink/biencoder/nn_prediction.py index eab90a8..18e50cd 100644 --- a/blink/biencoder/nn_prediction.py +++ b/blink/biencoder/nn_prediction.py @@ -55,13 +55,20 @@ def get_topk_predictions( oid = 0 for step, batch in enumerate(iter_): batch = tuple(t.to(device) for t in batch) - context_input, _, srcs, label_ids = batch + if is_zeshel: + context_input, _, srcs, label_ids = batch + else: + context_input, _, label_ids = batch + srcs = torch.tensor([0] * context_input.size(0), device=device) + src = srcs[0].item() + cand_encode_list[src] = cand_encode_list[src].to(device) scores = reranker.score_candidate( context_input, None, - cand_encs=cand_encode_list[src].to(device) + cand_encs=cand_encode_list[src] ) + values, indicies = scores.topk(top_k) old_src = src for i in range(context_input.size(0)): @@ -93,7 +100,7 @@ def get_topk_predictions( continue # add examples in new_data - cur_candidates = candidate_pool[src][inds] + cur_candidates = candidate_pool[srcs[i].item()][inds] nn_context.append(context_input[i].cpu().tolist()) nn_candidates.append(cur_candidates.cpu().tolist()) nn_labels.append(pointer)
Pity. It still doesn't work. Thanks for your reply. I think I should take a time to debug to find the exact reason. And I will comment if I solve this problem.
from blink.
KeyError might happen because the validation or test set tries to find their encodings from training set encodings. (e.g. there is a crash when val data - which has the src value 9 - attempts to find their encoding in training encodings, which has src values from 0 to 8.-- it is the reason why there is a key error for value 9)
Although there might be multiple solutions to fix this, I recommend saving each encoding in separate files. i.e. the following
shell script worked in my case:
python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode train --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_train.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_train.pt
python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode valid --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_valid.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_valid.pt
python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode test --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_test.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_test.pt
from blink.
Related Issues (20)
- Use a smaller model to speed up the prediction time HOT 3
- How to generate embeddings for new candidates? HOT 6
- Slightly different scores when using a quantized model
- Poor recall using non-dense FAISS indexes HOT 2
- How to only generate Precision, Recall, and f1 score when benchmarking BLINK HOT 3
- Biencoder with GPU RuntimeError: Expected object of device type cuda but got device type cpu for argument #3 'index' in call to _th_index_select HOT 3
- python: symbol lookup error:
- A short tutorial on how to train a smaller biencoder model on custom dataset HOT 1
- Entity linking in Wikidata? HOT 3
- Missing `add_special_tokens` in biencoder? HOT 1
- Average length of words in a Wikipedia Entity HOT 1
- AttributeError: 'KeyedVectors' object has no attribute 'key_to_index' HOT 1
- Add truncation for data_process.get_context_representation
- Tutorial on how to train a Crossencoder HOT 1
- Python 3.7 no longer supported by conda HOT 2
- How to get entity type?
- ValueError in faiss_indexer.py Due to Mismatched Tensor Shapes During ELQ Training
- ELQ Wikipedia-trained biencoder checkpoints
- Can Support Chinese?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from blink.