|
import sys |
|
import torch |
|
import os |
|
import random |
|
import base64 |
|
import msgpack |
|
from io import BytesIO |
|
import numpy as np |
|
|
|
from transformers import AutoTokenizer |
|
from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN |
|
from llava.conversation import conv_templates, SeparatorStyle |
|
from llava.utils import disable_torch_init |
|
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images_v2 |
|
from llava.model.builder import load_pretrained_model |
|
from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor |
|
from llava.model import LlavaMistralForCausalLM |
|
|
|
|
|
from transformers import CLIPImageProcessor |
|
from PIL import Image |
|
import logging |
|
|
|
def select_frames(input_frames, num_segments = 10): |
|
|
|
indices = np.linspace(start=0, stop=len(input_frames)-1, num=num_segments).astype(int) |
|
|
|
frames = [input_frames[ind] for ind in indices] |
|
|
|
return frames |
|
|
|
def load_model(model_path, device_map): |
|
kwargs = {"device_map": device_map} |
|
kwargs['torch_dtype'] = torch.float32 |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = LlavaMistralForCausalLM.from_pretrained( |
|
model_path, |
|
low_cpu_mem_usage=True, |
|
**kwargs |
|
) |
|
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN], special_tokens=True) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
vision_tower = model.get_vision_tower() |
|
if not vision_tower.is_loaded: |
|
vision_tower.load_model(device_map=device_map) |
|
|
|
return model, tokenizer |
|
|
|
|
|
class EndpointHandler: |
|
|
|
def __init__(self): |
|
model_path = './checkpoint-3000' |
|
disable_torch_init() |
|
model_path = os.path.expanduser(model_path) |
|
|
|
model_name = get_model_name_from_path(model_path) |
|
|
|
model, tokenizer = load_model(model_path, device_map={"":0}) |
|
|
|
|
|
image_processor = Blip2ImageTrainProcessor( |
|
image_size=model.config.img_size, |
|
is_training=False) |
|
|
|
""" |
|
import os |
|
from PIL import Image |
|
input_dir = './v12044gd0000clg1n4fog65p7pag5n6g/video' |
|
image_paths = os.listdir(input_dir) |
|
images = [Image.open(os.path.join(input_dir, item)) for item in image_paths] |
|
num_segments = 10 |
|
images = images[:num_segments] |
|
|
|
import torch |
|
device = torch.device('cuda:0') |
|
image_processor = Blip2ImageTrainProcessor( |
|
image_size=224, |
|
is_training=False) |
|
images_tensor = [image_processor.preprocess(image).cpu().to(device) for image in images] |
|
""" |
|
|
|
self.tokenizer = tokenizer |
|
self.device = torch.device('cpu') |
|
self.model = model.to(self.device) |
|
|
|
self.image_processor = image_processor |
|
self.conv_mode = 'v1' |
|
|
|
def inference_frames(self, images, question, temperature): |
|
|
|
if len(images) > 10: |
|
images = select_frames(images) |
|
|
|
conv_mode = self.conv_mode |
|
image_processor = self.image_processor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
images_tensor = process_images_v2(images, image_processor, self.model.config) |
|
images_tensor = images_tensor.to(self.device) |
|
|
|
|
|
qs = question |
|
|
|
if len(images) == 1: |
|
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs |
|
else: |
|
qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs |
|
|
|
conv = conv_templates[conv_mode].copy() |
|
conv.append_message(conv.roles[0], qs) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
|
|
input_ids = tokenizer_image_token(prompt, self.tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze( |
|
0).to(self.device) |
|
|
|
stop_str = conv.sep if conv.sep2 is None else conv.sep2 |
|
keywords = [stop_str] |
|
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) |
|
|
|
with torch.inference_mode(): |
|
output_ids = self.model.generate( |
|
input_ids, |
|
images=[images_tensor], |
|
temperature=temperature, |
|
do_sample=True, |
|
top_p=None, |
|
num_beams=1, |
|
no_repeat_ngram_size=3, |
|
max_new_tokens=1024, |
|
use_cache=True, |
|
stopping_criteria=[stopping_criteria], |
|
) |
|
|
|
|
|
outputs = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() |
|
|
|
outputs = outputs.strip() |
|
if outputs.endswith(conv.sep): |
|
outputs = outputs[:-len(stop_str)] |
|
outputs = outputs.strip() |
|
|
|
|
|
|
|
|
|
return outputs |
|
|
|
def __call__(self, request): |
|
|
|
|
|
packed_data= request['images'][0] |
|
unpacked_data = msgpack.unpackb(packed_data, raw=False) |
|
image_list = [Image.open(BytesIO(byte_data)) for byte_data in unpacked_data] |
|
prompt = request.get('prompt', [''.encode()])[0].decode() |
|
temperature = request.get('temperature', ['0.01'.encode()])[0].decode() |
|
temperature = float(temperature) |
|
|
|
|
|
|
|
if prompt=='': |
|
if len(image_list) == 1: |
|
prompt = "Please describe this image in detail." |
|
else: |
|
prompt = "Please describe this video in detail." |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.inference_frames(image_list, prompt, temperature) |
|
|
|
|
|
return {'output': [outputs]} |
|
|
|
|
|
if __name__ == "__main__": |
|
video_dir = '/mnt/bn/yukunfeng-nasdrive/xiangchen/masp_data/20231110_ttp/video/v12044gd0000cl5c6rfog65i2eoqcqig' |
|
frames = [(int(os.path.splitext(item)[0]), os.path.join(video_dir, item)) for item in os.listdir(video_dir)] |
|
frames = [item[1] for item in sorted(frames, key=lambda x: x[0])] |
|
out_frames = [Image.open(frame).convert('RGB') for frame in frames] |
|
|
|
|
|
|
|
request = {} |
|
|
|
|
|
byte_images = [] |
|
for img in out_frames: |
|
byte_io = BytesIO() |
|
img.save(byte_io, format='JPEG') |
|
byte_images.append(byte_io.getvalue()) |
|
|
|
|
|
packed_data = msgpack.packb(byte_images) |
|
request['images'] = [packed_data] |
|
|
|
request['temperature'] = ['0.01'.encode()] |
|
|
|
|
|
|
|
|
|
handler = EndpointHandler() |
|
print(handler(request)) |
|
|
|
|