import os import gradio as gr from gradio_client import Client, handle_file from pathlib import Path from gradio.utils import get_cache_folder import torch import torchvision.transforms as transforms from PIL import Image import cv2 import numpy as np class Examples(gr.helpers.Examples): def __init__(self, *args, cached_folder=None, **kwargs): super().__init__(*args, **kwargs, _initiated_directly=False) if cached_folder is not None: self.cached_folder = cached_folder # self.cached_file = Path(self.cached_folder) / "log.csv" self.create() def postprocess(output, prompt): result = [] image = Image.open(output) w, h = image.size n = len(prompt) slice_width = w // n for i in range(n): left = i * slice_width right = (i + 1) * slice_width if i < n - 1 else w cropped_img = image.crop((left, 0, right, h)) # 生成 caption caption = prompt[i] # 存入列表 result.append((cropped_img, caption)) return result # user click the image to get points, and show the points on the image def get_point(img, sel_pix, evt: gr.SelectData): print(sel_pix) if len(sel_pix) < 5: sel_pix.append((evt.index, 1)) # default foreground_point img = cv2.imread(img) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # draw points for point, label in sel_pix: cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) # if img[..., 0][0, 0] == img[..., 2][0, 0]: # BGR to RGB # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) print(sel_pix) return img, sel_pix # undo the selected point def undo_points(orig_img, sel_pix): if isinstance(orig_img, int): # if orig_img is int, the image if select from examples temp = cv2.imread(image_examples[orig_img][0]) temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) else: temp = cv2.imread(orig_img) temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) # draw points if len(sel_pix) != 0: sel_pix.pop() for point, label in sel_pix: cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) if temp[..., 0][0, 0] == temp[..., 2][0, 0]: # BGR to RGB temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) return temp, sel_pix HF_TOKEN = os.environ.get('HF_KEY') client = Client("Canyu/Diception", max_workers=3, hf_token=HF_TOKEN) colors = [(255, 0, 0), (0, 255, 0)] markers = [1, 5] map_prompt = { 'depth': '[[image2depth]]', 'normal': '[[image2normal]]', 'human pose': '[[image2pose]]', 'entity segmentation': '[[image2panoptic coarse]]', 'point segmentation': '[[image2segmentation]]', 'semantic segmentation': '[[image2semantic]]', } def download_additional_params(model_name, filename="add_params.bin"): # 下载文件并返回文件路径 file_path = hf_hub_download(repo_id=model_name, filename=filename, use_auth_token=HF_TOKEN) return file_path # 加载 additional_params.bin 文件 def load_additional_params(model_name): # 下载 additional_params.bin params_path = download_additional_params(model_name) # 使用 torch.load() 加载文件内容 additional_params = torch.load(params_path, map_location='cpu') # 返回加载的参数内容 return additional_params def process_image_check(path_input, prompt, sel_points, semantic): if path_input is None: raise gr.Error( "Missing image in the left pane: please upload an image first." ) if len(prompt) == 0: raise gr.Error( "At least 1 prediction type is needed." ) def process_image_4(image_path, prompt): inputs = [] for p in prompt: cur_p = map_prompt[p] coor_point = [] point_labels = [] cur_input = { # 'original_size': [[w,h]], # 'target_size': [[768, 768]], 'prompt': [cur_p], 'coor_point': coor_point, 'point_labels': point_labels, } inputs.append(cur_input) return inputs def inf(image_path, prompt, sel_points, semantic): print('=========== PROCESS IMAGE CHECK ===========') print(f"Image Path: {image_path}") print(f"Prompt: {prompt}") print(f"Selected Points (before processing): {sel_points}") print(f"Semantic Input: {semantic}") print('===========================================') if 'point segmentation' in prompt and len(sel_points) == 0: raise gr.Error( "At least 1 point is needed." ) return if 'point segmentation' not in prompt and len(sel_points) != 0: raise gr.Error( "You must select 'point segmentation' when performing point segmentation." ) return if 'semantic segmentation' in prompt and semantic == '': raise gr.Error( "Target category is needed." ) return if 'semantic segmentation' not in prompt and semantic != '': raise gr.Error( "You must select 'semantic segmentation' when performing semantic segmentation." ) return # return None # inputs = process_image_4(image_path, prompt, sel_points, semantic) prompt_str = str(sel_points) result = client.predict( input_image=handle_file(image_path), checkbox_group=prompt, selected_points=prompt_str, semantic_input=semantic, api_name="/inf" ) result = postprocess(result, prompt) return result def clear_cache(): return None, None def run_demo_server(): options = ['depth', 'normal', 'entity segmentation', 'human pose', 'point segmentation', 'semantic segmentation'] gradio_theme = gr.themes.Default() with gr.Blocks( theme=gradio_theme, title="Diception", css=""" #download { height: 118px; } .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } .tabs button.selected { font-size: 20px !important; color: crimson !important; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } .md_feedback li { margin-bottom: 0px !important; } """, head=""" """, ) as demo: selected_points = gr.State([]) # store points original_image = gr.State(value=None) # store original image without points, default None gr.HTML( """