from __future__ import annotations import datetime import os import pathlib import shlex import shutil import subprocess import sys import gradio as gr import torch import numpy as np import huggingface_hub from huggingface_hub import HfApi from omegaconf import OmegaConf from segment_anything import sam_model_registry from dinov2.models import vision_transformer as vits import dinov2.utils.utils as dinov2_utils from gradio_demo.oss_ops_inference import main_oss_ops ORIGINAL_SPACE_ID = '' SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID) class Runner: def __init__(self, hf_token: str | None = None): self.hf_token = hf_token # self.checkpoint_dir = pathlib.Path('checkpoints') # self.checkpoint_dir.mkdir(exist_ok=True) # oss, ops self.prompt_res_g = None self.prompt_mask_g = None self.tar1_res_g = None self.tar2_res_g = None self.version = 1 self.pred_masks = None self.pred_mask_lists = None device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") sam_checkpoint = "models/sam_vit_h_4b8939.pth" model_type = "default" self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) self.sam.to(device=device) dinov2_kwargs = dict( img_size=518, patch_size=14, init_values=1e-5, ffn_layer='mlp', block_chunks=0, qkv_bias=True, proj_bias=True, ffn_bias=True, ) dinov2 = vits.__dict__["vit_large"](**dinov2_kwargs) dinov2_utils.load_pretrained_weights(dinov2, "models/dinov2_vitl14_pretrain.pth", "teacher") dinov2.eval() dinov2.to(device=device) self.dinov2 = dinov2 def inference_oss_ops(self, prompt, target1, target2, version): if version == 'version 1 (🔺 multiple instances 🔻 whole, 🔻 part)': self.prompt_res_g, self.tar1_res_g, self.tar2_res_g = prompt['image'], target1, target2 self.prompt_mask_g = (prompt['mask'][..., 0] != 0)[None, ...] # 1, H, w self.version = 1 elif version == 'version 2 (🔻 multiple instances 🔺 whole, 🔻 part)': self.prompt_res_g, self.tar1_res_g, self.tar2_res_g = prompt['image'], target1, target2 self.prompt_mask_g = (prompt['mask'][..., 0] != 0)[None, ...] # 1, H, w self.version = 2 else: self.prompt_res_g, self.tar1_res_g, self.tar2_res_g = prompt['image'], target1, target2 self.prompt_mask_g = (prompt['mask'][..., 0] != 0)[None, ...] # 1, H, w self.version = 3 self.pred_masks, self.pred_mask_lists = main_oss_ops( sam=self.sam, dinov2=self.dinov2, support_img=self.prompt_res_g, support_mask=self.prompt_mask_g, query_img_1=self.tar1_res_g, query_img_2=self.tar2_res_g, version=self.version ) text = "Process Successful!" return text def clear_fn(self): self.prompt_res_g, self.tar1_res_g, self.tar2_res_g, self.prompt_mask_g = None, None, None, None self.version = 1 self.pred_masks = None self.pred_mask_lists = None return [None] * 7 def controllable_mask_output(self, k): color = np.array([30, 144, 255]) if self.version != 1: prompt_mask_res, tar1_mask_res, tar2_mask_res = self.pred_masks h, w = prompt_mask_res.shape[-2:] prompt_mask_img = prompt_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1) prompt_mask_res = self.prompt_res_g * 0.5 + prompt_mask_img * 0.5 h, w = tar1_mask_res.shape[-2:] tar1_mask_img = tar1_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1) tar1_mask_res = self.tar1_res_g * 0.5 + tar1_mask_img * 0.5 h, w = tar2_mask_res.shape[-2:] tar2_mask_img = tar2_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1) tar2_mask_res = self.tar2_res_g * 0.5 + tar2_mask_img * 0.5 else: prompt_mask_res = self.pred_masks[0] tar1_mask_res, tar2_mask_res = self.pred_mask_lists[1:] tar1_mask_res = tar1_mask_res[:min(k, len(tar1_mask_res))].sum(0)>0 tar2_mask_res = tar2_mask_res[:min(k, len(tar2_mask_res))].sum(0) > 0 h, w = prompt_mask_res.shape[-2:] prompt_mask_img = prompt_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1) prompt_mask_res = self.prompt_res_g * 0.5 + prompt_mask_img * 0.5 h, w = tar1_mask_res.shape[-2:] tar1_mask_img = tar1_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1) tar1_mask_res = self.tar1_res_g * 0.5 + tar1_mask_img * 0.5 h, w = tar2_mask_res.shape[-2:] tar2_mask_img = tar2_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1) tar2_mask_res = self.tar2_res_g * 0.5 + tar2_mask_img * 0.5 return prompt_mask_res/255, tar1_mask_res/255, tar2_mask_res/255 def inference_vos(self, prompt_vid, vid): raise NotImplementedError