Git Product home page Git Product logo

Comments (6)

risan-raja avatar risan-raja commented on May 27, 2024 2

Hi @ntnq4 ,
I have managed to convert the splade models to onnx. Although I used the pretrained checkpoint. I am aware it is counterintuitive for you but nevertheless if this helps, I am glad.
To reproduce:

  • Convert the model to a torchscript.

model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", torchscript=True) # type: ignore

import torch
from transformers import AutoModelForMaskedLM,AutoTokenizer # type: ignore

class TransformerRep(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", torchscript=True) # type: ignore
        self.model.eval() # type: ignore
        self.fp16 = True

    def encode(self, input_ids, token_type_ids, attention_mask):
        # Tokens is a dict with keys input_ids and attention_mask
        return self.model(input_ids, token_type_ids, attention_mask)[0]



class SpladeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = TransformerRep()
        self.agg = "max"
        self.model.eval()
    
    def forward(self, input_ids,token_type_ids, attention_mask):
        with torch.cuda.amp.autocast(): # type: ignore
            with torch.no_grad():
                lm_logits = self.model.encode(input_ids,token_type_ids, attention_mask)[0]
                vec, _ = torch.max(torch.log(1 + torch.relu(lm_logits)) * attention_mask.unsqueeze(-1), dim=1)
                indices = vec.nonzero().squeeze()
                weights = vec.squeeze()[indices]
        return indices[:,1], weights[:,1]

# Convert the model to TorchScript
model = SpladeModel()
tokenizer = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil")
sample = "the capital of france is paris"
inputs = tokenizer(sample, return_tensors="pt")
traced_model = torch.jit.trace(model, (inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]))
  • Later Load it from File and convert it using a dummy input. Make sure to adjust the above script to match your implementation.
import torch
dyn_axis = {
    'input_ids': {0: 'batch_size', 1: 'sequence'},
    'attention_mask': {0: 'batch_size', 1: 'sequence'},
    'token_type_ids': {0: 'batch_size', 1: 'sequence'},
    'indices': {0: 'batch_size', 1: 'sequence'},
    'weights': {0: 'batch_size', 1: 'sequence'}
    }
model = torch.jit.load(model_file)
onnx_model = torch.onnx.export(
    model,
    dummy_input, # type: ignore
    f=model_onnx_file,
    input_names=['input_ids','token_type_ids', 'attention_mask'],
    output_names=['indices', 'weights'],
    dynamic_axes=dyn_axis,
    do_constant_folding=True,
    opset_version=15,
    verbose=False,
)
  • Using this method I have managed to convert the following HF models successfully.
model_names= [
   "naver/splade_v2_max",
   "naver/splade_v2_distil",
   "naver/splade-cocondenser-ensembledistil",
   "naver/efficient-splade-VI-BT-large-query",
   "naver/efficient-splade-VI-BT-large-doc",
]

requirements:

  • torch==2.2.0

Hope this helps! :)

from splade.

sroussey avatar sroussey commented on May 27, 2024 1

Example: https://huggingface.co/Xenova/t5-small-awesome-text-to-sql/tree/main/

from splade.

thibault-formal avatar thibault-formal commented on May 27, 2024

Hi @ntnq4

Not that I am aware of. I am not super familiar with ONNX - did you manage to make it work?

from splade.

ntnq4 avatar ntnq4 commented on May 27, 2024

Hi @thibault-formal

I didn't manage to make it work unfortunately... I tried this tutorial but it didn't work for my SPLADE model.

I also found this recent paper that mentionned this conversion.

from splade.

ntnq4 avatar ntnq4 commented on May 27, 2024

Hi @risan-raja,

Thank you for your help : )
I will try your solution on my side.

from splade.

sroussey avatar sroussey commented on May 27, 2024

if an ONNX conversion was added to HuggingFace in a folder called onnx then it would automatically become available to HuggingFace Transformers.js and be usable locally on the web.

from splade.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.