Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import spaces | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
from PIL import ImageDraw | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
models = { | |
"AskUI/PTA-1": AutoModelForCausalLM.from_pretrained("AskUI/PTA-1", trust_remote_code=True), | |
} | |
processors = { | |
"AskUI/PTA-1": AutoProcessor.from_pretrained("AskUI/PTA-1", trust_remote_code=True) | |
} | |
def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=3): | |
draw = ImageDraw.Draw(image) | |
for box in bounding_boxes: | |
xmin, ymin, xmax, ymax = box | |
draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width) | |
return image | |
def florence_output_to_box(output): | |
try: | |
if "polygons" in output and len(output["polygons"]) > 0: | |
polygons = output["polygons"] | |
target_polygon = polygons[0][0] | |
target_polygon = [int(el) for el in target_polygon] | |
return [ | |
target_polygon[0], | |
target_polygon[1], | |
target_polygon[4], | |
target_polygon[5], | |
] | |
if "bboxes" in output and len(output["bboxes"]) > 0: | |
bboxes = output["bboxes"] | |
target_bbox = bboxes[0] | |
target_bbox = [int(el) for el in target_bbox] | |
return target_bbox | |
except Exception as e: | |
print(f"Error: {e}") | |
return None | |
def run_example(image, text_input, model_id="AskUI/PTA-1"): | |
model = models[model_id].to(device, torch_dtype) | |
processor = processors[model_id] | |
task_prompt = "<OPEN_VOCABULARY_DETECTION>" | |
prompt = task_prompt + text_input | |
image = image.convert("RGB") | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
do_sample=False, | |
num_beams=3, | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation(generated_text, task="<OPEN_VOCABULARY_DETECTION>", image_size=(image.width, image.height)) | |
target_box = florence_output_to_box(parsed_answer["<OPEN_VOCABULARY_DETECTION>"]) | |
return target_box, draw_bounding_boxes(image, [target_box]) | |
css = """ | |
#output { | |
height: 500px; | |
overflow: auto; | |
border: 1px solid #ccc; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown( | |
""" | |
# PTA-1: Controlling Computers with Small Models | |
""") | |
gr.Markdown("Check out the model [AskUI/PTA-1](https://huggingface.co/AskUI/PTA-1).") | |
with gr.Row(): | |
with gr.Column(): | |
input_img = gr.Image(label="Input Image", type="pil") | |
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="AskUI/PTA-1") | |
text_input = gr.Textbox(label="User Prompt") | |
submit_btn = gr.Button(value="Submit") | |
with gr.Column(): | |
model_output_text = gr.Textbox(label="Model Output Text") | |
annotated_image = gr.Image(label="Annotated Image") | |
gr.Examples( | |
examples=[ | |
["assets/sample.png", "search box"], | |
["assets/sample.png", "Query Service"], | |
["assets/ipad.png", "App Store icon"], | |
["assets/ipad.png", 'colorful icon with letter "S"'], | |
["assets/phone.jpg", "password field"], | |
["assets/phone.jpg", "back arrow icon"], | |
["assets/windows.jpg", "icon with letter S"], | |
["assets/windows.jpg", "Settings"], | |
], | |
inputs=[input_img, text_input], | |
outputs=[model_output_text, annotated_image], | |
fn=run_example, | |
cache_examples=False, | |
label="Try examples" | |
) | |
submit_btn.click(run_example, [input_img, text_input, model_selector], [model_output_text, annotated_image]) | |
demo.launch(debug=False) | |