import os import gradio as gr from gradio_client import Client, handle_file from pathlib import Path from gradio.utils import get_cache_folder import torch import torchvision.transforms as transforms from PIL import Image class Examples(gr.helpers.Examples): def __init__(self, *args, cached_folder=None, **kwargs): super().__init__(*args, **kwargs, _initiated_directly=False) if cached_folder is not None: self.cached_folder = cached_folder # self.cached_file = Path(self.cached_folder) / "log.csv" self.create() HF_TOKEN = os.environ.get('HF_KEY') client = Client("Canyu/Diception", max_workers=3, hf_token=HF_TOKEN) map_prompt = { 'depth': '[[image2depth]]', 'normal': '[[image2normal]]', 'pose': '[[image2pose]]', 'entity segmentation': '[[image2panoptic coarse]]', 'point segmentation': '[[image2segmentation]]', 'semantic segmentation': '[[image2semantic]]', } def download_additional_params(model_name, filename="add_params.bin"): # 下载文件并返回文件路径 file_path = hf_hub_download(repo_id=model_name, filename=filename, use_auth_token=HF_TOKEN) return file_path # 加载 additional_params.bin 文件 def load_additional_params(model_name): # 下载 additional_params.bin params_path = download_additional_params(model_name) # 使用 torch.load() 加载文件内容 additional_params = torch.load(params_path, map_location='cpu') # 返回加载的参数内容 return additional_params def process_image_check(path_input, prompt): if path_input is None: raise gr.Error( "Missing image in the left pane: please upload an image first." ) if len(prompt) == 0: raise gr.Error( "At least 1 prediction type is needed." ) def process_image_4(image_path, prompt): inputs = [] for p in prompt: cur_p = map_prompt[p] image = Image.open(image_path) w, h = image.size coor_point = torch.zeros((1,5,2)).to(torch.float32) point_labels = torch.zeros((1,5,1)).to(torch.float32) image = image.resize((768, 768), Image.LANCZOS).convert('RGB') to_tensor = transforms.ToTensor() image = (to_tensor(image) - 0.5) * 2 cur_input = { 'input_images': image.unsqueeze(0), 'original_size': torch.tensor([[w,h]]), 'target_size': torch.tensor([[768, 768]]), 'prompt': [cur_p], 'coor_point': coor_point, 'point_labels': point_labels, 'generator': generator } inputs.append(cur_input) return inputs def inf(image_path, prompt): print(image_path) print(prompt) inputs = process_image_4(image_path, prompt) # return None return client.predict( data=inputs, api_name="/inf" ) def clear_cache(): return None, None def run_demo_server(): options = ['depth', 'normal', 'entity', 'pose'] gradio_theme = gr.themes.Default() with gr.Blocks( theme=gradio_theme, title="Matting", ) as demo: with gr.Row(): gr.Markdown("# Diception Demo") with gr.Row(): gr.Markdown("### All results are generated using the same single model. To facilitate input processing, we separate point-prompted segmentation and semantic segmentation, as they require input points and segmentation targets.") with gr.Row(): checkbox_group = gr.CheckboxGroup(choices=options, label="Select options:") with gr.Row(): with gr.Column(): matting_image_input = gr.Image( label="Input Image", type="filepath", ) with gr.Row(): matting_image_submit_btn = gr.Button( value="Estimate Matting", variant="primary" ) matting_image_reset_btn = gr.Button(value="Reset") with gr.Row(): img_clear_button = gr.Button("Clear Cache") with gr.Column(): # matting_image_output = gr.Image(label='Output') matting_image_output = gr.Image(label='Matting Output') # label="Matting Output", # type="filepath", # show_download_button=True, # show_share_button=True, # interactive=False, # elem_classes="slider", # position=0.25, # ) img_clear_button.click(clear_cache, outputs=[matting_image_input, matting_image_output]) matting_image_submit_btn.click( fn=process_image_check, inputs=[matting_image_input, checkbox_group], outputs=None, preprocess=False, queue=False, ).success( # fn=process_pipe_matting, fn=inf, inputs=[ matting_image_input, checkbox_group ], outputs=[matting_image_output], concurrency_limit=1, ) matting_image_reset_btn.click( fn=lambda: ( None, None, ), inputs=[], outputs=[ matting_image_input, matting_image_output, ], queue=False, ) gr.Examples( fn=inf, examples=[ ["assets/person.jpg", ['depth', 'normal', 'entity', 'pose']] ], inputs=[matting_image_input, checkbox_group], outputs=[matting_image_output], cache_examples=True, # cache_examples=False, # cached_folder="cache_dir", ) demo.queue( api_open=False, ).launch() if __name__ == '__main__': run_demo_server()