test / runner.py
Xeraphinite's picture
Update runner.py
0d6ffe1 verified
from __future__ import annotations
import datetime
import os
import pathlib
import shlex
import shutil
import subprocess
import sys
import gradio as gr
import torch
import numpy as np
import huggingface_hub
from huggingface_hub import HfApi
from omegaconf import OmegaConf
from segment_anything import sam_model_registry
from dinov2.models import vision_transformer as vits
import dinov2.utils.utils as dinov2_utils
from gradio_demo.oss_ops_inference import main_oss_ops
ORIGINAL_SPACE_ID = ''
SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
class Runner:
def __init__(self, hf_token: str | None = None):
self.hf_token = hf_token
# self.checkpoint_dir = pathlib.Path('checkpoints')
# self.checkpoint_dir.mkdir(exist_ok=True)
# oss, ops
self.prompt_res_g = None
self.prompt_mask_g = None
self.tar1_res_g = None
self.tar2_res_g = None
self.version = 1
self.pred_masks = None
self.pred_mask_lists = None
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
sam_checkpoint = "models/sam_vit_h_4b8939.pth"
model_type = "default"
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
self.sam.to(device=device)
dinov2_kwargs = dict(
img_size=518,
patch_size=14,
init_values=1e-5,
ffn_layer='mlp',
block_chunks=0,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
)
dinov2 = vits.__dict__["vit_large"](**dinov2_kwargs)
dinov2_utils.load_pretrained_weights(dinov2, "models/dinov2_vitl14_pretrain.pth", "teacher")
dinov2.eval()
dinov2.to(device=device)
self.dinov2 = dinov2
def inference_oss_ops(self, prompt, target1, target2, version):
if version == 'version 1 (πŸ”Ί multiple instances πŸ”» whole, πŸ”» part)':
self.prompt_res_g, self.tar1_res_g, self.tar2_res_g = prompt['image'], target1, target2
self.prompt_mask_g = (prompt['mask'][..., 0] != 0)[None, ...] # 1, H, w
self.version = 1
elif version == 'version 2 (πŸ”» multiple instances πŸ”Ί whole, πŸ”» part)':
self.prompt_res_g, self.tar1_res_g, self.tar2_res_g = prompt['image'], target1, target2
self.prompt_mask_g = (prompt['mask'][..., 0] != 0)[None, ...] # 1, H, w
self.version = 2
else:
self.prompt_res_g, self.tar1_res_g, self.tar2_res_g = prompt['image'], target1, target2
self.prompt_mask_g = (prompt['mask'][..., 0] != 0)[None, ...] # 1, H, w
self.version = 3
self.pred_masks, self.pred_mask_lists = main_oss_ops(
sam=self.sam,
dinov2=self.dinov2,
support_img=self.prompt_res_g,
support_mask=self.prompt_mask_g,
query_img_1=self.tar1_res_g,
query_img_2=self.tar2_res_g,
version=self.version
)
text = "Process Successful!"
return text
def clear_fn(self):
self.prompt_res_g, self.tar1_res_g, self.tar2_res_g, self.prompt_mask_g = None, None, None, None
self.version = 1
self.pred_masks = None
self.pred_mask_lists = None
return [None] * 7
def controllable_mask_output(self, k):
color = np.array([30, 144, 255])
if self.version != 1:
prompt_mask_res, tar1_mask_res, tar2_mask_res = self.pred_masks
h, w = prompt_mask_res.shape[-2:]
prompt_mask_img = prompt_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
prompt_mask_res = self.prompt_res_g * 0.5 + prompt_mask_img * 0.5
h, w = tar1_mask_res.shape[-2:]
tar1_mask_img = tar1_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
tar1_mask_res = self.tar1_res_g * 0.5 + tar1_mask_img * 0.5
h, w = tar2_mask_res.shape[-2:]
tar2_mask_img = tar2_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
tar2_mask_res = self.tar2_res_g * 0.5 + tar2_mask_img * 0.5
else:
prompt_mask_res = self.pred_masks[0]
tar1_mask_res, tar2_mask_res = self.pred_mask_lists[1:]
tar1_mask_res = tar1_mask_res[:min(k, len(tar1_mask_res))].sum(0)>0
tar2_mask_res = tar2_mask_res[:min(k, len(tar2_mask_res))].sum(0) > 0
h, w = prompt_mask_res.shape[-2:]
prompt_mask_img = prompt_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
prompt_mask_res = self.prompt_res_g * 0.5 + prompt_mask_img * 0.5
h, w = tar1_mask_res.shape[-2:]
tar1_mask_img = tar1_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
tar1_mask_res = self.tar1_res_g * 0.5 + tar1_mask_img * 0.5
h, w = tar2_mask_res.shape[-2:]
tar2_mask_img = tar2_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
tar2_mask_res = self.tar2_res_g * 0.5 + tar2_mask_img * 0.5
return prompt_mask_res/255, tar1_mask_res/255, tar2_mask_res/255
def inference_vos(self, prompt_vid, vid):
raise NotImplementedError