Comments (1)
Here's a script to export any decoder layer for LLaMA to an ONNX model using transformers
with v4.34.1
. This is created by modifying existing code in this folder.
import onnx
import os
import shutil
import torch
from transformers import AutoConfig, AutoModelForCausalLM
# Model settings (change these as needed)
dtype = torch.float32
model_name = "meta-llama/Meta-Llama-3-8B"
decoder_layer_id = 4
opset_version = 14
# Folder settings (change these as needed)
cache_dir = os.path.join(".", "cache_dir")
temp_dir = os.path.join(".", "temp")
onnx_dir = os.path.join(".", "onnx")
# File settings (change these as needed)
temp_filename = os.path.join(temp_dir, "temp.onnx")
onnx_filename = os.path.join(onnx_dir, "model.onnx")
# Make OS directories
os.makedirs(cache_dir, exist_ok=True)
os.makedirs(temp_dir, exist_ok=True)
os.makedirs(onnx_dir, exist_ok=True)
# Load PyTorch config and model
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
config.use_cache = True
config.torch_dtype = dtype
config._attn_implementation = "eager"
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=dtype, cache_dir=cache_dir).eval()
# Create dummy inputs for export
batch_size, sequence_length, num_heads, head_size = 2, 8, config.num_key_value_heads, config.hidden_size // config.num_attention_heads
inputs_embeds = torch.randn(batch_size, sequence_length, config.hidden_size, dtype=dtype)
attention_mask_2d = torch.ones(batch_size, sequence_length, dtype=torch.int64)
attention_mask_4d = torch.ones(batch_size, 1, sequence_length, sequence_length, dtype=torch.int64)
position_ids = attention_mask_2d.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask_2d == 0, 1)
past_key = torch.zeros(batch_size, num_heads, 0, head_size, dtype=dtype)
past_value = torch.zeros(batch_size, num_heads, 0, head_size, dtype=dtype)
output_attentions = False
use_cache = True
dummy_inputs = (inputs_embeds, attention_mask_4d, position_ids, [past_key, past_value], output_attentions, use_cache)
# Create input and output names for ONNX model (using Hugging Face naming style)
input_names = ["inputs_embeds", "attention_mask", "position_ids", f"past_key_values.{decoder_layer_id}.key", f"past_key_values.{decoder_layer_id}.value"]
output_names = ["hidden_states", f"present.{decoder_layer_id}.key", f"present.{decoder_layer_id}.value"]
# Create dynamic axes for each input and output
dynamic_axes = {}
dynamic_axes["inputs_embeds"] = {0: "batch_size", 1: "sequence_length"}
dynamic_axes["attention_mask"] = {0: "batch_size", 2: "source_sequence_length", 3: "total_sequence_length"}
dynamic_axes["position_ids"] = {0: "batch_size", 1: "sequence_length"}
dynamic_axes[f"past_key_values.{decoder_layer_id}.key"] = {0: "batch_size", 2: "past_sequence_length"}
dynamic_axes[f"past_key_values.{decoder_layer_id}.value"] = {0: "batch_size", 2: "past_sequence_length"}
dynamic_axes[f"present.{decoder_layer_id}.key"] = {0: "batch_size", 2: "total_sequence_length"}
dynamic_axes[f"present.{decoder_layer_id}.value"] = {0: "batch_size", 2: "total_sequence_length"}
# Export to ONNX
torch.onnx.export(
eval(f"model.model.layers[{decoder_layer_id}]"),
args=dummy_inputs,
f=temp_filename,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset_version,
do_constant_folding=True,
)
# Check exported ONNX model and run shape inference (optional steps)
onnx.checker.check_model(temp_filename)
onnx.shape_inference.infer_shapes_path(temp_filename)
# Merge all external data files into one file (convenience step)
onnx_model = onnx.load_model(temp_filename, load_external_data=True)
onnx.save(
onnx_model,
onnx_filename,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"{os.path.basename(onnx_filename)}.data",
size_threshold=0,
convert_attribute=False,
)
shutil.rmtree(temp_dir)
ONNX Runtime GenAI
If the specific decoder layer does not matter, you can use ONNX Runtime GenAI's model builder to quickly create an optimized and/or quantized ONNX model with just the first layer.
Here are some example commands that should work for your scenario.
Option 1: Build model from wheel
-
Install ONNX Runtime GenAI
-
Create ONNX model using the model builder
$ python3 -m onnxruntime_genai.models.builder -m meta-llama/Meta-Llama-3-8B -o ./llama3_8b -p int4 -e cpu -c ./cache_dir --extra_options num_hidden_layers=1
Option 2: Build model from source
- Clone ONNX Runtime GenAI and navigate to the model builder
$ git clone https://github.com/microsoft/onnxruntime-genai
$ cd onnxruntime-genai/src/python/py/models/
- Create ONNX model using the model builder
$ python3 builder.py -m meta-llama/Meta-Llama-3-8B -o ./llama3_8b -p int4 -e cpu -c ./cache_dir --extra_options num_hidden_layers=1
This approach will build an optimized and/or quantized ONNX model with the first decoder layer in the PyTorch model.
from onnxruntime.
Related Issues (20)
- [Build] RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running MatMul node. Name:'/MatMul_7' Status Message: /onnxruntime_src/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape&) status.IsOK() was false. Shape mismatch attempting to re-use buffer. {1,1,512} != {1,32,512}. Validate usage of dim_value (values should be > 0) and dim_param (all values with the same string should equate to the same size) in shapes in the model. HOT 7
- Not able to load onnx model multilingual-e5-large HOT 3
- [Crash] Crash while loading AlibabaNLP/gte-base ONNX model HOT 5
- Model saved with offline basic optimizations will not load - ShapeInferenceError HOT 1
- [Training] [ShapeInferenceError] Dimension could not be inferred: incompatible shapes
- [Build] How can I quantize the llama3 model activation to int4 ?
- [Feature Request] ORT-Profiler: Include timestamps for tensor allocations and deallocations. HOT 2
- header files path not recognized or unable to read header file HOT 1
- [Build] AllocatorTest.CUDAAllocatorFallbackTest failed HOT 1
- [Performance] Get nan value when I block all the node in fp16 conversion HOT 8
- [Bug] The per_tensor quantized weight type of matmul is wrong HOT 1
- ONNX Runtime 1.18.1 CUDA 12.4 cuDNN 9.2 breaks inference with repeated inputs when enable_mem_reuse is enabled
- Latest Release(1.18.1) Java Artifacts Unavailable HOT 1
- [Build] C++ API cannot be reliably linked with an program using CMake
- [BUG] CANN: onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError]
- [Build] Cross compilation of the library for ARMv7 32bit target with gcc 8.3 HOT 4
- CUDA 12 and session.get_providers() not showing CUDAExecutionProvider HOT 9
- [Web] Memory access out of bounds / alignment fault
- An error occurred when I installed onnxruntime-qnn in an Arm environment HOT 3
- [Performance] Multiple Sessions on Same GPU is very slow
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from onnxruntime.