model1 / onnx_convert.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
import sys
import torch
import os
import random
import base64
import msgpack
from io import BytesIO
import numpy as np
from transformers import AutoTokenizer
from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images_v2
from llava.model.builder import load_pretrained_model
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor
from llava.model import LlavaMistralForCausalLM
def load_model(model_path, device_map):
kwargs = {"device_map": device_map}
kwargs['torch_dtype'] = torch.float16 # Ensure correct data type
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = LlavaMistralForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs
)
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN],
special_tokens=True
)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model(device_map=device_map)
return model, tokenizer
# Get the device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Load the model
model, tokenizer = load_model("./masp_094_v2", device_map={"": 0})
# Extract the vision tower
vitmodel = model.get_vision_tower()
vitmodel.to(device) # Ensure the vision tower is on the correct device
# Create a dummy input tensor for the vision tower
dummy_input = torch.randn(10, 3, 224, 224, device=device, dtype=torch.float16)
# Export the vision tower to ONNX
onnx_path = "vit_model.onnx"
with torch.no_grad():
torch.onnx.export(
vitmodel,
dummy_input,
onnx_path,
export_params=True,
opset_version=12, # Use a newer opset version for better compatibility
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
verbose=True
)
exit()