Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ from PIL import Image, ImageEnhance
|
|
6 |
import gradio as gr
|
7 |
|
8 |
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
|
|
9 |
from prompts import front, back # prompts.py should define front and back as multiline strings
|
10 |
|
11 |
# Load the OCR model and processor once
|
@@ -15,49 +16,54 @@ ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
15 |
).eval().to(device)
|
16 |
ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
17 |
|
18 |
-
# Load the YOLO model
|
19 |
-
yolo_model =
|
20 |
|
21 |
def process_image(input_image):
|
22 |
"""
|
23 |
1. Preprocess the input image.
|
24 |
-
2. Run YOLO detection
|
25 |
-
3. Crop the image according to the bounding box.
|
26 |
-
4.
|
27 |
-
5. Convert the cropped image to base64 and build the
|
28 |
-
6. Run the OCR model
|
29 |
-
7. Return the cropped image and extracted text.
|
30 |
"""
|
31 |
# Step 1: Enhance the image (sharpness, contrast, brightness)
|
32 |
enhanced_image = ImageEnhance.Sharpness(input_image).enhance(2.0)
|
33 |
enhanced_image = ImageEnhance.Contrast(enhanced_image).enhance(1.5)
|
34 |
enhanced_image = ImageEnhance.Brightness(enhanced_image).enhance(0.8)
|
35 |
|
36 |
-
# Step 2: Run YOLO detection
|
37 |
-
# Convert PIL image to numpy array (RGB)
|
38 |
image_np = np.array(enhanced_image)
|
39 |
-
results = yolo_model(image_np)
|
40 |
-
|
41 |
|
42 |
-
|
|
|
43 |
return enhanced_image, "No document detected by YOLO."
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
50 |
|
51 |
-
#
|
52 |
-
|
|
|
53 |
|
54 |
-
# Step 4:
|
|
|
|
|
|
|
55 |
if label.lower() == "front":
|
56 |
doc_prompt = front
|
57 |
elif label.lower() == "back":
|
58 |
doc_prompt = back
|
59 |
else:
|
60 |
-
doc_prompt = front #
|
61 |
|
62 |
# Step 5: Convert cropped image to base64 for the message
|
63 |
buffered = BytesIO()
|
@@ -110,11 +116,11 @@ iface = gr.Interface(
|
|
110 |
gr.Image(type="pil", label="Cropped & Preprocessed Image"),
|
111 |
gr.Textbox(label="Extracted Text")
|
112 |
],
|
113 |
-
title="Document OCR with YOLO and OLMOCR",
|
114 |
description=(
|
115 |
"Upload an image of a document. The app enhances the image, uses a YOLO model "
|
116 |
-
"to detect and crop the document (front/back)
|
117 |
-
"with a corresponding prompt."
|
118 |
),
|
119 |
)
|
120 |
|
|
|
6 |
import gradio as gr
|
7 |
|
8 |
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
9 |
+
from ultralytics import YOLO
|
10 |
from prompts import front, back # prompts.py should define front and back as multiline strings
|
11 |
|
12 |
# Load the OCR model and processor once
|
|
|
16 |
).eval().to(device)
|
17 |
ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
18 |
|
19 |
+
# Load the YOLO model using ultralytics (ensure best.pt is in your working directory)
|
20 |
+
yolo_model = YOLO("best.pt")
|
21 |
|
22 |
def process_image(input_image):
|
23 |
"""
|
24 |
1. Preprocess the input image.
|
25 |
+
2. Run YOLO detection with a confidence threshold of 0.85.
|
26 |
+
3. Crop the image according to the detected bounding box.
|
27 |
+
4. Choose the corresponding prompt from prompts.py based on the label.
|
28 |
+
5. Convert the cropped image to base64 and build the OCR prompt.
|
29 |
+
6. Run the OCR model to extract text.
|
30 |
+
7. Return the cropped (preprocessed) image and the extracted text.
|
31 |
"""
|
32 |
# Step 1: Enhance the image (sharpness, contrast, brightness)
|
33 |
enhanced_image = ImageEnhance.Sharpness(input_image).enhance(2.0)
|
34 |
enhanced_image = ImageEnhance.Contrast(enhanced_image).enhance(1.5)
|
35 |
enhanced_image = ImageEnhance.Brightness(enhanced_image).enhance(0.8)
|
36 |
|
37 |
+
# Step 2: Run YOLO detection using ultralytics with confidence threshold = 0.85
|
|
|
38 |
image_np = np.array(enhanced_image)
|
39 |
+
results = yolo_model.predict(source=image_np, conf=0.85)
|
40 |
+
result = results[0]
|
41 |
|
42 |
+
# If no boxes detected, return the enhanced image with an error message.
|
43 |
+
if len(result.boxes) == 0:
|
44 |
return enhanced_image, "No document detected by YOLO."
|
45 |
|
46 |
+
# Step 3: Select the detection with the highest confidence
|
47 |
+
boxes = result.boxes
|
48 |
+
confidences = boxes.conf.cpu().numpy() # convert tensor to numpy array
|
49 |
+
best_index = int(confidences.argmax())
|
50 |
+
best_box = boxes.xyxy[best_index].cpu().numpy().tolist() # [xmin, ymin, xmax, ymax]
|
51 |
+
xmin, ymin, xmax, ymax = map(int, best_box)
|
52 |
|
53 |
+
# Retrieve the detected label using the model's names mapping
|
54 |
+
class_idx = int(boxes.cls[best_index].item())
|
55 |
+
label = yolo_model.names[class_idx]
|
56 |
|
57 |
+
# Step 4: Crop the image using the bounding box
|
58 |
+
cropped_image = enhanced_image.crop((xmin, ymin, xmax, ymax))
|
59 |
+
|
60 |
+
# Select the corresponding OCR prompt based on the YOLO label
|
61 |
if label.lower() == "front":
|
62 |
doc_prompt = front
|
63 |
elif label.lower() == "back":
|
64 |
doc_prompt = back
|
65 |
else:
|
66 |
+
doc_prompt = front # Default to front if unexpected label
|
67 |
|
68 |
# Step 5: Convert cropped image to base64 for the message
|
69 |
buffered = BytesIO()
|
|
|
116 |
gr.Image(type="pil", label="Cropped & Preprocessed Image"),
|
117 |
gr.Textbox(label="Extracted Text")
|
118 |
],
|
119 |
+
title="Document OCR with YOLO (Ultralytics) and OLMOCR",
|
120 |
description=(
|
121 |
"Upload an image of a document. The app enhances the image, uses a YOLO model "
|
122 |
+
"to detect and crop the document (front/back) with a confidence threshold of 0.85, and "
|
123 |
+
"then extracts text using an OCR model with a corresponding prompt."
|
124 |
),
|
125 |
)
|
126 |
|