Git Product home page Git Product logo

Comments (1)

kunal-vaishnavi avatar kunal-vaishnavi commented on July 18, 2024

Here's a script to export any decoder layer for LLaMA to an ONNX model using transformers with v4.34.1. This is created by modifying existing code in this folder.

import onnx
import os
import shutil
import torch
from transformers import AutoConfig, AutoModelForCausalLM

# Model settings (change these as needed)
dtype = torch.float32
model_name = "meta-llama/Meta-Llama-3-8B"
decoder_layer_id = 4
opset_version = 14

# Folder settings (change these as needed)
cache_dir = os.path.join(".", "cache_dir")
temp_dir = os.path.join(".", "temp")
onnx_dir = os.path.join(".", "onnx")

# File settings (change these as needed)
temp_filename = os.path.join(temp_dir, "temp.onnx")
onnx_filename = os.path.join(onnx_dir, "model.onnx")

# Make OS directories
os.makedirs(cache_dir, exist_ok=True)
os.makedirs(temp_dir, exist_ok=True)
os.makedirs(onnx_dir, exist_ok=True)

# Load PyTorch config and model
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
config.use_cache = True
config.torch_dtype = dtype
config._attn_implementation = "eager"
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=dtype, cache_dir=cache_dir).eval()

# Create dummy inputs for export
batch_size, sequence_length, num_heads, head_size = 2, 8, config.num_key_value_heads, config.hidden_size // config.num_attention_heads
inputs_embeds = torch.randn(batch_size, sequence_length, config.hidden_size, dtype=dtype)
attention_mask_2d = torch.ones(batch_size, sequence_length, dtype=torch.int64)
attention_mask_4d = torch.ones(batch_size, 1, sequence_length, sequence_length, dtype=torch.int64)
position_ids = attention_mask_2d.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask_2d == 0, 1)
past_key = torch.zeros(batch_size, num_heads, 0, head_size, dtype=dtype)
past_value = torch.zeros(batch_size, num_heads, 0, head_size, dtype=dtype)
output_attentions = False
use_cache = True
dummy_inputs = (inputs_embeds, attention_mask_4d, position_ids, [past_key, past_value], output_attentions, use_cache)

# Create input and output names for ONNX model (using Hugging Face naming style)
input_names = ["inputs_embeds", "attention_mask", "position_ids", f"past_key_values.{decoder_layer_id}.key", f"past_key_values.{decoder_layer_id}.value"]
output_names = ["hidden_states", f"present.{decoder_layer_id}.key", f"present.{decoder_layer_id}.value"]

# Create dynamic axes for each input and output
dynamic_axes = {}
dynamic_axes["inputs_embeds"] = {0: "batch_size", 1: "sequence_length"}
dynamic_axes["attention_mask"] = {0: "batch_size", 2: "source_sequence_length", 3: "total_sequence_length"}
dynamic_axes["position_ids"] = {0: "batch_size", 1: "sequence_length"}
dynamic_axes[f"past_key_values.{decoder_layer_id}.key"] = {0: "batch_size", 2: "past_sequence_length"}
dynamic_axes[f"past_key_values.{decoder_layer_id}.value"] = {0: "batch_size", 2: "past_sequence_length"}
dynamic_axes[f"present.{decoder_layer_id}.key"] = {0: "batch_size", 2: "total_sequence_length"}
dynamic_axes[f"present.{decoder_layer_id}.value"] = {0: "batch_size", 2: "total_sequence_length"}

# Export to ONNX
torch.onnx.export(
    eval(f"model.model.layers[{decoder_layer_id}]"),
    args=dummy_inputs,
    f=temp_filename,
    export_params=True,
    input_names=input_names,
    output_names=output_names,
    dynamic_axes=dynamic_axes,
    opset_version=opset_version,
    do_constant_folding=True,
)

# Check exported ONNX model and run shape inference (optional steps)
onnx.checker.check_model(temp_filename)
onnx.shape_inference.infer_shapes_path(temp_filename)

# Merge all external data files into one file (convenience step)
onnx_model = onnx.load_model(temp_filename, load_external_data=True)
onnx.save(
    onnx_model,
    onnx_filename,
    save_as_external_data=True,
    all_tensors_to_one_file=True,
    location=f"{os.path.basename(onnx_filename)}.data",
    size_threshold=0,
    convert_attribute=False,
)
shutil.rmtree(temp_dir)

ONNX Runtime GenAI

If the specific decoder layer does not matter, you can use ONNX Runtime GenAI's model builder to quickly create an optimized and/or quantized ONNX model with just the first layer.

Here are some example commands that should work for your scenario.

Option 1: Build model from wheel

  1. Install ONNX Runtime GenAI

  2. Create ONNX model using the model builder

$ python3 -m onnxruntime_genai.models.builder -m meta-llama/Meta-Llama-3-8B -o ./llama3_8b -p int4 -e cpu -c ./cache_dir --extra_options num_hidden_layers=1

Option 2: Build model from source

  1. Clone ONNX Runtime GenAI and navigate to the model builder
$ git clone https://github.com/microsoft/onnxruntime-genai
$ cd onnxruntime-genai/src/python/py/models/
  1. Create ONNX model using the model builder
$ python3 builder.py -m meta-llama/Meta-Llama-3-8B -o ./llama3_8b -p int4 -e cpu -c ./cache_dir --extra_options num_hidden_layers=1

This approach will build an optimized and/or quantized ONNX model with the first decoder layer in the PyTorch model.

from onnxruntime.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.