Comments (6)
Hi @ntnq4 ,
I have managed to convert the splade models to onnx. Although I used the pretrained checkpoint. I am aware it is counterintuitive for you but nevertheless if this helps, I am glad.
To reproduce:
- Convert the model to a torchscript.
model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", torchscript=True) # type: ignore
import torch
from transformers import AutoModelForMaskedLM,AutoTokenizer # type: ignore
class TransformerRep(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", torchscript=True) # type: ignore
self.model.eval() # type: ignore
self.fp16 = True
def encode(self, input_ids, token_type_ids, attention_mask):
# Tokens is a dict with keys input_ids and attention_mask
return self.model(input_ids, token_type_ids, attention_mask)[0]
class SpladeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = TransformerRep()
self.agg = "max"
self.model.eval()
def forward(self, input_ids,token_type_ids, attention_mask):
with torch.cuda.amp.autocast(): # type: ignore
with torch.no_grad():
lm_logits = self.model.encode(input_ids,token_type_ids, attention_mask)[0]
vec, _ = torch.max(torch.log(1 + torch.relu(lm_logits)) * attention_mask.unsqueeze(-1), dim=1)
indices = vec.nonzero().squeeze()
weights = vec.squeeze()[indices]
return indices[:,1], weights[:,1]
# Convert the model to TorchScript
model = SpladeModel()
tokenizer = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil")
sample = "the capital of france is paris"
inputs = tokenizer(sample, return_tensors="pt")
traced_model = torch.jit.trace(model, (inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]))
- Later Load it from File and convert it using a dummy input. Make sure to adjust the above script to match your implementation.
import torch
dyn_axis = {
'input_ids': {0: 'batch_size', 1: 'sequence'},
'attention_mask': {0: 'batch_size', 1: 'sequence'},
'token_type_ids': {0: 'batch_size', 1: 'sequence'},
'indices': {0: 'batch_size', 1: 'sequence'},
'weights': {0: 'batch_size', 1: 'sequence'}
}
model = torch.jit.load(model_file)
onnx_model = torch.onnx.export(
model,
dummy_input, # type: ignore
f=model_onnx_file,
input_names=['input_ids','token_type_ids', 'attention_mask'],
output_names=['indices', 'weights'],
dynamic_axes=dyn_axis,
do_constant_folding=True,
opset_version=15,
verbose=False,
)
- Using this method I have managed to convert the following HF models successfully.
model_names= [
"naver/splade_v2_max",
"naver/splade_v2_distil",
"naver/splade-cocondenser-ensembledistil",
"naver/efficient-splade-VI-BT-large-query",
"naver/efficient-splade-VI-BT-large-doc",
]
requirements:
torch==2.2.0
Hope this helps! :)
from splade.
Example: https://huggingface.co/Xenova/t5-small-awesome-text-to-sql/tree/main/
from splade.
Hi @ntnq4
Not that I am aware of. I am not super familiar with ONNX - did you manage to make it work?
from splade.
I didn't manage to make it work unfortunately... I tried this tutorial but it didn't work for my SPLADE model.
I also found this recent paper that mentionned this conversion.
from splade.
Hi @risan-raja,
Thank you for your help : )
I will try your solution on my side.
from splade.
if an ONNX conversion was added to HuggingFace in a folder called onnx
then it would automatically become available to HuggingFace Transformers.js and be usable locally on the web.
from splade.
Related Issues (20)
- Chunk token limit for SPLADE sparse embeddings? HOT 4
- Indexing a document corpus with Efficient SPLADE HOT 4
- [Bug] Get PyTorch version HOT 2
- Can SPLADE adapt to Chinese language ? HOT 10
- Proposed Dockerfile
- Whether the SPLADE model supports the distinction of 'is_q'? HOT 1
- SPLADE representations on BEIR dataset HOT 1
- Quick Start Problem: an unexpected keyword argument 'version_base' HOT 1
- Is it possible to get a commercial license? HOT 5
- Installation error - splade with tokenisers v0.12.1 – Compatibility issue with Python 3.11.1 and Rust (v. 1.72, 1.76, 1.69, 1.62)
- PyTorch version checking
- Inquiry about Configuration Details for "ecir23-scratch-tydi-japanese-splade" Model HOT 4
- TypeError: main() got an unexpected keyword argument 'version_base' HOT 1
- How to install the ENV correctly?
- Inference Experiments HOT 2
- Change default to splade-v3
- Seeking Assistance with SPLADE Model for Chinese Text
- bug: TREC 2020 qrel_binary.json, score 1 should be treated as negative instead of positive
- Hybrid search & normalization
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from splade.