import io import gradio as gr import matplotlib.pyplot as plt import requests, validators import torch import pathlib from PIL import Image from transformers import AutoFeatureExtractor, DetrForObjectDetection, YolosForObjectDetection from ultralyticsplus import YOLO, render_result import os # colors for visualization COLORS = [ [0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933] ] YOLOV8_LABELS = ['pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor'] def make_prediction(img, feature_extractor, model): inputs = feature_extractor(img, return_tensors="pt") outputs = model(**inputs) img_size = torch.tensor([tuple(reversed(img.size))]) processed_outputs = feature_extractor.post_process(outputs, img_size) return processed_outputs def fig2img(fig): buf = io.BytesIO() fig.savefig(buf, bbox_inches="tight") buf.seek(0) img = Image.open(buf) return img def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None): keep = output_dict["scores"] > threshold boxes = output_dict["boxes"][keep].tolist() scores = output_dict["scores"][keep].tolist() labels = output_dict["labels"][keep].tolist() if id2label is not None: labels = [id2label[x] for x in labels] # print("Labels " + str(labels)) plt.figure(figsize=(16, 10)) plt.imshow(pil_img) ax = plt.gca() colors = COLORS * 100 for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors): ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3)) ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5)) plt.axis("off") return fig2img(plt.gcf()) def detect_objects(model_name,url_input,image_input,threshold): if 'yolov8' in model_name: # Working on getting this to work, another approach # https://docs.ultralytics.com/modes/predict/#key-features-of-predict-mode model = YOLO(model_name) # set model parameters model.overrides['conf'] = 0.15 # NMS confidence threshold model.overrides['iou'] = 0.05 # NMS IoU threshold https://www.google.com/search?client=firefox-b-1-d&q=intersection+over+union+meaning model.overrides['agnostic_nms'] = False # NMS class-agnostic model.overrides['max_det'] = 1000 # maximum number of detections per image results = model.predict(image_input) render = render_result(model=model, image=image_input, result=results[0]) final_str = "" final_str_abv = "" final_str_else = "" for result in results: boxes = result.boxes.cpu().numpy() for i, box in enumerate(boxes): # r = box.xyxy[0].astype(int) coordinates = box.xyxy[0].astype(int) try: label = YOLOV8_LABELS[int(box.cls)] except: label = "ERROR" try: confi = float(box.conf) except: confi = 0.0 # final_str_abv += str() + "__" + str(box.cls) + "__" + str(box.conf) + "__" + str(box) + "\n" if confi >= threshold: final_str_abv += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n" else: final_str_else += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n" final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else return render, final_str else: #Extract model and feature extractor feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) if 'detr' in model_name: model = DetrForObjectDetection.from_pretrained(model_name) elif 'yolos' in model_name: model = YolosForObjectDetection.from_pretrained(model_name) tb_label = "" if validators.url(url_input): image = Image.open(requests.get(url_input, stream=True).raw) tb_label = "Confidence Values URL" elif image_input: image = image_input tb_label = "Confidence Values Upload" #Make prediction processed_output_list = make_prediction(image, feature_extractor, model) # print("After make_prediction" + str(processed_output_list)) processed_outputs = processed_output_list[0] #Visualize prediction viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label) # return [viz_img, processed_outputs] # print(type(viz_img)) final_str_abv = "" final_str_else = "" for score, label, box in sorted(zip(processed_outputs["scores"], processed_outputs["labels"], processed_outputs["boxes"]), key = lambda x: x[0].item(), reverse=True): box = [round(i, 2) for i in box.tolist()] if score.item() >= threshold: final_str_abv += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n" else: final_str_else += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n" # https://docs.python.org/3/library/string.html#format-examples final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else return viz_img, final_str def set_example_image(example: list) -> dict: return gr.Image(value=example[0]["path"]) def set_example_url(example: list) -> dict: return gr.Textbox(value=example[0]["path"]) title = """

Object Detection App with DETR and YOLOS

""" description = """ Links to HuggingFace Models: - [facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50) - [facebook/detr-resnet-101](https://huggingface.co/facebook/detr-resnet-101) - [hustvl/yolos-small](https://huggingface.co/hustvl/yolos-small) - [hustvl/yolos-tiny](https://huggingface.co/hustvl/yolos-tiny) - [facebook/detr-resnet-101-dc5](https://huggingface.co/facebook/detr-resnet-101-dc5) - [hustvl/yolos-small-300](https://huggingface.co/hustvl/yolos-small-300) - [mshamrai/yolov8x-visdrone](https://huggingface.co/mshamrai/yolov8x-visdrone) """ models = ["facebook/detr-resnet-50","facebook/detr-resnet-101",'hustvl/yolos-small','hustvl/yolos-tiny','facebook/detr-resnet-101-dc5', 'hustvl/yolos-small-300', 'mshamrai/yolov8x-visdrone'] urls = ["https://c8.alamy.com/comp/J2AB4K/the-new-york-stock-exchange-on-the-wall-street-in-new-york-J2AB4K.jpg"] # twitter_link = """ # [![](https://img.shields.io/twitter/follow/nickmuchi?label=@nickmuchi&style=social)](https://twitter.com/nickmuchi) # """ css = ''' h1#title { text-align: center; } ''' demo = gr.Blocks(css=css) def changing(): # https://discuss.huggingface.co/t/how-to-programmatically-enable-or-disable-components/52350/4 return gr.Button('Detect', interactive=True), gr.Button('Detect', interactive=True) with demo: gr.Markdown(title) gr.Markdown(description) # gr.Markdown(twitter_link) options = gr.Dropdown(choices=models,label='Select Object Detection Model',show_label=True) slider_input = gr.Slider(minimum=0.2,maximum=1,value=0.7,label='Prediction Threshold') with gr.Tabs(): with gr.TabItem('Image URL'): with gr.Row(): url_input = gr.Textbox(lines=2,label='Enter valid image URL here..') img_output_from_url = gr.Image(height=650,width=650) with gr.Row(): example_url = gr.Dataset(components=[url_input],samples=[[str(url)] for url in urls]) url_but = gr.Button('Detect', interactive=False) with gr.TabItem('Image Upload'): with gr.Row(): img_input = gr.Image(type='pil') img_output_from_upload= gr.Image(height=650,width=650) with gr.Row(): example_images = gr.Dataset(components=[img_input], samples=[[path.as_posix()] for path in sorted(pathlib.Path('images').rglob('*.JPG'))]) # Can't get case_sensitive to work img_but = gr.Button('Detect', interactive=False) # output_text1 = gr.outputs.Textbox(label="Confidence Values") output_text1 = gr.components.Textbox(label="Confidence Values") # https://huggingface.co/spaces/vishnun/CLIPnCROP/blob/main/app.py -- Got .outputs. from this options.change(fn=changing, inputs=[], outputs=[img_but, url_but]) url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_url, output_text1],queue=True) img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_upload, output_text1],queue=True) # url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_url, _],queue=True) # img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=[img_output_from_upload, _],queue=True) # url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_url,queue=True) # img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_upload,queue=True) example_images.click(fn=set_example_image,inputs=[example_images],outputs=[img_input]) example_url.click(fn=set_example_url,inputs=[example_url],outputs=[url_input]) # gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-object-detection-with-detr-and-yolos)") # demo.launch(enable_queue=True) demo.launch() #removed (share=True)