|
import sys |
|
import torch |
|
import os |
|
import random |
|
import base64 |
|
import msgpack |
|
from io import BytesIO |
|
import numpy as np |
|
|
|
from transformers import AutoTokenizer |
|
from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN |
|
from llava.conversation import conv_templates, SeparatorStyle |
|
from llava.utils import disable_torch_init |
|
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images_v2 |
|
from llava.model.builder import load_pretrained_model |
|
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor |
|
from llava.model import LlavaMistralForCausalLM |
|
|
|
def load_model(model_path, device_map): |
|
kwargs = {"device_map": device_map} |
|
kwargs['torch_dtype'] = torch.float16 |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = LlavaMistralForCausalLM.from_pretrained( |
|
model_path, |
|
low_cpu_mem_usage=True, |
|
**kwargs |
|
) |
|
tokenizer.add_tokens( |
|
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN], |
|
special_tokens=True |
|
) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
vision_tower = model.get_vision_tower() |
|
if not vision_tower.is_loaded: |
|
vision_tower.load_model(device_map=device_map) |
|
|
|
return model, tokenizer |
|
|
|
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
model, tokenizer = load_model("./masp_094_v2", device_map={"": 0}) |
|
|
|
|
|
vitmodel = model.get_vision_tower() |
|
vitmodel.to(device) |
|
|
|
|
|
dummy_input = torch.randn(10, 3, 224, 224, device=device, dtype=torch.float16) |
|
|
|
|
|
onnx_path = "vit_model.onnx" |
|
with torch.no_grad(): |
|
torch.onnx.export( |
|
vitmodel, |
|
dummy_input, |
|
onnx_path, |
|
export_params=True, |
|
opset_version=12, |
|
do_constant_folding=True, |
|
input_names=['input'], |
|
output_names=['output'], |
|
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}, |
|
verbose=True |
|
) |
|
|
|
exit() |