ammariii08 commited on
Commit
e409fcf
·
verified ·
1 Parent(s): 7a18006

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -463
app.py CHANGED
@@ -1,472 +1,121 @@
1
- import os
2
- from pathlib import Path
3
- from typing import List, Union
4
- from PIL import Image
5
- import ezdxf.units
6
- import numpy as np
7
  import torch
8
- from torchvision import transforms
9
- from ultralytics import YOLOWorld, YOLO
10
- from ultralytics.engine.results import Results
11
- from ultralytics.utils.plotting import save_one_box
12
- from transformers import AutoModelForImageSegmentation
13
- import cv2
14
- import ezdxf
15
  import gradio as gr
16
- import zipfile
17
- import datetime
18
-
19
- from scalingtestupdated import calculate_scaling_factor
20
- from shapely.geometry import Polygon, Point
21
- from scipy.interpolate import splprep, splev
22
- from scipy.ndimage import gaussian_filter1d
23
-
24
- ###############################################################################
25
- # 1) Single-Image Pipeline & Utilities (Simplified)
26
- ###############################################################################
27
-
28
- # Load Segmentation Model (BiRefNet)
29
- birefnet = AutoModelForImageSegmentation.from_pretrained(
30
- "zhengpeng7/BiRefNet", trust_remote_code=True
31
- )
32
- device = "cpu"
33
- torch.set_float32_matmul_precision(["high", "highest"][0])
34
- birefnet.to(device)
35
- birefnet.eval()
36
-
37
- transform_image = transforms.Compose([
38
- transforms.Resize((1024, 1024)),
39
- transforms.ToTensor(),
40
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
41
- ])
42
-
43
-
44
- def yolo_detect(image: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor],
45
- classes: List[str]) -> np.ndarray:
46
- """Detects the drawer (box) in the image using YOLOWorld."""
47
- drawer_detector = YOLOWorld("yolov8x-worldv2.pt")
48
- drawer_detector.set_classes(classes)
49
- results: List[Results] = drawer_detector.predict(image)
50
- boxes = []
51
- for result in results:
52
- boxes.append(save_one_box(result.cpu().boxes.xyxy, im=result.orig_img, save=False))
53
- del drawer_detector
54
- return boxes[0]
55
-
56
-
57
- def resize_img(img: np.ndarray, resize_dim):
58
- return np.array(Image.fromarray(img).resize(resize_dim))
59
-
60
-
61
- def remove_bg(image: np.ndarray) -> np.ndarray:
62
- """Removes background using BiRefNet, returning a binary mask."""
63
- image_pil = Image.fromarray(image)
64
- input_images = transform_image(image_pil).unsqueeze(0).to(device)
65
- with torch.no_grad():
66
- preds = birefnet(input_images)[-1].sigmoid().cpu()
67
- pred = preds[0].squeeze()
68
- pred_pil: Image = transforms.ToPILImage()(pred)
69
- scale_ratio = 1024 / max(image_pil.size)
70
- scaled_size = (int(image_pil.size[0] * scale_ratio), int(image_pil.size[1] * scale_ratio))
71
- return np.array(pred_pil.resize(scaled_size))
72
-
73
-
74
- def make_square(img: np.ndarray):
75
- """Pads an image to be square (max dimension)."""
76
- height, width = img.shape[:2]
77
- max_dim = max(height, width)
78
- pad_height = (max_dim - height) // 2
79
- pad_width = (max_dim - width) // 2
80
- pad_height_extra = max_dim - height - 2 * pad_height
81
- pad_width_extra = max_dim - width - 2 * pad_width
82
- if len(img.shape) == 3:
83
- padded = np.pad(img, ((pad_height, pad_height + pad_height_extra),
84
- (pad_width, pad_width + pad_width_extra), (0, 0)), mode="edge")
85
- else:
86
- padded = np.pad(img, ((pad_height, pad_height + pad_height_extra),
87
- (pad_width, pad_width + pad_width_extra)), mode="edge")
88
- return padded
89
-
90
-
91
- def exclude_scaling_box(image: np.ndarray, bbox: np.ndarray, orig_size: tuple, processed_size: tuple,
92
- expansion_factor: float = 1.2) -> np.ndarray:
93
- """Zeros out the area of the reference square from the binary mask."""
94
- x_min, y_min, x_max, y_max = map(int, bbox)
95
- scale_x = processed_size[1] / orig_size[1]
96
- scale_y = processed_size[0] / orig_size[0]
97
- x_min = int(x_min * scale_x)
98
- x_max = int(x_max * scale_x)
99
- y_min = int(y_min * scale_y)
100
- y_max = int(y_max * scale_y)
101
- box_width = x_max - x_min
102
- box_height = y_max - y_min
103
- expanded_x_min = max(0, int(x_min - (expansion_factor - 1) * box_width / 2))
104
- expanded_x_max = min(image.shape[1], int(x_max + (expansion_factor - 1) * box_width / 2))
105
- expanded_y_min = max(0, int(y_min - (expansion_factor - 1) * box_height / 2))
106
- expanded_y_max = min(image.shape[0], int(y_max + (expansion_factor - 1) * box_height / 2))
107
- image[expanded_y_min:expanded_y_max, expanded_x_min:expanded_x_max] = 0
108
- return image
109
-
110
-
111
- def resample_contour(contour):
112
- """Resamples a contour to ~1000 points using spline interpolation and smoothing."""
113
- num_points = 1000
114
- smoothing_factor = 5
115
- spline_degree = 3
116
- if len(contour) < spline_degree + 1:
117
- raise ValueError("Contour must have at least 4 points.")
118
- contour = contour[:, 0, :]
119
- tck, _ = splprep([contour[:, 0], contour[:, 1]], s=smoothing_factor)
120
- u = np.linspace(0, 1, num_points)
121
- resampled_points = splev(u, tck)
122
- smoothed_x = gaussian_filter1d(resampled_points[0], sigma=1)
123
- smoothed_y = gaussian_filter1d(resampled_points[1], sigma=1)
124
- return np.array([smoothed_x, smoothed_y]).T
125
-
126
-
127
- def extract_outlines(binary_image: np.ndarray):
128
- """Finds external contours in a binary mask, returns the outline image and the list of contours."""
129
- contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
130
- outline_image = np.zeros_like(binary_image)
131
- cv2.drawContours(outline_image, contours, -1, (255), thickness=1)
132
- return cv2.bitwise_not(outline_image), contours
133
-
134
-
135
- def shrink_bbox(image: np.ndarray, shrink_factor: float):
136
- """Shrinks the bounding box around the image by a certain factor."""
137
- height, width = image.shape[:2]
138
- center_x, center_y = width // 2, height // 2
139
- new_width = int(width * shrink_factor)
140
- new_height = int(height * shrink_factor)
141
- x1 = max(center_x - new_width // 2, 0)
142
- y1 = max(center_y - new_height // 2, 0)
143
- x2 = min(center_x + new_width // 2, width)
144
- y2 = min(center_y + new_height // 2, height)
145
- return image[y1:y2, x1:x2]
146
-
147
-
148
- def detect_reference_square(img) -> np.ndarray:
149
- """Detects the reference square in the image using a YOLO model saved in './last.pt'."""
150
- box_detector = YOLO("./last.pt")
151
- res = box_detector.predict(img, conf=0.05)
152
- del box_detector
153
- if len(res) == 0 or len(res[0].boxes) == 0:
154
- raise ValueError("No reference square found.")
155
- cropped_img = save_one_box(res[0].cpu().boxes.xyxy, res[0].orig_img, save=False)
156
- coords = res[0].cpu().boxes.xyxy[0]
157
- return cropped_img, coords
158
-
159
-
160
- def build_tool_polygon(points_inch):
161
- return Polygon(points_inch)
162
-
163
-
164
- def polygon_to_exterior_coords(poly: Polygon):
165
- """Gets the exterior coordinates of a polygon (or the largest piece if MultiPolygon)."""
166
- if poly.geom_type == "MultiPolygon":
167
- poly = max(poly.geoms, key=lambda g: g.area)
168
- if not poly.exterior:
169
- return []
170
- return list(poly.exterior.coords)
171
-
172
-
173
- def save_dxf_spline(inflated_contours, scaling_factor, height):
174
- """Creates a DXF with splines from the inflated contours."""
175
- doc = ezdxf.new(units=0)
176
- doc.units = ezdxf.units.IN
177
- doc.header["$INSUNITS"] = ezdxf.units.IN
178
- msp = doc.modelspace()
179
- final_polygons_inch = []
180
- for contour in inflated_contours:
181
- try:
182
- resampled = resample_contour(contour)
183
- points_inch = [(x * scaling_factor, (height - y) * scaling_factor) for (x, y) in resampled]
184
- if len(points_inch) < 3:
185
- continue
186
- if np.linalg.norm(np.array(points_inch[0]) - np.array(points_inch[-1])) > 1e-6:
187
- points_inch.append(points_inch[0])
188
- tool_polygon = build_tool_polygon(points_inch)
189
- exterior_coords = polygon_to_exterior_coords(tool_polygon)
190
- if len(exterior_coords) < 3:
191
- continue
192
- msp.add_spline(exterior_coords, degree=3, dxfattribs={"layer": "TOOLS"})
193
- final_polygons_inch.append(tool_polygon)
194
- except ValueError as e:
195
- print(f"Skipping contour: {e}")
196
- return doc, final_polygons_inch
197
-
198
-
199
- def draw_polygons_inch(polygons_inch, image_rgb, scaling_factor, image_height,
200
- color=(0, 255, 0), thickness=1):
201
- """Draws polygons on an image for visualization."""
202
- for poly in polygons_inch:
203
- if poly.geom_type == "MultiPolygon":
204
- for subpoly in poly.geoms:
205
- draw_single_polygon(subpoly, image_rgb, scaling_factor, image_height, color, thickness)
206
- else:
207
- draw_single_polygon(poly, image_rgb, scaling_factor, image_height, color, thickness)
208
-
209
-
210
- def draw_single_polygon(poly, image_rgb, scaling_factor, image_height,
211
- color=(0, 255, 0), thickness=1):
212
- """Helper to draw a single polygon."""
213
- ext = list(poly.exterior.coords)
214
- if len(ext) < 3:
215
- return
216
- pts_px = []
217
- for (x_in, y_in) in ext:
218
- px = int(x_in / scaling_factor)
219
- py = int(image_height - (y_in / scaling_factor))
220
- pts_px.append([px, py])
221
- pts_px = np.array(pts_px, dtype=np.int32)
222
- cv2.polylines(image_rgb, [pts_px], isClosed=True, color=color, thickness=thickness, lineType=cv2.LINE_AA)
223
-
224
- ###############################################################################
225
- # 2) Single-Image Predict (Only Image & Offset)
226
- ###############################################################################
227
- def predict(image, offset, offset_unit):
228
- # Convert offset to inches if necessary
229
- if offset_unit == "mm":
230
- offset_inches = offset / 25.4
231
- else:
232
- offset_inches = offset
233
-
234
- try:
235
- drawer_img = yolo_detect(image, ["box"])
236
- shrunked_img = make_square(shrink_bbox(drawer_img, 0.90))
237
- except Exception as e:
238
- raise gr.Error("Unable to DETECT DRAWER. Please try a different image or angle!") from e
239
-
240
- try:
241
- reference_obj_img, scaling_box_coords = detect_reference_square(shrunked_img)
242
- except Exception as e:
243
- raise gr.Error("Unable to DETECT REFERENCE BOX. Please try a different image!") from e
244
-
245
- reference_obj_img = make_square(reference_obj_img)
246
- reference_square_mask = remove_bg(reference_obj_img)
247
- reference_square_mask = resize_img(reference_square_mask, (reference_obj_img.shape[1], reference_obj_img.shape[0]))
248
-
249
- try:
250
- scaling_factor = calculate_scaling_factor(
251
- reference_image_path="./Reference_ScalingBox.jpg",
252
- target_image=reference_square_mask,
253
- feature_detector="ORB",
254
- )
255
- except ZeroDivisionError:
256
- scaling_factor = None
257
- print("Error calculating scaling factor: Division by zero")
258
- except Exception as e:
259
- scaling_factor = None
260
- print(f"Error calculating scaling factor: {e}")
261
-
262
- if scaling_factor is None or scaling_factor == 0:
263
- scaling_factor = 1.0
264
- print("Using default scaling factor of 1.0 due to calculation error")
265
-
266
- orig_size = shrunked_img.shape[:2]
267
- objects_mask = remove_bg(shrunked_img)
268
- processed_size = objects_mask.shape[:2]
269
-
270
- # Exclude the reference square from the mask
271
- objects_mask = exclude_scaling_box(objects_mask, scaling_box_coords, orig_size, processed_size, expansion_factor=1.2)
272
- objects_mask = resize_img(objects_mask, (shrunked_img.shape[1], shrunked_img.shape[0]))
273
-
274
- if scaling_factor != 0:
275
- offset_pixels = (offset_inches / scaling_factor) * 2 + 1
276
- else:
277
- offset_pixels = 1
278
-
279
- dilated_mask = cv2.dilate(objects_mask, np.ones((int(offset_pixels), int(offset_pixels)), np.uint8))
280
- outlines, contours = extract_outlines(dilated_mask)
281
-
282
- color_output = cv2.cvtColor(shrunked_img, cv2.COLOR_BGR2RGB)
283
- outlines_bgr = cv2.cvtColor(outlines, cv2.COLOR_GRAY2BGR)
284
-
285
- image_height, image_width = shrunked_img.shape[:2]
286
- doc, final_polygons_inch = save_dxf_spline(inflated_contours=contours, scaling_factor=scaling_factor, height=image_height)
287
-
288
- # Draw tool outlines on images
289
- draw_polygons_inch(final_polygons_inch, color_output, scaling_factor, image_height, color=(0, 255, 0), thickness=1)
290
- draw_polygons_inch(final_polygons_inch, outlines_bgr, scaling_factor, image_height, color=(0, 255, 0), thickness=1)
291
-
292
- outlines_color = cv2.cvtColor(outlines_bgr, cv2.COLOR_BGR2RGB)
293
-
294
- # Save DXF file
295
- dxf_filepath = os.path.join("./outputs", "out.dxf")
296
- doc.saveas(dxf_filepath)
297
-
298
- return color_output, outlines_color, dxf_filepath, dilated_mask, str(scaling_factor)
299
-
300
- ###############################################################################
301
- # 3) Batch Processing (Up to 4 Images; Retry Faulty Ones Separately)
302
- ###############################################################################
303
- def batch_predict(images, offsets_str, offset_unit):
304
- offsets = [float(x.strip()) for x in offsets_str.split(",")]
305
- if len(images) != len(offsets):
306
- raise gr.Error("The number of images and offsets must match!")
307
-
308
- final_images = []
309
- outline_images = []
310
- mask_images = []
311
- scale_factors_dict = {}
312
- dxf_files = {}
313
 
314
- error_indices = []
315
- now_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
316
- zip_path = f"./outputs/batch_{now_str}.zip"
317
- zipf = zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED)
318
 
319
- for i, img_path in enumerate(images):
320
- try:
321
- img_pil = Image.open(img_path).convert("RGB")
322
- img_np = np.array(img_pil)
323
- offset = offsets[i]
324
- color_output, outlines_img, dxf_path, mask_img, sfactor = predict(img_np, offset, offset_unit)
325
- final_images.append(Image.fromarray(color_output))
326
- outline_images.append(Image.fromarray(outlines_img))
327
- mask_images.append(Image.fromarray(mask_img))
328
- scale_factors_dict[str(i)] = sfactor
329
- base_name = os.path.splitext(os.path.basename(img_path))[0]
330
- unique_dxf = f"./outputs/{base_name}_{i}.dxf"
331
- os.rename(dxf_path, unique_dxf)
332
- dxf_files[i] = unique_dxf
333
- zipf.write(unique_dxf, arcname=os.path.basename(unique_dxf))
334
- except Exception as e:
335
- error_indices.append(i)
336
- final_images.append(None)
337
- outline_images.append(None)
338
- mask_images.append(None)
339
- scale_factors_dict[str(i)] = f"Error: {str(e)}"
340
- zipf.close()
341
- return final_images, outline_images, zip_path, mask_images, scale_factors_dict, error_indices
342
 
 
 
343
 
344
- def retry_predict(index, image_path, offset, offset_unit, current_zip_path, current_scale_factors):
345
  """
346
- Retry processing a single faulty image. Returns updated outputs for that image and updated zip & scale factors.
 
 
 
 
 
 
347
  """
348
- try:
349
- img_pil = Image.open(image_path).convert("RGB")
350
- img_np = np.array(img_pil)
351
- color_output, outlines_img, dxf_path, mask_img, sfactor = predict(img_np, offset, offset_unit)
352
- processed_img = Image.fromarray(color_output)
353
- outline_img = Image.fromarray(outlines_img)
354
- mask_image = Image.fromarray(mask_img)
355
- base_name = os.path.splitext(os.path.basename(image_path))[0]
356
- unique_dxf = f"./outputs/{base_name}_{index}.dxf"
357
- os.rename(dxf_path, unique_dxf)
358
- # Append the new DXF to the existing zip archive.
359
- with zipfile.ZipFile(current_zip_path, "a", zipfile.ZIP_DEFLATED) as zipf:
360
- zipf.write(unique_dxf, arcname=os.path.basename(unique_dxf))
361
- current_scale_factors[str(index)] = sfactor
362
- return processed_img, outline_img, mask_image, unique_dxf, current_zip_path, current_scale_factors, ""
363
- except Exception as e:
364
- return None, None, None, None, current_zip_path, current_scale_factors, str(e)
365
-
366
- ###############################################################################
367
- # 4) Gradio UI
368
- ###############################################################################
369
- if __name__ == "__main__":
370
- os.makedirs("./outputs", exist_ok=True)
371
-
372
- with gr.Blocks() as demo:
373
- gr.Markdown("## Choose Processing Mode")
374
-
375
- # Radio to pick Single or Batch
376
- mode_select = gr.Radio(choices=["Single", "Batch"], value="Single", label="Select Mode")
377
- single_section = gr.Group(visible=True)
378
- batch_section = gr.Group(visible=False)
379
- retry_section = gr.Group(visible=False)
380
-
381
- # Toggle mode visibility
382
- def toggle_mode(mode):
383
- if mode == "Single":
384
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
385
- else:
386
- return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
387
- mode_select.change(fn=toggle_mode, inputs=mode_select, outputs=[single_section, batch_section, retry_section])
388
-
389
- #######################################################################
390
- # Single-Image Section
391
- #######################################################################
392
- with single_section:
393
- gr.Markdown("### Single-Image Processing")
394
- with gr.Row():
395
- with gr.Column():
396
- image_input = gr.Image(label="Input Image")
397
- offset_input = gr.Number(label="Offset", value=0.075)
398
- offset_unit_input = gr.Dropdown(label="Offset Unit", choices=["inches", "mm"], value="inches")
399
- submit_btn = gr.Button("Submit Single")
400
- clear_btn = gr.Button("Clear Single")
401
- with gr.Column():
402
- output_image = gr.Image(label="Output Image")
403
- outlines_image = gr.Image(label="Outlined Image")
404
- dxf_file = gr.File(label="DXF File")
405
- mask_image = gr.Image(label="Mask")
406
- scaling_factor_txt = gr.Textbox(label="Scaling Factor (inches/pixel)", placeholder="Computed value")
407
- submit_btn.click(fn=predict,
408
- inputs=[image_input, offset_input, offset_unit_input],
409
- outputs=[output_image, outlines_image, dxf_file, mask_image, scaling_factor_txt])
410
- clear_btn.click(fn=lambda: (None, None, None, None, ""),
411
- inputs=[], outputs=[output_image, outlines_image, dxf_file, mask_image, scaling_factor_txt])
412
-
413
- #######################################################################
414
- # Batch Section
415
- #######################################################################
416
- # Helper function to limit files to a maximum of 4
417
- def limit_files(file_list):
418
- """If more than 4 files are uploaded, return only the first 4."""
419
- if file_list is None:
420
- return None
421
- if len(file_list) > 4:
422
- return file_list[:4]
423
- return file_list
424
-
425
- with batch_section:
426
- gr.Markdown("### Batch Processing (Up to 4 Images)")
427
- with gr.Row():
428
- with gr.Column():
429
- images_input = gr.File(label="Upload 4 Images (up to 4)", file_count="multiple", type="filepath")
430
- images_input.change(fn=limit_files, inputs=images_input, outputs=images_input)
431
- offsets_input = gr.Textbox(label="Offsets (comma-separated, one per image)", placeholder="e.g. 0.1, 0.1")
432
- offset_unit_batch = gr.Dropdown(label="Offset Unit", choices=["inches", "mm"], value="inches")
433
- batch_submit_btn = gr.Button("Submit Batch")
434
- batch_clear_btn = gr.Button("Clear Batch")
435
- with gr.Column():
436
- final_images_gallery = gr.Gallery(label="Final Annotated Images", columns=2)
437
- outlines_gallery = gr.Gallery(label="Outlined Images", columns=2)
438
- masks_gallery = gr.Gallery(label="Mask Images", columns=2)
439
- dxf_zip_file = gr.File(label="DXF Files (zip)")
440
- scale_factors_text = gr.JSON(label="Scale Factors (Key=Image Index)")
441
- error_indices_txt = gr.Textbox(label="Error Indices (if any)", interactive=False)
442
- batch_submit_btn.click(fn=batch_predict,
443
- inputs=[images_input, offsets_input, offset_unit_batch],
444
- outputs=[final_images_gallery, outlines_gallery, dxf_zip_file, masks_gallery, scale_factors_text, error_indices_txt])
445
- batch_clear_btn.click(fn=lambda: ([], [], None, [], {}, ""),
446
- inputs=[], outputs=[final_images_gallery, outlines_gallery, dxf_zip_file, masks_gallery, scale_factors_text, error_indices_txt])
447
 
448
- #######################################################################
449
- # Retry Faulty Image Section
450
- #######################################################################
451
- with retry_section:
452
- gr.Markdown("### Retry Faulty Image")
453
- with gr.Row():
454
- with gr.Column():
455
- retry_index = gr.Textbox(label="Index of Faulty Image (0-indexed)", placeholder="Enter index of failed image")
456
- retry_image_input = gr.Image(label="Replacement Image")
457
- retry_offset = gr.Number(label="Offset", value=0.075)
458
- retry_offset_unit = gr.Dropdown(label="Offset Unit", choices=["inches", "mm"], value="inches")
459
- current_zip = gr.Textbox(label="Current ZIP File Path", interactive=False)
460
- current_scale = gr.JSON(label="Current Scale Factors", value={})
461
- retry_btn = gr.Button("Retry Faulty Image")
462
- with gr.Column():
463
- retry_final_img = gr.Image(label="Updated Final Image")
464
- retry_outline_img = gr.Image(label="Updated Outline Image")
465
- retry_mask_img = gr.Image(label="Updated Mask Image")
466
- updated_zip = gr.File(label="Updated ZIP File")
467
- updated_scale = gr.JSON(label="Updated Scale Factors")
468
- retry_error = gr.Textbox(label="Retry Error Message", interactive=False)
469
- retry_btn.click(fn=retry_predict,
470
- inputs=[retry_index, retry_image_input, retry_offset, retry_offset_unit, current_zip, current_scale],
471
- outputs=[retry_final_img, retry_outline_img, retry_mask_img, current_zip, updated_zip, updated_scale, retry_error])
472
- demo.launch(share=True)
 
 
 
 
 
 
 
1
  import torch
2
+ import base64
3
+ import numpy as np
4
+ from io import BytesIO
5
+ from PIL import Image, ImageEnhance
 
 
 
6
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
9
+ from prompts import front, back # prompts.py should define front and back as multiline strings
 
 
10
 
11
+ # Load the OCR model and processor once
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(
14
+ "allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16
15
+ ).eval().to(device)
16
+ ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Load the YOLO model (using torch.hub and your custom checkpoint "best.pt")
19
+ yolo_model = torch.hub.load('ultralytics/yolov5', 'custom', path='best.pt', force_reload=False)
20
 
21
+ def process_image(input_image):
22
  """
23
+ 1. Preprocess the input image.
24
+ 2. Run YOLO detection to get the document type and bounding box.
25
+ 3. Crop the image according to the bounding box.
26
+ 4. Based on the detection label ("front" or "back"), select the corresponding prompt.
27
+ 5. Convert the cropped image to base64 and build the chat message.
28
+ 6. Run the OCR model using the constructed prompt and cropped image.
29
+ 7. Return the cropped image and extracted text.
30
  """
31
+ # Step 1: Enhance the image (sharpness, contrast, brightness)
32
+ enhanced_image = ImageEnhance.Sharpness(input_image).enhance(2.0)
33
+ enhanced_image = ImageEnhance.Contrast(enhanced_image).enhance(1.5)
34
+ enhanced_image = ImageEnhance.Brightness(enhanced_image).enhance(0.8)
35
+
36
+ # Step 2: Run YOLO detection
37
+ # Convert PIL image to numpy array (RGB)
38
+ image_np = np.array(enhanced_image)
39
+ results = yolo_model(image_np)
40
+ df = results.pandas().xyxy[0]
41
+
42
+ if df.empty:
43
+ return enhanced_image, "No document detected by YOLO."
44
+
45
+ # Use the detection with the highest confidence
46
+ best_row = df.sort_values(by="confidence", ascending=False).iloc[0]
47
+ label = best_row['name']
48
+ bbox = (int(best_row['xmin']), int(best_row['ymin']),
49
+ int(best_row['xmax']), int(best_row['ymax']))
50
+
51
+ # Step 3: Crop the image using the bounding box
52
+ cropped_image = enhanced_image.crop(bbox)
53
+
54
+ # Step 4: Select the prompt based on YOLO label
55
+ if label.lower() == "front":
56
+ doc_prompt = front
57
+ elif label.lower() == "back":
58
+ doc_prompt = back
59
+ else:
60
+ doc_prompt = front # default to front if label is unexpected
61
+
62
+ # Step 5: Convert cropped image to base64 for the message
63
+ buffered = BytesIO()
64
+ cropped_image.save(buffered, format="PNG")
65
+ cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
66
+
67
+ # Build the message in the expected format for the OCR processor
68
+ messages = [
69
+ {
70
+ "role": "user",
71
+ "content": [
72
+ {"type": "text", "text": doc_prompt},
73
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{cropped_base64}"}},
74
+ ],
75
+ }
76
+ ]
77
+ text_prompt = ocr_processor.apply_chat_template(
78
+ messages, tokenize=False, add_generation_prompt=True
79
+ )
80
+
81
+ # Step 6: Prepare inputs and run the OCR model
82
+ inputs = ocr_processor(
83
+ text=[text_prompt],
84
+ images=[cropped_image],
85
+ padding=True,
86
+ return_tensors="pt",
87
+ )
88
+ inputs = {k: v.to(device) for k, v in inputs.items()}
89
+
90
+ output = ocr_model.generate(
91
+ **inputs,
92
+ temperature=0.8,
93
+ max_new_tokens=50,
94
+ num_return_sequences=1,
95
+ do_sample=True,
96
+ )
97
+ prompt_length = inputs["input_ids"].shape[1]
98
+ new_tokens = output[:, prompt_length:]
99
+ text_output = ocr_processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
100
+ extracted_text = text_output[0]
101
+
102
+ # Step 7: Return the cropped (preprocessed) image and extracted text
103
+ return cropped_image, extracted_text
104
+
105
+ # Define the Gradio Interface
106
+ iface = gr.Interface(
107
+ fn=process_image,
108
+ inputs=gr.Image(type="pil", label="Input Document Image"),
109
+ outputs=[
110
+ gr.Image(type="pil", label="Cropped & Preprocessed Image"),
111
+ gr.Textbox(label="Extracted Text")
112
+ ],
113
+ title="Document OCR with YOLO and OLMOCR",
114
+ description=(
115
+ "Upload an image of a document. The app enhances the image, uses a YOLO model "
116
+ "to detect and crop the document (front/back), and then extracts text using the OCR model "
117
+ "with a corresponding prompt."
118
+ ),
119
+ )
 
 
 
 
 
 
 
 
 
 
120
 
121
+ iface.launch()