kernel-luso-comfort commited on
Commit
fc90b14
·
1 Parent(s): 354d315

Add legend to prediction overlay with dynamic font sizing and color boxes

Browse files
Files changed (1) hide show
  1. inference_utils/model.py +86 -41
inference_utils/model.py CHANGED
@@ -14,7 +14,7 @@ from dataclasses import dataclass
14
  import os
15
  from typing import Tuple
16
 
17
- from PIL import Image, ImageDraw
18
  from huggingface_hub import hf_hub_download
19
  import matplotlib.pyplot as plt
20
  import numpy as np
@@ -50,7 +50,10 @@ class Model:
50
  self._model, image, modality_type, targets
51
  )
52
  targets_not_found_str = (
53
- "\n".join(t.target for t in prediction_targets_not_found)
 
 
 
54
  if prediction_targets_not_found
55
  else "All targets were found!"
56
  )
@@ -104,51 +107,16 @@ def predict(
104
  )
105
  pt.adjusted_p_value = float(adj_p_value)
106
 
107
- pred_tasks_found, pred_tasks_not_found = segregate_prediction_tasks(
108
  prediction_tasks, 0.05
109
  )
110
 
111
  # Generate visualization
112
- colors = generate_colors(len(pred_tasks_found))
113
- masks = [1 * (pred_mask[i] > 0.5) for i in range(len(pred_tasks_found))]
114
  pred_overlay = overlay_masks(image, masks, colors)
115
 
116
- # Add legend
117
- if len(pred_tasks_found) > 0:
118
- # Convert to numpy for manipulation
119
- pred_overlay = np.array(pred_overlay)
120
-
121
- # Calculate legend dimensions
122
- legend_height = 30 * len(pred_tasks_found) # 30 pixels per entry
123
- legend_padding = 10 # padding around legend
124
- total_height = pred_overlay.shape[0] + legend_height + 2 * legend_padding
125
-
126
- # Create new image with space for legend
127
- new_image = np.zeros((total_height, pred_overlay.shape[1], 3), dtype=np.uint8)
128
- new_image[: pred_overlay.shape[0], :] = pred_overlay
129
- new_image[pred_overlay.shape[0] :] = 255 # White background for legend
130
-
131
- # Convert to PIL once for all legend entries
132
- img_pil = Image.fromarray(new_image)
133
- draw = ImageDraw.Draw(img_pil)
134
-
135
- # Draw legend entries
136
- start_y = pred_overlay.shape[0] + legend_padding
137
- for i, task in enumerate(pred_tasks_found):
138
- # Draw color box
139
- box_x = 10
140
- box_y = start_y + i * 30
141
- box_size = 20
142
- box_coords = (box_x, box_y, box_x + box_size, box_y + box_size)
143
- draw.rectangle(box_coords, fill=colors[i])
144
-
145
- # Draw text (vertically centered with color box)
146
- text_y = box_y + (box_size - 12) // 2 # Assuming ~12px text height
147
- draw.text((box_x + box_size + 10, text_y), task.target, fill=(0, 0, 0))
148
-
149
- pred_overlay = img_pil
150
- else:
151
- pred_overlay = Image.fromarray(np.array(pred_overlay))
152
 
153
  return pred_overlay, pred_tasks_not_found
154
 
@@ -194,3 +162,80 @@ def overlay_masks(
194
  np.uint8
195
  )
196
  return Image.fromarray(overlay)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  import os
15
  from typing import Tuple
16
 
17
+ from PIL import Image, ImageDraw, ImageFont
18
  from huggingface_hub import hf_hub_download
19
  import matplotlib.pyplot as plt
20
  import numpy as np
 
50
  self._model, image, modality_type, targets
51
  )
52
  targets_not_found_str = (
53
+ "\n".join(
54
+ f"{t.target} ({t.adjusted_p_value, 2:.3f})"
55
+ for t in prediction_targets_not_found
56
+ )
57
  if prediction_targets_not_found
58
  else "All targets were found!"
59
  )
 
107
  )
108
  pt.adjusted_p_value = float(adj_p_value)
