import sys import torch import os import random import base64 import msgpack from io import BytesIO import numpy as np from transformers import AutoTokenizer from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.utils import disable_torch_init from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images_v2 from llava.model.builder import load_pretrained_model from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor from llava.model import LlavaMistralForCausalLM def load_model(model_path, device_map): kwargs = {"device_map": device_map} kwargs['torch_dtype'] = torch.float16 # Ensure correct data type tokenizer = AutoTokenizer.from_pretrained(model_path) model = LlavaMistralForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) tokenizer.add_tokens( [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN], special_tokens=True ) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model(device_map=device_map) return model, tokenizer # Get the device device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Load the model model, tokenizer = load_model("./masp_094_v2", device_map={"": 0}) # Extract the vision tower vitmodel = model.get_vision_tower() vitmodel.to(device) # Ensure the vision tower is on the correct device # Create a dummy input tensor for the vision tower dummy_input = torch.randn(10, 3, 224, 224, device=device, dtype=torch.float16) # Export the vision tower to ONNX onnx_path = "vit_model.onnx" with torch.no_grad(): torch.onnx.export( vitmodel, dummy_input, onnx_path, export_params=True, opset_version=12, # Use a newer opset version for better compatibility do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}, verbose=True ) exit()