Git Product home page Git Product logo

Comments (7)

abhinavkulkarni avatar abhinavkulkarni commented on June 4, 2024 1

@JLUGQQ: Here's what I have in git diff. Let me know if this helps.

diff --git a/blink/biencoder/nn_prediction.py b/blink/biencoder/nn_prediction.py
index eab90a8..18e50cd 100644
--- a/blink/biencoder/nn_prediction.py
+++ b/blink/biencoder/nn_prediction.py
@@ -55,13 +55,20 @@ def get_topk_predictions(
     oid = 0
     for step, batch in enumerate(iter_):
         batch = tuple(t.to(device) for t in batch)
-        context_input, _, srcs, label_ids = batch
+        if is_zeshel:
+            context_input, _, srcs, label_ids = batch
+        else:
+            context_input, _, label_ids = batch
+            srcs = torch.tensor([0] * context_input.size(0), device=device)
+
         src = srcs[0].item()
+        cand_encode_list[src] = cand_encode_list[src].to(device)
         scores = reranker.score_candidate(
             context_input, 
             None, 
-            cand_encs=cand_encode_list[src].to(device)
+            cand_encs=cand_encode_list[src]
         )
+
         values, indicies = scores.topk(top_k)
         old_src = src
         for i in range(context_input.size(0)):
@@ -93,7 +100,7 @@ def get_topk_predictions(
                 continue
 
             # add examples in new_data
-            cur_candidates = candidate_pool[src][inds]
+            cur_candidates = candidate_pool[srcs[i].item()][inds]
             nn_context.append(context_input[i].cpu().tolist())
             nn_candidates.append(cur_candidates.cpu().tolist())
             nn_labels.append(pointer)

from blink.

abhinavkulkarni avatar abhinavkulkarni commented on June 4, 2024

@JLUGQQ: Yes, this is a bug, you can look at issue #95 for the solution.

from blink.

JLUGQQ avatar JLUGQQ commented on June 4, 2024

#95

Thank you. I have tried this solution before, but it didn't work. Maybe I should change my package version accoring to requirements.txt.

from blink.

abhinavkulkarni avatar abhinavkulkarni commented on June 4, 2024

@JLUGQQ: I am able to successfully run both eval on both zeshel and non-zeshel datasets. Feel free to copy and paste your error message here, I'd be glad to take a look.

from blink.

JLUGQQ avatar JLUGQQ commented on June 4, 2024

#95

Thank you very much for your help!
I could successfully run train_biencoder. But when I ran eval_biencoder. I encountered this problem. I have changed code according to issue #95
05/06/2022 13:33:00 - INFO - Blink - Getting top 64 predictions.
0%| | 0/2500 [00:00<?, ?it/s]05/06/2022 13:33:00 - INFO - Blink - World size : 16
0%| | 0/2500 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/data/gavin/BLINK-main/blink/biencoder/eval_biencoder.py", line 337, in
main(new_params)
File "/data/gavin/BLINK-main/blink/biencoder/eval_biencoder.py", line 289, in main
save_results,
File "/data/gavin/BLINK-main/blink/biencoder/nn_prediction.py", line 65, in get_topk_predictions
cand_encs=cand_encode_list[src].to(device)
KeyError: 12

from blink.

JLUGQQ avatar JLUGQQ commented on June 4, 2024

@JLUGQQ: Here's what I have in git diff. Let me know if this helps.

diff --git a/blink/biencoder/nn_prediction.py b/blink/biencoder/nn_prediction.py
index eab90a8..18e50cd 100644
--- a/blink/biencoder/nn_prediction.py
+++ b/blink/biencoder/nn_prediction.py
@@ -55,13 +55,20 @@ def get_topk_predictions(
     oid = 0
     for step, batch in enumerate(iter_):
         batch = tuple(t.to(device) for t in batch)
-        context_input, _, srcs, label_ids = batch
+        if is_zeshel:
+            context_input, _, srcs, label_ids = batch
+        else:
+            context_input, _, label_ids = batch
+            srcs = torch.tensor([0] * context_input.size(0), device=device)
+
         src = srcs[0].item()
+        cand_encode_list[src] = cand_encode_list[src].to(device)
         scores = reranker.score_candidate(
             context_input, 
             None, 
-            cand_encs=cand_encode_list[src].to(device)
+            cand_encs=cand_encode_list[src]
         )
+
         values, indicies = scores.topk(top_k)
         old_src = src
         for i in range(context_input.size(0)):
@@ -93,7 +100,7 @@ def get_topk_predictions(
                 continue
 
             # add examples in new_data
-            cur_candidates = candidate_pool[src][inds]
+            cur_candidates = candidate_pool[srcs[i].item()][inds]
             nn_context.append(context_input[i].cpu().tolist())
             nn_candidates.append(cur_candidates.cpu().tolist())
             nn_labels.append(pointer)

Pity. It still doesn't work. Thanks for your reply. I think I should take a time to debug to find the exact reason. And I will comment if I solve this problem.

from blink.

yc-song avatar yc-song commented on June 4, 2024

KeyError might happen because the validation or test set tries to find their encodings from training set encodings. (e.g. there is a crash when val data - which has the src value 9 - attempts to find their encoding in training encodings, which has src values from 0 to 8.-- it is the reason why there is a key error for value 9)
Although there might be multiple solutions to fix this, I recommend saving each encoding in separate files. i.e. the following
shell script worked in my case:

python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode train --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_train.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_train.pt

python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode valid --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_valid.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_valid.pt

python blink/biencoder/eval_biencoder.py --path_to_model models/zeshel/biencoder/pytorch_model.bin --data_path data/zeshel/blink_format --output_path models/zeshel --encode_batch_size 128 --eval_batch_size 1 --top_k 64 --save_topk_result --bert_model bert-large-uncased --mode test --zeshel True --data_parallel --cand_encode_path data/zeshel/cand_enc/cand_enc_test.pt --cand_pool_path data/zeshel/cand_pool/cand_pool_test.pt

from blink.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.