freddyaboulton HF staff commited on
Commit
7e1afe1
·
verified ·
1 Parent(s): bd8d7c7

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. app.py +83 -0
  2. index.html +310 -0
  3. inference.py +153 -0
  4. requirements.txt +4 -0
  5. utils.py +237 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ from fastapi import FastAPI
7
+ from fastapi.responses import HTMLResponse
8
+ from fastrtc import Stream, WebRTCError, get_twilio_turn_credentials
9
+ from gradio.utils import get_space
10
+ from huggingface_hub import hf_hub_download
11
+ from pydantic import BaseModel, Field
12
+
13
+ try:
14
+ from demo.object_detection.inference import YOLOv10
15
+ except (ImportError, ModuleNotFoundError):
16
+ from .inference import YOLOv10
17
+
18
+
19
+ cur_dir = Path(__file__).parent
20
+
21
+ model_file = hf_hub_download(
22
+ repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
23
+ )
24
+
25
+ model = YOLOv10(model_file)
26
+
27
+
28
+ def detection(image, conf_threshold=0.3):
29
+ try:
30
+ image = cv2.resize(image, (model.input_width, model.input_height))
31
+ print("conf_threshold", conf_threshold)
32
+ new_image = model.detect_objects(image, conf_threshold)
33
+ return cv2.resize(new_image, (500, 500))
34
+ except Exception as e:
35
+ import traceback
36
+
37
+ traceback.print_exc()
38
+ raise WebRTCError(str(e))
39
+
40
+
41
+ stream = Stream(
42
+ handler=detection,
43
+ modality="video",
44
+ mode="send-receive",
45
+ additional_inputs=[gr.Slider(minimum=0, maximum=1, step=0.01, value=0.3)],
46
+ rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
47
+ concurrency_limit=20 if get_space() else None,
48
+ )
49
+
50
+ app = FastAPI()
51
+
52
+ stream.mount(app)
53
+
54
+
55
+ @app.get("/")
56
+ async def _():
57
+ rtc_config = get_twilio_turn_credentials() if get_space() else None
58
+ html_content = open(cur_dir / "index.html").read()
59
+ html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
60
+ return HTMLResponse(content=html_content)
61
+
62
+
63
+ class InputData(BaseModel):
64
+ webrtc_id: str
65
+ conf_threshold: float = Field(ge=0, le=1)
66
+
67
+
68
+ @app.post("/input_hook")
69
+ async def _(data: InputData):
70
+ stream.set_input(data.webrtc_id, data.conf_threshold)
71
+
72
+
73
+ if __name__ == "__main__":
74
+ import os
75
+
76
+ if (mode := os.getenv("MODE")) == "UI":
77
+ stream.ui.launch(server_port=7860, server_name="0.0.0.0")
78
+ elif mode == "PHONE":
79
+ stream.fastphone(host="0.0.0.0", port=7860)
80
+ else:
81
+ import uvicorn
82
+
83
+ uvicorn.run(app, host="0.0.0.0", port=7860)
index.html ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>Object Detection</title>
8
+ <style>
9
+ body {
10
+ font-family: system-ui, -apple-system, sans-serif;
11
+ background: linear-gradient(135deg, #2d2b52 0%, #191731 100%);
12
+ color: white;
13
+ margin: 0;
14
+ padding: 20px;
15
+ height: 100vh;
16
+ box-sizing: border-box;
17
+ display: flex;
18
+ flex-direction: column;
19
+ align-items: center;
20
+ justify-content: center;
21
+ }
22
+
23
+ .container {
24
+ width: 100%;
25
+ max-width: 800px;
26
+ text-align: center;
27
+ display: flex;
28
+ flex-direction: column;
29
+ align-items: center;
30
+ }
31
+
32
+ .video-container {
33
+ width: 100%;
34
+ max-width: 500px;
35
+ aspect-ratio: 1/1;
36
+ background: rgba(255, 255, 255, 0.1);
37
+ border-radius: 12px;
38
+ overflow: hidden;
39
+ box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2);
40
+ margin: 10px 0;
41
+ }
42
+
43
+ #video-output {
44
+ width: 100%;
45
+ height: 100%;
46
+ object-fit: cover;
47
+ }
48
+
49
+ button {
50
+ background: white;
51
+ color: #2d2b52;
52
+ border: none;
53
+ padding: 12px 32px;
54
+ border-radius: 24px;
55
+ font-size: 16px;
56
+ font-weight: 600;
57
+ cursor: pointer;
58
+ transition: all 0.3s ease;
59
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
60
+ }
61
+
62
+ button:hover {
63
+ transform: translateY(-2px);
64
+ box-shadow: 0 6px 16px rgba(0, 0, 0, 0.2);
65
+ }
66
+
67
+ h1 {
68
+ font-size: 2.5em;
69
+ margin-bottom: 0.3em;
70
+ }
71
+
72
+ p {
73
+ color: rgba(255, 255, 255, 0.8);
74
+ margin-bottom: 1em;
75
+ }
76
+
77
+ .controls {
78
+ display: flex;
79
+ flex-direction: column;
80
+ gap: 12px;
81
+ align-items: center;
82
+ margin-top: 10px;
83
+ }
84
+
85
+ .slider-container {
86
+ width: 100%;
87
+ max-width: 300px;
88
+ display: flex;
89
+ flex-direction: column;
90
+ gap: 8px;
91
+ }
92
+
93
+ .slider-container label {
94
+ color: rgba(255, 255, 255, 0.8);
95
+ font-size: 14px;
96
+ }
97
+
98
+ input[type="range"] {
99
+ width: 100%;
100
+ height: 6px;
101
+ -webkit-appearance: none;
102
+ background: rgba(255, 255, 255, 0.1);
103
+ border-radius: 3px;
104
+ outline: none;
105
+ }
106
+
107
+ input[type="range"]::-webkit-slider-thumb {
108
+ -webkit-appearance: none;
109
+ width: 18px;
110
+ height: 18px;
111
+ background: white;
112
+ border-radius: 50%;
113
+ cursor: pointer;
114
+ }
115
+
116
+ /* Add styles for toast notifications */
117
+ .toast {
118
+ position: fixed;
119
+ top: 20px;
120
+ left: 50%;
121
+ transform: translateX(-50%);
122
+ background-color: #f44336;
123
+ color: white;
124
+ padding: 16px 24px;
125
+ border-radius: 4px;
126
+ font-size: 14px;
127
+ z-index: 1000;
128
+ display: none;
129
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);
130
+ }
131
+ </style>
132
+ </head>
133
+
134
+ <body>
135
+ <!-- Add toast element after body opening tag -->
136
+ <div id="error-toast" class="toast"></div>
137
+ <div class="container">
138
+ <h1>Real-time Object Detection</h1>
139
+ <p>Using YOLOv10 to detect objects in your webcam feed</p>
140
+ <div class="video-container">
141
+ <video id="video-output" autoplay playsinline></video>
142
+ </div>
143
+ <div class="controls">
144
+ <div class="slider-container">
145
+ <label>Confidence Threshold: <span id="conf-value">0.3</span></label>
146
+ <input type="range" id="conf-threshold" min="0" max="1" step="0.01" value="0.3">
147
+ </div>
148
+ <button id="start-button">Start</button>
149
+ </div>
150
+ </div>
151
+
152
+ <script>
153
+ let peerConnection;
154
+ let webrtc_id;
155
+ const startButton = document.getElementById('start-button');
156
+ const videoOutput = document.getElementById('video-output');
157
+ const confThreshold = document.getElementById('conf-threshold');
158
+ const confValue = document.getElementById('conf-value');
159
+
160
+ // Update confidence value display
161
+ confThreshold.addEventListener('input', (e) => {
162
+ confValue.textContent = e.target.value;
163
+ if (peerConnection) {
164
+ updateConfThreshold(e.target.value);
165
+ }
166
+ });
167
+
168
+ function updateConfThreshold(value) {
169
+ fetch('/input_hook', {
170
+ method: 'POST',
171
+ headers: {
172
+ 'Content-Type': 'application/json',
173
+ },
174
+ body: JSON.stringify({
175
+ webrtc_id: webrtc_id,
176
+ conf_threshold: parseFloat(value)
177
+ })
178
+ });
179
+ }
180
+
181
+ function showError(message) {
182
+ const toast = document.getElementById('error-toast');
183
+ toast.textContent = message;
184
+ toast.style.display = 'block';
185
+
186
+ // Hide toast after 5 seconds
187
+ setTimeout(() => {
188
+ toast.style.display = 'none';
189
+ }, 5000);
190
+ }
191
+
192
+ async function setupWebRTC() {
193
+ const config = __RTC_CONFIGURATION__;
194
+ peerConnection = new RTCPeerConnection(config);
195
+
196
+ try {
197
+ const stream = await navigator.mediaDevices.getUserMedia({
198
+ video: true
199
+ });
200
+
201
+ stream.getTracks().forEach(track => {
202
+ peerConnection.addTrack(track, stream);
203
+ });
204
+
205
+ peerConnection.addEventListener('track', (evt) => {
206
+ if (videoOutput && videoOutput.srcObject !== evt.streams[0]) {
207
+ videoOutput.srcObject = evt.streams[0];
208
+ }
209
+ });
210
+
211
+ const dataChannel = peerConnection.createDataChannel('text');
212
+ dataChannel.onmessage = (event) => {
213
+ const eventJson = JSON.parse(event.data);
214
+ if (eventJson.type === "error") {
215
+ showError(eventJson.message);
216
+ } else if (eventJson.type === "send_input") {
217
+ updateConfThreshold(confThreshold.value);
218
+ }
219
+ };
220
+
221
+ const offer = await peerConnection.createOffer();
222
+ await peerConnection.setLocalDescription(offer);
223
+
224
+ await new Promise((resolve) => {
225
+ if (peerConnection.iceGatheringState === "complete") {
226
+ resolve();
227
+ } else {
228
+ const checkState = () => {
229
+ if (peerConnection.iceGatheringState === "complete") {
230
+ peerConnection.removeEventListener("icegatheringstatechange", checkState);
231
+ resolve();
232
+ }
233
+ };
234
+ peerConnection.addEventListener("icegatheringstatechange", checkState);
235
+ }
236
+ });
237
+
238
+ webrtc_id = Math.random().toString(36).substring(7);
239
+
240
+ const response = await fetch('/webrtc/offer', {
241
+ method: 'POST',
242
+ headers: { 'Content-Type': 'application/json' },
243
+ body: JSON.stringify({
244
+ sdp: peerConnection.localDescription.sdp,
245
+ type: peerConnection.localDescription.type,
246
+ webrtc_id: webrtc_id
247
+ })
248
+ });
249
+
250
+ const serverResponse = await response.json();
251
+
252
+ if (serverResponse.status === 'failed') {
253
+ showError(serverResponse.meta.error === 'concurrency_limit_reached'
254
+ ? `Too many connections. Maximum limit is ${serverResponse.meta.limit}`
255
+ : serverResponse.meta.error);
256
+ stop();
257
+ startButton.textContent = 'Start';
258
+ return;
259
+ }
260
+
261
+ await peerConnection.setRemoteDescription(serverResponse);
262
+
263
+ // Send initial confidence threshold
264
+ updateConfThreshold(confThreshold.value);
265
+
266
+ } catch (err) {
267
+ console.error('Error setting up WebRTC:', err);
268
+ showError('Failed to establish connection. Please try again.');
269
+ stop();
270
+ startButton.textContent = 'Start';
271
+ }
272
+ }
273
+
274
+ function stop() {
275
+ if (peerConnection) {
276
+ if (peerConnection.getTransceivers) {
277
+ peerConnection.getTransceivers().forEach(transceiver => {
278
+ if (transceiver.stop) {
279
+ transceiver.stop();
280
+ }
281
+ });
282
+ }
283
+
284
+ if (peerConnection.getSenders) {
285
+ peerConnection.getSenders().forEach(sender => {
286
+ if (sender.track && sender.track.stop) sender.track.stop();
287
+ });
288
+ }
289
+
290
+ setTimeout(() => {
291
+ peerConnection.close();
292
+ }, 500);
293
+ }
294
+
295
+ videoOutput.srcObject = null;
296
+ }
297
+
298
+ startButton.addEventListener('click', () => {
299
+ if (startButton.textContent === 'Start') {
300
+ setupWebRTC();
301
+ startButton.textContent = 'Stop';
302
+ } else {
303
+ stop();
304
+ startButton.textContent = 'Start';
305
+ }
306
+ });
307
+ </script>
308
+ </body>
309
+
310
+ </html>
inference.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import onnxruntime
6
+
7
+ try:
8
+ from demo.object_detection.utils import draw_detections
9
+ except (ImportError, ModuleNotFoundError):
10
+ from utils import draw_detections
11
+
12
+
13
+ class YOLOv10:
14
+ def __init__(self, path):
15
+ # Initialize model
16
+ self.initialize_model(path)
17
+
18
+ def __call__(self, image):
19
+ return self.detect_objects(image)
20
+
21
+ def initialize_model(self, path):
22
+ self.session = onnxruntime.InferenceSession(
23
+ path, providers=onnxruntime.get_available_providers()
24
+ )
25
+ # Get model info
26
+ self.get_input_details()
27
+ self.get_output_details()
28
+
29
+ def detect_objects(self, image, conf_threshold=0.3):
30
+ input_tensor = self.prepare_input(image)
31
+
32
+ # Perform inference on the image
33
+ new_image = self.inference(image, input_tensor, conf_threshold)
34
+
35
+ return new_image
36
+
37
+ def prepare_input(self, image):
38
+ self.img_height, self.img_width = image.shape[:2]
39
+
40
+ input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
41
+
42
+ # Resize input image
43
+ input_img = cv2.resize(input_img, (self.input_width, self.input_height))
44
+
45
+ # Scale input pixel values to 0 to 1
46
+ input_img = input_img / 255.0
47
+ input_img = input_img.transpose(2, 0, 1)
48
+ input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
49
+
50
+ return input_tensor
51
+
52
+ def inference(self, image, input_tensor, conf_threshold=0.3):
53
+ start = time.perf_counter()
54
+ outputs = self.session.run(
55
+ self.output_names, {self.input_names[0]: input_tensor}
56
+ )
57
+
58
+ print(f"Inference time: {(time.perf_counter() - start) * 1000:.2f} ms")
59
+ (
60
+ boxes,
61
+ scores,
62
+ class_ids,
63
+ ) = self.process_output(outputs, conf_threshold)
64
+ return self.draw_detections(image, boxes, scores, class_ids)
65
+
66
+ def process_output(self, output, conf_threshold=0.3):
67
+ predictions = np.squeeze(output[0])
68
+
69
+ # Filter out object confidence scores below threshold
70
+ scores = predictions[:, 4]
71
+ predictions = predictions[scores > conf_threshold, :]
72
+ scores = scores[scores > conf_threshold]
73
+
74
+ if len(scores) == 0:
75
+ return [], [], []
76
+
77
+ # Get the class with the highest confidence
78
+ class_ids = predictions[:, 5].astype(int)
79
+
80
+ # Get bounding boxes for each object
81
+ boxes = self.extract_boxes(predictions)
82
+
83
+ return boxes, scores, class_ids
84
+
85
+ def extract_boxes(self, predictions):
86
+ # Extract boxes from predictions
87
+ boxes = predictions[:, :4]
88
+
89
+ # Scale boxes to original image dimensions
90
+ boxes = self.rescale_boxes(boxes)
91
+
92
+ # Convert boxes to xyxy format
93
+ # boxes = xywh2xyxy(boxes)
94
+
95
+ return boxes
96
+
97
+ def rescale_boxes(self, boxes):
98
+ # Rescale boxes to original image dimensions
99
+ input_shape = np.array(
100
+ [self.input_width, self.input_height, self.input_width, self.input_height]
101
+ )
102
+ boxes = np.divide(boxes, input_shape, dtype=np.float32)
103
+ boxes *= np.array(
104
+ [self.img_width, self.img_height, self.img_width, self.img_height]
105
+ )
106
+ return boxes
107
+
108
+ def draw_detections(
109
+ self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4
110
+ ):
111
+ return draw_detections(image, boxes, scores, class_ids, mask_alpha)
112
+
113
+ def get_input_details(self):
114
+ model_inputs = self.session.get_inputs()
115
+ self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]
116
+
117
+ self.input_shape = model_inputs[0].shape
118
+ self.input_height = self.input_shape[2]
119
+ self.input_width = self.input_shape[3]
120
+
121
+ def get_output_details(self):
122
+ model_outputs = self.session.get_outputs()
123
+ self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]
124
+
125
+
126
+ if __name__ == "__main__":
127
+ import tempfile
128
+
129
+ import requests
130
+ from huggingface_hub import hf_hub_download
131
+
132
+ model_file = hf_hub_download(
133
+ repo_id="onnx-community/yolov10s", filename="onnx/model.onnx"
134
+ )
135
+
136
+ yolov8_detector = YOLOv10(model_file)
137
+
138
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
139
+ f.write(
140
+ requests.get(
141
+ "https://live.staticflickr.com/13/19041780_d6fd803de0_3k.jpg"
142
+ ).content
143
+ )
144
+ f.seek(0)
145
+ img = cv2.imread(f.name)
146
+
147
+ # # Detect Objects
148
+ combined_image = yolov8_detector.detect_objects(img)
149
+
150
+ # Draw detections
151
+ cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
152
+ cv2.imshow("Output", combined_image)
153
+ cv2.waitKey(0)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fastrtc
2
+ opencv-python
3
+ twilio
4
+ onnxruntime-gpu
utils.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ class_names = [
5
+ "person",
6
+ "bicycle",
7
+ "car",
8
+ "motorcycle",
9
+ "airplane",
10
+ "bus",
11
+ "train",
12
+ "truck",
13
+ "boat",
14
+ "traffic light",
15
+ "fire hydrant",
16
+ "stop sign",
17
+ "parking meter",
18
+ "bench",
19
+ "bird",
20
+ "cat",
21
+ "dog",
22
+ "horse",
23
+ "sheep",
24
+ "cow",
25
+ "elephant",
26
+ "bear",
27
+ "zebra",
28
+ "giraffe",
29
+ "backpack",
30
+ "umbrella",
31
+ "handbag",
32
+ "tie",
33
+ "suitcase",
34
+ "frisbee",
35
+ "skis",
36
+ "snowboard",
37
+ "sports ball",
38
+ "kite",
39
+ "baseball bat",
40
+ "baseball glove",
41
+ "skateboard",
42
+ "surfboard",
43
+ "tennis racket",
44
+ "bottle",
45
+ "wine glass",
46
+ "cup",
47
+ "fork",
48
+ "knife",
49
+ "spoon",
50
+ "bowl",
51
+ "banana",
52
+ "apple",
53
+ "sandwich",
54
+ "orange",
55
+ "broccoli",
56
+ "carrot",
57
+ "hot dog",
58
+ "pizza",
59
+ "donut",
60
+ "cake",
61
+ "chair",
62
+ "couch",
63
+ "potted plant",
64
+ "bed",
65
+ "dining table",
66
+ "toilet",
67
+ "tv",
68
+ "laptop",
69
+ "mouse",
70
+ "remote",
71
+ "keyboard",
72
+ "cell phone",
73
+ "microwave",
74
+ "oven",
75
+ "toaster",
76
+ "sink",
77
+ "refrigerator",
78
+ "book",
79
+ "clock",
80
+ "vase",
81
+ "scissors",
82
+ "teddy bear",
83
+ "hair drier",
84
+ "toothbrush",
85
+ ]
86
+
87
+ # Create a list of colors for each class where each color is a tuple of 3 integer values
88
+ rng = np.random.default_rng(3)
89
+ colors = rng.uniform(0, 255, size=(len(class_names), 3))
90
+
91
+
92
+ def nms(boxes, scores, iou_threshold):
93
+ # Sort by score
94
+ sorted_indices = np.argsort(scores)[::-1]
95
+
96
+ keep_boxes = []
97
+ while sorted_indices.size > 0:
98
+ # Pick the last box
99
+ box_id = sorted_indices[0]
100
+ keep_boxes.append(box_id)
101
+
102
+ # Compute IoU of the picked box with the rest
103
+ ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
104
+
105
+ # Remove boxes with IoU over the threshold
106
+ keep_indices = np.where(ious < iou_threshold)[0]
107
+
108
+ # print(keep_indices.shape, sorted_indices.shape)
109
+ sorted_indices = sorted_indices[keep_indices + 1]
110
+
111
+ return keep_boxes
112
+
113
+
114
+ def multiclass_nms(boxes, scores, class_ids, iou_threshold):
115
+ unique_class_ids = np.unique(class_ids)
116
+
117
+ keep_boxes = []
118
+ for class_id in unique_class_ids:
119
+ class_indices = np.where(class_ids == class_id)[0]
120
+ class_boxes = boxes[class_indices, :]
121
+ class_scores = scores[class_indices]
122
+
123
+ class_keep_boxes = nms(class_boxes, class_scores, iou_threshold)
124
+ keep_boxes.extend(class_indices[class_keep_boxes])
125
+
126
+ return keep_boxes
127
+
128
+
129
+ def compute_iou(box, boxes):
130
+ # Compute xmin, ymin, xmax, ymax for both boxes
131
+ xmin = np.maximum(box[0], boxes[:, 0])
132
+ ymin = np.maximum(box[1], boxes[:, 1])
133
+ xmax = np.minimum(box[2], boxes[:, 2])
134
+ ymax = np.minimum(box[3], boxes[:, 3])
135
+
136
+ # Compute intersection area
137
+ intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
138
+
139
+ # Compute union area
140
+ box_area = (box[2] - box[0]) * (box[3] - box[1])
141
+ boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
142
+ union_area = box_area + boxes_area - intersection_area
143
+
144
+ # Compute IoU
145
+ iou = intersection_area / union_area
146
+
147
+ return iou
148
+
149
+
150
+ def xywh2xyxy(x):
151
+ # Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2)
152
+ y = np.copy(x)
153
+ y[..., 0] = x[..., 0] - x[..., 2] / 2
154
+ y[..., 1] = x[..., 1] - x[..., 3] / 2
155
+ y[..., 2] = x[..., 0] + x[..., 2] / 2
156
+ y[..., 3] = x[..., 1] + x[..., 3] / 2
157
+ return y
158
+
159
+
160
+ def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3):
161
+ det_img = image.copy()
162
+
163
+ img_height, img_width = image.shape[:2]
164
+ font_size = min([img_height, img_width]) * 0.0006
165
+ text_thickness = int(min([img_height, img_width]) * 0.001)
166
+
167
+ # det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)
168
+
169
+ # Draw bounding boxes and labels of detections
170
+ for class_id, box, score in zip(class_ids, boxes, scores):
171
+ color = colors[class_id]
172
+
173
+ draw_box(det_img, box, color) # type: ignore
174
+
175
+ label = class_names[class_id]
176
+ caption = f"{label} {int(score * 100)}%"
177
+ draw_text(det_img, caption, box, color, font_size, text_thickness) # type: ignore
178
+
179
+ return det_img
180
+
181
+
182
+ def draw_box(
183
+ image: np.ndarray,
184
+ box: np.ndarray,
185
+ color: tuple[int, int, int] = (0, 0, 255),
186
+ thickness: int = 2,
187
+ ) -> np.ndarray:
188
+ x1, y1, x2, y2 = box.astype(int)
189
+ return cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
190
+
191
+
192
+ def draw_text(
193
+ image: np.ndarray,
194
+ text: str,
195
+ box: np.ndarray,
196
+ color: tuple[int, int, int] = (0, 0, 255),
197
+ font_size: float = 0.001,
198
+ text_thickness: int = 2,
199
+ ) -> np.ndarray:
200
+ x1, y1, x2, y2 = box.astype(int)
201
+ (tw, th), _ = cv2.getTextSize(
202
+ text=text,
203
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
204
+ fontScale=font_size,
205
+ thickness=text_thickness,
206
+ )
207
+ th = int(th * 1.2)
208
+
209
+ cv2.rectangle(image, (x1, y1), (x1 + tw, y1 - th), color, -1)
210
+
211
+ return cv2.putText(
212
+ image,
213
+ text,
214
+ (x1, y1),
215
+ cv2.FONT_HERSHEY_SIMPLEX,
216
+ font_size,
217
+ (255, 255, 255),
218
+ text_thickness,
219
+ cv2.LINE_AA,
220
+ )
221
+
222
+
223
+ def draw_masks(
224
+ image: np.ndarray, boxes: np.ndarray, classes: np.ndarray, mask_alpha: float = 0.3
225
+ ) -> np.ndarray:
226
+ mask_img = image.copy()
227
+
228
+ # Draw bounding boxes and labels of detections
229
+ for box, class_id in zip(boxes, classes):
230
+ color = colors[class_id]
231
+
232
+ x1, y1, x2, y2 = box.astype(int)
233
+
234
+ # Draw fill rectangle in mask image
235
+ cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1) # type: ignore
236
+
237
+ return cv2.addWeighted(mask_img, mask_alpha, image, 1 - mask_alpha, 0)