Running Demo.py successful, but export model OOM.
Device: A10 x 2
command:
OMP_NUM_THREADS=1 torchrun --nproc_per_node 2 Export.py --ckpt_dir /data/workspace/CodeLlama-13b --tokenizer_path /data/workspace/CodeLlama-13b/tokenizer.model --export_path /data/workspace/13_v2 --fused_qkv 1 --fused_kvcache 1 --auto_causal 1 --quantized_cache 1 --dynamic_batching 1
File "/data/workspace/ppl/ppl.pmx/model_zoo/llama/facebook/Export.py", line 44, in
fire.Fire(main)
File "/root/miniconda3/envs/test/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/root/miniconda3/envs/test/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/root/miniconda3/envs/test/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/data/workspace/ppl/ppl.pmx/model_zoo/llama/facebook/Export.py", line 40, in main
generator.export(export_path)
File "/data/workspace/ppl/ppl.pmx/model_zoo/llama/facebook/../../llama/modeling/dynamic_batching/Pipeline.py", line 291, in export
torch.onnx.export(
File "/root/miniconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py", line 506, in export
_export(
File "/root/miniconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py", line 1548, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/root/miniconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/root/miniconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/root/miniconda3/envs/test/lib/python3.10/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/root/miniconda3/envs/test/lib/python3.10/site-packages/torch/jit/_trace.py", line 1268, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/root/miniconda3/envs/test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/test/lib/python3.10/site-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/root/miniconda3/envs/test/lib/python3.10/site-packages/torch/jit/_trace.py", line 114, in wrapper
tuple(x.clone(memory_format=torch.preserve_format) for x in args)
File "/root/miniconda3/envs/test/lib/python3.10/site-packages/torch/jit/_trace.py", line 114, in
tuple(x.clone(memory_format=torch.preserve_format) for x in args)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 68.00 MiB (GPU 1; 22.20 GiB total capacity; 21.58 GiB already allocated; 12.12 MiB free; 21.58 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Have tried settings below:
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:32
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:64