gatesla's picture
Testing commit ability
e0089c8
import io
import matplotlib.pyplot as plt
import requests, validators
import torch
import pathlib
from PIL import Image
from transformers import AutoFeatureExtractor, DetrForObjectDetection, YolosForObjectDetection
from ultralyticsplus import YOLO, render_result
import os
# colors for visualization
COLORS = [
[0.000, 0.447, 0.741],
[0.850, 0.325, 0.098],
[0.929, 0.694, 0.125],
[0.494, 0.184, 0.556],
[0.466, 0.674, 0.188],
[0.301, 0.745, 0.933]
]
YOLOV8_LABELS = ['pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
def make_prediction(img, feature_extractor, model):
inputs = feature_extractor(img, return_tensors="pt")
outputs = model(**inputs)
img_size = torch.tensor([tuple(reversed(img.size))])
processed_outputs = feature_extractor.post_process(outputs, img_size)
return processed_outputs
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
keep = output_dict["scores"] > threshold
boxes = output_dict["boxes"][keep].tolist()
scores = output_dict["scores"][keep].tolist()
labels = output_dict["labels"][keep].tolist()
if id2label is not None:
labels = [id2label[x] for x in labels]
# print("Labels " + str(labels))
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
plt.axis("off")
return fig2img(plt.gcf())
def detect_objects(model_name,url_input,image_input,threshold):
if 'yolov8' in model_name:
# Working on getting this to work, another approach
# https://docs.ultralytics.com/modes/predict/#key-features-of-predict-mode
model = YOLO(model_name)
# set model parameters
model.overrides['conf'] = 0.15 # NMS confidence threshold
model.overrides['iou'] = 0.05 # NMS IoU threshold https://www.google.com/search?client=firefox-b-1-d&q=intersection+over+union+meaning
model.overrides['agnostic_nms'] = False # NMS class-agnostic
model.overrides['max_det'] = 1000 # maximum number of detections per image
results = model.predict(image_input)
render = render_result(model=model, image=image_input, result=results[0])
final_str = ""
final_str_abv = ""
final_str_else = ""
for result in results:
boxes = result.boxes.cpu().numpy()
for i, box in enumerate(boxes):
# r = box.xyxy[0].astype(int)
coordinates = box.xyxy[0].astype(int)
try:
label = YOLOV8_LABELS[int(box.cls)]
except:
label = "ERROR"
try:
confi = float(box.conf)
except:
confi = 0.0
# final_str_abv += str() + "__" + str(box.cls) + "__" + str(box.conf) + "__" + str(box) + "\n"
if confi >= threshold:
final_str_abv += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n"
else:
final_str_else += f"Detected `{label}` with confidence `{confi}` at location `{coordinates}`\n"
final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else
return render, final_str
else:
#Extract model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
if 'detr' in model_name:
model = DetrForObjectDetection.from_pretrained(model_name)
elif 'yolos' in model_name:
model = YolosForObjectDetection.from_pretrained(model_name)
tb_label = ""
if validators.url(url_input):
image = Image.open(requests.get(url_input, stream=True).raw)
tb_label = "Confidence Values URL"
elif image_input:
image = image_input
tb_label = "Confidence Values Upload"
#Make prediction
processed_output_list = make_prediction(image, feature_extractor, model)
# print("After make_prediction" + str(processed_output_list))
processed_outputs = processed_output_list[0]
#Visualize prediction
viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
# return [viz_img, processed_outputs]
# print(type(viz_img))
final_str_abv = ""
final_str_else = ""
for score, label, box in sorted(zip(processed_outputs["scores"], processed_outputs["labels"], processed_outputs["boxes"]), key = lambda x: x[0].item(), reverse=True):
box = [round(i, 2) for i in box.tolist()]
if score.item() >= threshold:
final_str_abv += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n"
else:
final_str_else += f"Detected `{model.config.id2label[label.item()]}` with confidence `{round(score.item(), 3)}` at location `{box}`\n"
# https://docs.python.org/3/library/string.html#format-examples
final_str = "{:*^50}\n".format("ABOVE THRESHOLD OR EQUAL") + final_str_abv + "\n{:*^50}\n".format("BELOW THRESHOLD")+final_str_else
return viz_img, final_str
title = """<h1 id="title">Object Detection App with DETR and YOLOS</h1>"""
description = """
Links to HuggingFace Models:
- [facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50)
- [facebook/detr-resnet-101](https://huggingface.co/facebook/detr-resnet-101)
- [hustvl/yolos-small](https://huggingface.co/hustvl/yolos-small)
- [hustvl/yolos-tiny](https://huggingface.co/hustvl/yolos-tiny)
- [facebook/detr-resnet-101-dc5](https://huggingface.co/facebook/detr-resnet-101-dc5)
- [hustvl/yolos-small-300](https://huggingface.co/hustvl/yolos-small-300)
- [mshamrai/yolov8x-visdrone](https://huggingface.co/mshamrai/yolov8x-visdrone)
"""
models = ["facebook/detr-resnet-50","facebook/detr-resnet-101",'hustvl/yolos-small','hustvl/yolos-tiny','facebook/detr-resnet-101-dc5', 'hustvl/yolos-small-300', 'mshamrai/yolov8x-visdrone']
urls = ["https://c8.alamy.com/comp/J2AB4K/the-new-york-stock-exchange-on-the-wall-street-in-new-york-J2AB4K.jpg"]
TEST_IMAGE = Image.open(r"images/Test_Street_VisDrone.JPG")
# Test command line when in python terminal: image_functions.detect_objects('facebook/detr-resnet-50', "", image_functions.TEST_IMAGE, 0.7)