109
 
110
+ pred_targets_found, pred_tasks_not_found = segregate_prediction_tasks(
111
  prediction_tasks, 0.05
112
  )
113
 
114
  # Generate visualization
115
+ colors = generate_colors(len(pred_targets_found))
116
+ masks = [1 * (pred_mask[i] > 0.5) for i in range(len(pred_targets_found))]
117
  pred_overlay = overlay_masks(image, masks, colors)
118
 
119
+ pred_overlay = add_legend(pred_overlay, pred_targets_found, colors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  return pred_overlay, pred_tasks_not_found
122
 
 
162
  np.uint8
163
  )
164
  return Image.fromarray(overlay)
165
+
166
+
167
+ def add_legend(
168
+ image: Image.Image,
169
+ pred_targets_found: list[PredictionTarget],
170
+ colors: list[Tuple[int, int, int]],
171
+ ) -> Image.Image:
172
+ if len(pred_targets_found) == 0:
173
+ return image
174
+
175
+ # Convert to numpy for manipulation
176
+ pred_overlay = np.array(image)
177
+
178
+ # Calculate dimensions based on image resolution
179
+ image_width = pred_overlay.shape[1]
180
+ font_size = max(16, int(image_width * 0.02)) # Scale with image width, minimum 16px
181
+ box_size = int(font_size * 1.5) # Color box proportional to font
182
+ entry_height = int(box_size * 1.5) # Space between entries
183
+ legend_padding = int(font_size * 0.75) # Padding scales with font
184
+
185
+ # Calculate total legend height
186
+ legend_height = entry_height * len(pred_targets_found)
187
+ total_height = pred_overlay.shape[0] + legend_height + 2 * legend_padding
188
+
189
+ # Create new image with space for legend
190
+ new_image = np.zeros((total_height, pred_overlay.shape[1], 3), dtype=np.uint8)
191
+ new_image[: pred_overlay.shape[0], :] = pred_overlay
192
+ new_image[pred_overlay.shape[0] :] = 255 # White background for legend
193
+
194
+ # Convert to PIL once for all legend entries
195
+ img_pil = Image.fromarray(new_image)
196
+ draw = ImageDraw.Draw(img_pil)
197
+
198
+ # Try to load a system font with proper scaling
199
+ font = None
200
+ system_fonts = [
201
+ "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", # Linux
202
+ "/System/Library/Fonts/Helvetica.ttc", # macOS
203
+ "C:\\Windows\\Fonts\\arial.ttf", # Windows
204
+ ]
205
+ for font_path in system_fonts:
206
+ try:
207
+ font = ImageFont.truetype(font_path, font_size)
208
+ break
209
+ except (OSError, IOError):
210
+ continue
211
+
212
+ if font is None:
213
+ # Fallback to default font if no system fonts are available
214
+ font = ImageFont.load_default()
215
+
216
+ # Get font metrics for proper vertical centering
217
+ bbox = font.getbbox("Aj") # Use tall characters to get true height
218
+ font_height = bbox[3] - bbox[1] # bottom - top
219
+
220
+ # Draw legend entries
221
+ start_y = pred_overlay.shape[0] + legend_padding
222
+ for i, task in enumerate(pred_targets_found):
223
+ # Draw color box
224
+ box_x = legend_padding
225
+ box_y = start_y + i * entry_height
226
+ box_coords = (box_x, box_y, box_x + box_size, box_y + box_size)
227
+ draw.rectangle(box_coords, fill=colors[i])
228
+
229
+ # Draw text (vertically centered with color box)
230
+ text_y = box_y + (box_size - font_height) // 2 # Center text with box
231
+ # Format text with truncated p-value
232
+ p_value_truncated = "{:.2f}".format(task.adjusted_p_value)
233
+ legend_text = f"{task.target} ({p_value_truncated})"
234
+ draw.text(
235
+ (box_x + box_size + legend_padding, text_y),
236
+ legend_text,
237
+ fill=(0, 0, 0),
238
+ font=font,
239
+ )
240
+
241
+ return img_pil