PTA-1 / app.py
maxiw's picture
Update app.py
a30e28f verified
import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import ImageDraw
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
models = {
"AskUI/PTA-1": AutoModelForCausalLM.from_pretrained("AskUI/PTA-1", trust_remote_code=True),
}
processors = {
"AskUI/PTA-1": AutoProcessor.from_pretrained("AskUI/PTA-1", trust_remote_code=True)
}
def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=3):
draw = ImageDraw.Draw(image)
for box in bounding_boxes:
xmin, ymin, xmax, ymax = box
draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width)
return image
def florence_output_to_box(output):
try:
if "polygons" in output and len(output["polygons"]) > 0:
polygons = output["polygons"]
target_polygon = polygons[0][0]
target_polygon = [int(el) for el in target_polygon]
return [
target_polygon[0],
target_polygon[1],
target_polygon[4],
target_polygon[5],
]
if "bboxes" in output and len(output["bboxes"]) > 0:
bboxes = output["bboxes"]
target_bbox = bboxes[0]
target_bbox = [int(el) for el in target_bbox]
return target_bbox
except Exception as e:
print(f"Error: {e}")
return None
@spaces.GPU
def run_example(image, text_input, model_id="AskUI/PTA-1"):
model = models[model_id].to(device, torch_dtype)
processor = processors[model_id]
task_prompt = "<OPEN_VOCABULARY_DETECTION>"
prompt = task_prompt + text_input
image = image.convert("RGB")
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task="<OPEN_VOCABULARY_DETECTION>", image_size=(image.width, image.height))
target_box = florence_output_to_box(parsed_answer["<OPEN_VOCABULARY_DETECTION>"])
return target_box, draw_bounding_boxes(image, [target_box])
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# PTA-1: Controlling Computers with Small Models
""")
gr.Markdown("Check out the model [AskUI/PTA-1](https://huggingface.co/AskUI/PTA-1).")
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Image", type="pil")
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="AskUI/PTA-1")
text_input = gr.Textbox(label="User Prompt")
submit_btn = gr.Button(value="Submit")
with gr.Column():
model_output_text = gr.Textbox(label="Model Output Text")
annotated_image = gr.Image(label="Annotated Image")
gr.Examples(
examples=[
["assets/sample.png", "search box"],
["assets/sample.png", "Query Service"],
["assets/ipad.png", "App Store icon"],
["assets/ipad.png", 'colorful icon with letter "S"'],
["assets/phone.jpg", "password field"],
["assets/phone.jpg", "back arrow icon"],
["assets/windows.jpg", "icon with letter S"],
["assets/windows.jpg", "Settings"],
],
inputs=[input_img, text_input],
outputs=[model_output_text, annotated_image],
fn=run_example,
cache_examples=False,
label="Try examples"
)
submit_btn.click(run_example, [input_img, text_input, model_selector], [model_output_text, annotated_image])
demo.launch(debug=False)