import os import gradio as gr from gradio_client import Client, handle_file from pathlib import Path from gradio.utils import get_cache_folder import torch import torchvision.transforms as transforms from PIL import Image import cv2 import numpy as np class Examples(gr.helpers.Examples): def __init__(self, *args, cached_folder=None, **kwargs): super().__init__(*args, **kwargs, _initiated_directly=False) if cached_folder is not None: self.cached_folder = cached_folder # self.cached_file = Path(self.cached_folder) / "log.csv" self.create() def postprocess(output, prompt): result = [] image = Image.open(output) w, h = image.size n = len(prompt) slice_width = w // n for i in range(n): left = i * slice_width right = (i + 1) * slice_width if i < n - 1 else w cropped_img = image.crop((left, 0, right, h)) # 生成 caption caption = prompt[i] # 存入列表 result.append((cropped_img, caption)) return result # user click the image to get points, and show the points on the image def get_point(img, sel_pix, evt: gr.SelectData): print(sel_pix) if len(sel_pix) < 5: sel_pix.append((evt.index, 1)) # default foreground_point img = cv2.imread(img) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # draw points for point, label in sel_pix: cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) # if img[..., 0][0, 0] == img[..., 2][0, 0]: # BGR to RGB # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) print(sel_pix) return img, sel_pix # undo the selected point def undo_points(orig_img, sel_pix): if isinstance(orig_img, int): # if orig_img is int, the image if select from examples temp = cv2.imread(image_examples[orig_img][0]) temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) else: temp = cv2.imread(orig_img) temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) # draw points if len(sel_pix) != 0: sel_pix.pop() for point, label in sel_pix: cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) if temp[..., 0][0, 0] == temp[..., 2][0, 0]: # BGR to RGB temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) return temp, sel_pix HF_TOKEN = os.environ.get('HF_KEY') client = Client("Canyu/Diception", max_workers=3, hf_token=HF_TOKEN) colors = [(255, 0, 0), (0, 255, 0)] markers = [1, 5] map_prompt = { 'depth': '[[image2depth]]', 'normal': '[[image2normal]]', 'human pose': '[[image2pose]]', 'entity segmentation': '[[image2panoptic coarse]]', 'point segmentation': '[[image2segmentation]]', 'semantic segmentation': '[[image2semantic]]', } def download_additional_params(model_name, filename="add_params.bin"): # 下载文件并返回文件路径 file_path = hf_hub_download(repo_id=model_name, filename=filename, use_auth_token=HF_TOKEN) return file_path # 加载 additional_params.bin 文件 def load_additional_params(model_name): # 下载 additional_params.bin params_path = download_additional_params(model_name) # 使用 torch.load() 加载文件内容 additional_params = torch.load(params_path, map_location='cpu') # 返回加载的参数内容 return additional_params def process_image_check(path_input, prompt, sel_points, semantic): if path_input is None: raise gr.Error( "Missing image in the left pane: please upload an image first." ) if len(prompt) == 0: raise gr.Error( "At least 1 prediction type is needed." ) def process_image_4(image_path, prompt): inputs = [] for p in prompt: cur_p = map_prompt[p] coor_point = [] point_labels = [] cur_input = { # 'original_size': [[w,h]], # 'target_size': [[768, 768]], 'prompt': [cur_p], 'coor_point': coor_point, 'point_labels': point_labels, } inputs.append(cur_input) return inputs def inf(image_path, prompt, sel_points, semantic): print('=========== PROCESS IMAGE CHECK ===========') print(f"Image Path: {image_path}") print(f"Prompt: {prompt}") print(f"Selected Points (before processing): {sel_points}") print(f"Semantic Input: {semantic}") print('===========================================') if 'point segmentation' in prompt and len(sel_points) == 0: raise gr.Error( "At least 1 point is needed." ) return if 'point segmentation' not in prompt and len(sel_points) != 0: raise gr.Error( "You must select 'point segmentation' when performing point segmentation." ) return if 'semantic segmentation' in prompt and semantic == '': raise gr.Error( "Target category is needed." ) return if 'semantic segmentation' not in prompt and semantic != '': raise gr.Error( "You must select 'semantic segmentation' when performing semantic segmentation." ) return # return None # inputs = process_image_4(image_path, prompt, sel_points, semantic) prompt_str = str(sel_points) result = client.predict( input_image=handle_file(image_path), checkbox_group=prompt, selected_points=prompt_str, semantic_input=semantic, api_name="/inf" ) result = postprocess(result, prompt) return result def clear_cache(): return None, None def run_demo_server(): options = ['depth', 'normal', 'entity segmentation', 'human pose', 'point segmentation', 'semantic segmentation'] gradio_theme = gr.themes.Default() with gr.Blocks( theme=gradio_theme, title="Diception", css=""" #download { height: 118px; } .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } .tabs button.selected { font-size: 20px !important; color: crimson !important; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } .md_feedback li { margin-bottom: 0px !important; } """, head=""" """, ) as demo: selected_points = gr.State([]) # store points original_image = gr.State(value=None) # store original image without points, default None gr.HTML( """

