Spaces:
Running
Running
import os | |
import gradio as gr | |
from gradio_client import Client, handle_file | |
from pathlib import Path | |
from gradio.utils import get_cache_folder | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
class Examples(gr.helpers.Examples): | |
def __init__(self, *args, cached_folder=None, **kwargs): | |
super().__init__(*args, **kwargs, _initiated_directly=False) | |
if cached_folder is not None: | |
self.cached_folder = cached_folder | |
# self.cached_file = Path(self.cached_folder) / "log.csv" | |
self.create() | |
HF_TOKEN = os.environ.get('HF_KEY') | |
client = Client("Canyu/Diception", | |
max_workers=3, | |
hf_token=HF_TOKEN) | |
map_prompt = { | |
'depth': '[[image2depth]]', | |
'normal': '[[image2normal]]', | |
'pose': '[[image2pose]]', | |
'entity segmentation': '[[image2panoptic coarse]]', | |
'point segmentation': '[[image2segmentation]]', | |
'semantic segmentation': '[[image2semantic]]', | |
} | |
def download_additional_params(model_name, filename="add_params.bin"): | |
# 下载文件并返回文件路径 | |
file_path = hf_hub_download(repo_id=model_name, filename=filename, use_auth_token=HF_TOKEN) | |
return file_path | |
# 加载 additional_params.bin 文件 | |
def load_additional_params(model_name): | |
# 下载 additional_params.bin | |
params_path = download_additional_params(model_name) | |
# 使用 torch.load() 加载文件内容 | |
additional_params = torch.load(params_path, map_location='cpu') | |
# 返回加载的参数内容 | |
return additional_params | |
def process_image_check(path_input, prompt): | |
if path_input is None: | |
raise gr.Error( | |
"Missing image in the left pane: please upload an image first." | |
) | |
if len(prompt) == 0: | |
raise gr.Error( | |
"At least 1 prediction type is needed." | |
) | |
def process_image_4(image_path, prompt): | |
inputs = [] | |
for p in prompt: | |
cur_p = map_prompt[p] | |
coor_point = [] | |
point_labels = [] | |
cur_input = { | |
# 'original_size': [[w,h]], | |
# 'target_size': [[768, 768]], | |
'prompt': [cur_p], | |
'coor_point': coor_point, | |
'point_labels': point_labels, | |
} | |
inputs.append(cur_input) | |
return inputs | |
def inf(image_path, prompt): | |
print(image_path) | |
print(prompt) | |
inputs = process_image_4(image_path, prompt) | |
# return None | |
return client.predict( | |
image=handle_file(image_path), | |
data=inputs, | |
api_name="/inf" | |
) | |
def clear_cache(): | |
return None, None | |
def run_demo_server(): | |
options = ['depth', 'normal', 'entity', 'pose'] | |
gradio_theme = gr.themes.Default() | |
with gr.Blocks( | |
theme=gradio_theme, | |
title="Matting", | |
) as demo: | |
with gr.Row(): | |
gr.Markdown("# Diception Demo") | |
with gr.Row(): | |
gr.Markdown("### All results are generated using the same single model. To facilitate input processing, we separate point-prompted segmentation and semantic segmentation, as they require input points and segmentation targets.") | |
with gr.Row(): | |
checkbox_group = gr.CheckboxGroup(choices=options, label="Select options:") | |
with gr.Row(): | |
with gr.Column(): | |
matting_image_input = gr.Image( | |
label="Input Image", | |
type="filepath", | |
) | |
with gr.Row(): | |
matting_image_submit_btn = gr.Button( | |
value="Estimate Matting", variant="primary" | |
) | |
matting_image_reset_btn = gr.Button(value="Reset") | |
with gr.Row(): | |
img_clear_button = gr.Button("Clear Cache") | |
with gr.Column(): | |
# matting_image_output = gr.Image(label='Output') | |
matting_image_output = gr.Image(label='Matting Output') | |
# label="Matting Output", | |
# type="filepath", | |
# show_download_button=True, | |
# show_share_button=True, | |
# interactive=False, | |
# elem_classes="slider", | |
# position=0.25, | |
# ) | |
img_clear_button.click(clear_cache, outputs=[matting_image_input, matting_image_output]) | |
matting_image_submit_btn.click( | |
fn=process_image_check, | |
inputs=[matting_image_input, checkbox_group], | |
outputs=None, | |
preprocess=False, | |
queue=False, | |
).success( | |
# fn=process_pipe_matting, | |
fn=inf, | |
inputs=[ | |
matting_image_input, | |
checkbox_group | |
], | |
outputs=[matting_image_output], | |
concurrency_limit=1, | |
) | |
matting_image_reset_btn.click( | |
fn=lambda: ( | |
None, | |
None, | |
), | |
inputs=[], | |
outputs=[ | |
matting_image_input, | |
matting_image_output, | |
], | |
queue=False, | |
) | |
gr.Examples( | |
fn=inf, | |
examples=[ | |
["assets/person.jpg", ['depth', 'normal', 'entity', 'pose']] | |
], | |
inputs=[matting_image_input, checkbox_group], | |
outputs=[matting_image_output], | |
cache_examples=True, | |
# cache_examples=False, | |
# cached_folder="cache_dir", | |
) | |
demo.queue( | |
api_open=False, | |
).launch() | |
if __name__ == '__main__': | |
run_demo_server() |