File size: 4,848 Bytes
452e24d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Optional

import gradio as gr
import numpy as np
import torch
from PIL import Image
import io
import base64, os
from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
import torch
from PIL import Image
import ast

# 定义模型路径,使用相对路径,并使用 os.path.join 确保跨平台兼容性
MODEL_DIR = 'weights'
YOLO_MODEL_PATH = os.path.join(MODEL_DIR, 'icon_detect', 'model.pt')
CAPTION_MODEL_PATH = os.path.join(MODEL_DIR, 'icon_caption')
# BLIP2_CAPTION_MODEL_PATH = os.path.join(MODEL_DIR, 'icon_caption_blip2') # 如果使用 BLIP2 模型

yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
caption_model_processor = get_caption_model_processor(model_name="ollama", model_name_or_path=CAPTION_MODEL_PATH)
# caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2")

MARKDOWN = """

# OmniParser for Pure Vision Based General GUI Agent嘻嘻 🔥

<div>

    <a href="https://arxiv.org/pdf/2408.00203">

        <img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">

    </a>

</div>



OmniParser is a screen parsing tool to convert general GUI screen to structured elements. 

"""

DEVICE = torch.device('cuda')

def process(

    image_input,

    box_threshold,

    iou_threshold,

    use_paddleocr,

    imgsz

) -> Optional[Image.Image]:

    image_save_path = 'imgs/saved_image_demo.png'
    image_input.save(image_save_path)
    image = Image.open(image_save_path)
    box_overlay_ratio = image.size[0] / 3200
    draw_bbox_config = {
        'text_scale': 0.8 * box_overlay_ratio,
        'text_thickness': max(int(2 * box_overlay_ratio), 1),
        'text_padding': max(int(3 * box_overlay_ratio), 1),
        'thickness': max(int(3 * box_overlay_ratio), 1),
    }

    ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr)
    text, ocr_bbox_input = ocr_bbox_rslt

    # Correctly handle ocr_bbox and ocr_text
    if ocr_bbox_input is None or not ocr_bbox_input:
        ocr_bbox = []
        ocr_text = []
    else:
        ocr_bbox = []
        for box_str in ocr_bbox_input:
            try:
                # 使用 eval(),但要非常小心!
                box = eval(box_str)  # 转换为元组
                ocr_bbox.append(box)
            except (SyntaxError, NameError, TypeError, ValueError):
                print(f"警告:无法解析边界框字符串:{box_str}")  # 打印警告信息,但继续处理其他框
                continue # 跳过错误的框
        ocr_text = text  # 使用 check_ocr_box 返回的 text
            
    dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_save_path, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=ocr_text, iou_threshold=iou_threshold, imgsz=imgsz) 
    image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
    print('finish processing')
    parsed_content_list = '\n'.join([f'icon {i}: ' + str(v) for i,v in enumerate(parsed_content_list)])
    return image, str(parsed_content_list)

with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN)
    with gr.Row():
        with gr.Column():
            image_input_component = gr.Image(type='pil', label='Upload image')
            box_threshold_component = gr.Slider(label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
            iou_threshold_component = gr.Slider(label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
            use_paddleocr_component = gr.Checkbox(label='Use PaddleOCR', value=True)
            imgsz_component = gr.Slider(label='Icon Detect Image Size', minimum=640, maximum=1920, step=32, value=640)
            submit_button_component = gr.Button(value='Submit', variant='primary')
        with gr.Column():
            image_output_component = gr.Image(type='pil', label='Image Output')
            text_output_component = gr.Textbox(label='Parsed screen elements', placeholder='Text Output')

    submit_button_component.click(
        fn=process,
        inputs=[
            image_input_component,
            box_threshold_component,
            iou_threshold_component,
            use_paddleocr_component,
            imgsz_component
        ],
        outputs=[image_output_component, text_output_component]
    )

demo.launch(share=True, server_port=7861, server_name='0.0.0.0')