DICEPTION: A Generalist Diffusion Model for Vision Perception

One single model solves multiple perception tasks, producing impressive results!

badge-github-stars

""" ) with gr.Row(): checkbox_group = gr.CheckboxGroup(choices=options, label="Select options:") with gr.Row(): semantic_input = gr.Textbox(label="Category Name (for semantic segmentation only, in COCO)", placeholder="e.g. person/cat/dog/elephant......") with gr.Row(): gr.Markdown('For non-human image inputs, the pose results may have issues. Same when perform semantic segmentation with categories that are not in COCO.') with gr.Row(): with gr.Column(): input_image = gr.Image( label="Input Image", type="filepath", ) with gr.Column(): with gr.Row(): gr.Markdown('You can click on the image to select points prompt. At most 5 point.') matting_image_submit_btn = gr.Button( value="Run", variant="primary" ) with gr.Row(): undo_button = gr.Button('Undo point') matting_image_reset_btn = gr.Button(value="Reset") # with gr.Row(): # img_clear_button = gr.Button("Clear Cache") with gr.Column(): # matting_image_output = gr.Image(label='Output') # matting_image_output = gr.Image(label='Results') matting_image_output = gr.Gallery(label="Results") # label="Matting Output", # type="filepath", # show_download_button=True, # show_share_button=True, # interactive=False, # elem_classes="slider", # position=0.25, # ) # img_clear_button.click(clear_cache, outputs=[input_image, matting_image_output]) matting_image_submit_btn.click( fn=process_image_check, inputs=[input_image, checkbox_group, selected_points, semantic_input], outputs=None, preprocess=False, queue=False, ).success( # fn=process_pipe_matting, fn=inf, inputs=[input_image, checkbox_group, selected_points, semantic_input], outputs=[matting_image_output], concurrency_limit=1, ) matting_image_reset_btn.click( fn=lambda: ( None, None, [] ), inputs=[], outputs=[ input_image, matting_image_output, selected_points ], queue=False, ) # once user upload an image, the original image is stored in `original_image` def store_img(img): return img, [] # when new image is uploaded, `selected_points` should be empty input_image.upload( store_img, [input_image], [original_image, selected_points] ) input_image.select( get_point, [input_image, selected_points], [input_image, selected_points], ) undo_button.click( undo_points, [original_image, selected_points], [input_image, selected_points] ) # gr.Examples( # fn=inf, # examples=[ # ["assets/person.jpg", ['depth', 'normal', 'entity segmentation', 'pose']] # ], # inputs=[input_image, checkbox_group], # outputs=[matting_image_output], # cache_examples=True, # # cache_examples=False, # # cached_folder="cache_dir", # ) demo.queue( api_open=False, ).launch() if __name__ == '__main__': run_demo_server()