import os import numpy as np import random import time import gradio as gr from runner import Runner import matplotlib.pyplot as plt def show_mask(mask, ax, color='blue'): if color == 'blue': # reference, blue color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) else: # target, green color = np.array([78 / 255, 238 / 255, 148 / 255, 0.6]) # if random_color: # color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) # else: # color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels == 1] neg_points = coords[labels == 0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2)) def show_img_point_box_mask(img, input_point=None, input_label=None, box=None, masks=None, save_path=None, mode='mask', color='blue'): if mode == 'point': # point plt.figure(figsize=(10, 10)) plt.imshow(img) show_points(input_point, input_label, plt.gca()) plt.axis('on') plt.savefig(save_path, bbox_inches='tight') elif mode == 'box': # box plt.figure(figsize=(10, 10)) plt.imshow(img) show_box(box, plt.gca()) plt.axis('on') plt.savefig(save_path, bbox_inches='tight') else: # mask plt.figure(figsize=(10, 10)) plt.imshow(img) show_mask(masks, plt.gca(), color=color) plt.axis('off') plt.savefig(save_path, bbox_inches='tight') plt.close() def create_oss_demo( runner: Runner, pipe: None = None ) -> gr.Blocks: examples = [ ['./gradio_demo/images/horse1.png', './gradio_demo/images/horse2.png', './gradio_demo/images/horse3.png'], ['./gradio_demo/images/hmbb1.png', './gradio_demo/images/hmbb2.png', './gradio_demo/images/hmbb3.png'], ['./gradio_demo/images/earth1.png', './gradio_demo/images/earth2.png', './gradio_demo/images/earth3.png'], ['./gradio_demo/images/elephant1.png', './gradio_demo/images/elephant2.png', './gradio_demo/images/elephant3.png'], ['./gradio_demo/images/dinosaur1.png', './gradio_demo/images/dinosaur2.png', './gradio_demo/images/dinosaur3.png'], ] with gr.Blocks() as oss_demo: with gr.Column(): # inputs with gr.Row(): img_input_prompt = gr.ImageMask(label='Prompt (提示图)') img_input_target1 = gr.Image(label='Target 1 (测试图1)') img_input_target2 = gr.Image(label='Target 2 (测试图2)') version = gr.inputs.Radio(['version 1 (🔺 multiple instances 🔻 whole, 🔻 part)', 'version 2 (🔻 multiple instances 🔺 whole, 🔻 part)', 'version 3 (🔻 multiple instances 🔻 whole, 🔺 part)'], type="value", default='version 1 (🔺 whole, 🔻 part)', label='Multiple Instances (version 1), Single Instance (version 2), Part of a object (version 3)') with gr.Row(): submit1 = gr.Button("提交 (Submit)") clear = gr.Button("清除 (Clear)") info = gr.Text(label="Processing result: ", interactive=False) # decision K = gr.Slider(0, 10, 10, step=1, label="Controllable mask output", interactive=True) submit2 = gr.Button("提交 (Submit)") # outputs with gr.Row(): img_output_pmt = gr.Image(label='Prompt (提示图)') img_output_tar1 = gr.Image(label='Output 1 (输出图1)') img_output_tar2 = gr.Image(label='Output 2 (输出图2)') # images gr.Examples( examples=examples, fn=runner.inference_oss_ops, inputs=[img_input_prompt, img_input_target1, img_input_target2], outputs=info ) submit1.click( fn=runner.inference_oss_ops, inputs=[img_input_prompt, img_input_target1, img_input_target2, version], outputs=info ) submit2.click( fn=runner.controllable_mask_output, inputs=K, outputs=[img_output_pmt, img_output_tar1, img_output_tar2] ) clear.click( fn=runner.clear_fn, inputs=None, outputs=[img_input_prompt, img_input_target1, img_input_target2, info, img_output_pmt, img_output_tar1, img_output_tar2], queue=False ) return oss_demo def create_vos_demo( runner: Runner, pipe: None = None ) -> gr.Interface: raise NotImplementedError def create_demo( runner: Runner, pipe: None = None ) -> gr.TabbedInterface: title = "Matcher🎯: Segment Anything with One Shot Using All-Purpose Feature Matching
\
\

[paper] \ [code]

\

Matcher can segment anything with one shot by integrating an all-purpose feature extraction model and a class-agnostic segmentation model.

\
\
\ " oss_demo = create_oss_demo(runner=runner, pipe=pipe) # vos_demo = create_vos_demo(runner=runner, pipe=pipe) demo = gr.TabbedInterface( [oss_demo,], ['OSS+OPS',], title=title) return demo if __name__ == '__main__': pipe = None HF_TOKEN = os.getenv('HF_TOKEN') runner = Runner(HF_TOKEN) # runner = None demo = create_demo(runner, pipe) demo.launch(enable_queue=False)