Spaces:
Sleeping
Sleeping
kernel-luso-comfort
commited on
Commit
·
fc90b14
1
Parent(s):
354d315
Add legend to prediction overlay with dynamic font sizing and color boxes
Browse files- inference_utils/model.py +86 -41
inference_utils/model.py
CHANGED
@@ -14,7 +14,7 @@ from dataclasses import dataclass
|
|
14 |
import os
|
15 |
from typing import Tuple
|
16 |
|
17 |
-
from PIL import Image, ImageDraw
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
import matplotlib.pyplot as plt
|
20 |
import numpy as np
|
@@ -50,7 +50,10 @@ class Model:
|
|
50 |
self._model, image, modality_type, targets
|
51 |
)
|
52 |
targets_not_found_str = (
|
53 |
-
"\n".join(
|
|
|
|
|
|
|
54 |
if prediction_targets_not_found
|
55 |
else "All targets were found!"
|
56 |
)
|
@@ -104,51 +107,16 @@ def predict(
|
|
104 |
)
|
105 |
pt.adjusted_p_value = float(adj_p_value)
|
106 |
|
107 |
-
|
108 |
prediction_tasks, 0.05
|
109 |
)
|
110 |
|
111 |
# Generate visualization
|
112 |
-
colors = generate_colors(len(
|
113 |
-
masks = [1 * (pred_mask[i] > 0.5) for i in range(len(
|
114 |
pred_overlay = overlay_masks(image, masks, colors)
|
115 |
|
116 |
-
|
117 |
-
if len(pred_tasks_found) > 0:
|
118 |
-
# Convert to numpy for manipulation
|
119 |
-
pred_overlay = np.array(pred_overlay)
|
120 |
-
|
121 |
-
# Calculate legend dimensions
|
122 |
-
legend_height = 30 * len(pred_tasks_found) # 30 pixels per entry
|
123 |
-
legend_padding = 10 # padding around legend
|
124 |
-
total_height = pred_overlay.shape[0] + legend_height + 2 * legend_padding
|
125 |
-
|
126 |
-
# Create new image with space for legend
|
127 |
-
new_image = np.zeros((total_height, pred_overlay.shape[1], 3), dtype=np.uint8)
|
128 |
-
new_image[: pred_overlay.shape[0], :] = pred_overlay
|
129 |
-
new_image[pred_overlay.shape[0] :] = 255 # White background for legend
|
130 |
-
|
131 |
-
# Convert to PIL once for all legend entries
|
132 |
-
img_pil = Image.fromarray(new_image)
|
133 |
-
draw = ImageDraw.Draw(img_pil)
|
134 |
-
|
135 |
-
# Draw legend entries
|
136 |
-
start_y = pred_overlay.shape[0] + legend_padding
|
137 |
-
for i, task in enumerate(pred_tasks_found):
|
138 |
-
# Draw color box
|
139 |
-
box_x = 10
|
140 |
-
box_y = start_y + i * 30
|
141 |
-
box_size = 20
|
142 |
-
box_coords = (box_x, box_y, box_x + box_size, box_y + box_size)
|
143 |
-
draw.rectangle(box_coords, fill=colors[i])
|
144 |
-
|
145 |
-
# Draw text (vertically centered with color box)
|
146 |
-
text_y = box_y + (box_size - 12) // 2 # Assuming ~12px text height
|
147 |
-
draw.text((box_x + box_size + 10, text_y), task.target, fill=(0, 0, 0))
|
148 |
-
|
149 |
-
pred_overlay = img_pil
|
150 |
-
else:
|
151 |
-
pred_overlay = Image.fromarray(np.array(pred_overlay))
|
152 |
|
153 |
return pred_overlay, pred_tasks_not_found
|
154 |
|
@@ -194,3 +162,80 @@ def overlay_masks(
|
|
194 |
np.uint8
|
195 |
)
|
196 |
return Image.fromarray(overlay)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
import os
|
15 |
from typing import Tuple
|
16 |
|
17 |
+
from PIL import Image, ImageDraw, ImageFont
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
import matplotlib.pyplot as plt
|
20 |
import numpy as np
|
|
|
50 |
self._model, image, modality_type, targets
|
51 |
)
|
52 |
targets_not_found_str = (
|
53 |
+
"\n".join(
|
54 |
+
f"{t.target} ({t.adjusted_p_value, 2:.3f})"
|
55 |
+
for t in prediction_targets_not_found
|
56 |
+
)
|
57 |
if prediction_targets_not_found
|
58 |
else "All targets were found!"
|
59 |
)
|
|
|
107 |
)
|
108 |
pt.adjusted_p_value = float(adj_p_value)
|
109 |
|
110 |
+
pred_targets_found, pred_tasks_not_found = segregate_prediction_tasks(
|
111 |
prediction_tasks, 0.05
|
112 |
)
|
113 |
|
114 |
# Generate visualization
|
115 |
+
colors = generate_colors(len(pred_targets_found))
|
116 |
+
masks = [1 * (pred_mask[i] > 0.5) for i in range(len(pred_targets_found))]
|
117 |
pred_overlay = overlay_masks(image, masks, colors)
|
118 |
|
119 |
+
pred_overlay = add_legend(pred_overlay, pred_targets_found, colors)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
return pred_overlay, pred_tasks_not_found
|
122 |
|
|
|
162 |
np.uint8
|
163 |
)
|
164 |
return Image.fromarray(overlay)
|
165 |
+
|
166 |
+
|
167 |
+
def add_legend(
|
168 |
+
image: Image.Image,
|
169 |
+
pred_targets_found: list[PredictionTarget],
|
170 |
+
colors: list[Tuple[int, int, int]],
|
171 |
+
) -> Image.Image:
|
172 |
+
if len(pred_targets_found) == 0:
|
173 |
+
return image
|
174 |
+
|
175 |
+
# Convert to numpy for manipulation
|
176 |
+
pred_overlay = np.array(image)
|
177 |
+
|
178 |
+
# Calculate dimensions based on image resolution
|
179 |
+
image_width = pred_overlay.shape[1]
|
180 |
+
font_size = max(16, int(image_width * 0.02)) # Scale with image width, minimum 16px
|
181 |
+
box_size = int(font_size * 1.5) # Color box proportional to font
|
182 |
+
entry_height = int(box_size * 1.5) # Space between entries
|
183 |
+
legend_padding = int(font_size * 0.75) # Padding scales with font
|
184 |
+
|
185 |
+
# Calculate total legend height
|
186 |
+
legend_height = entry_height * len(pred_targets_found)
|
187 |
+
total_height = pred_overlay.shape[0] + legend_height + 2 * legend_padding
|
188 |
+
|
189 |
+
# Create new image with space for legend
|
190 |
+
new_image = np.zeros((total_height, pred_overlay.shape[1], 3), dtype=np.uint8)
|
191 |
+
new_image[: pred_overlay.shape[0], :] = pred_overlay
|
192 |
+
new_image[pred_overlay.shape[0] :] = 255 # White background for legend
|
193 |
+
|
194 |
+
# Convert to PIL once for all legend entries
|
195 |
+
img_pil = Image.fromarray(new_image)
|
196 |
+
draw = ImageDraw.Draw(img_pil)
|
197 |
+
|
198 |
+
# Try to load a system font with proper scaling
|
199 |
+
font = None
|
200 |
+
system_fonts = [
|
201 |
+
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", # Linux
|
202 |
+
"/System/Library/Fonts/Helvetica.ttc", # macOS
|
203 |
+
"C:\\Windows\\Fonts\\arial.ttf", # Windows
|
204 |
+
]
|
205 |
+
for font_path in system_fonts:
|
206 |
+
try:
|
207 |
+
font = ImageFont.truetype(font_path, font_size)
|
208 |
+
break
|
209 |
+
except (OSError, IOError):
|
210 |
+
continue
|
211 |
+
|
212 |
+
if font is None:
|
213 |
+
# Fallback to default font if no system fonts are available
|
214 |
+
font = ImageFont.load_default()
|
215 |
+
|
216 |
+
# Get font metrics for proper vertical centering
|
217 |
+
bbox = font.getbbox("Aj") # Use tall characters to get true height
|
218 |
+
font_height = bbox[3] - bbox[1] # bottom - top
|
219 |
+
|
220 |
+
# Draw legend entries
|
221 |
+
start_y = pred_overlay.shape[0] + legend_padding
|
222 |
+
for i, task in enumerate(pred_targets_found):
|
223 |
+
# Draw color box
|
224 |
+
box_x = legend_padding
|
225 |
+
box_y = start_y + i * entry_height
|
226 |
+
box_coords = (box_x, box_y, box_x + box_size, box_y + box_size)
|
227 |
+
draw.rectangle(box_coords, fill=colors[i])
|
228 |
+
|
229 |
+
# Draw text (vertically centered with color box)
|
230 |
+
text_y = box_y + (box_size - font_height) // 2 # Center text with box
|
231 |
+
# Format text with truncated p-value
|
232 |
+
p_value_truncated = "{:.2f}".format(task.adjusted_p_value)
|
233 |
+
legend_text = f"{task.target} ({p_value_truncated})"
|
234 |
+
draw.text(
|
235 |
+
(box_x + box_size + legend_padding, text_y),
|
236 |
+
legend_text,
|
237 |
+
fill=(0, 0, 0),
|
238 |
+
font=font,
|
239 |
+
)
|
240 |
+
|
241 |
+
return img_pil
|