Spaces:
Running
Running
File size: 5,796 Bytes
7a18006 e409fcf 7a18006 e409fcf d1385e8 1345fd5 e409fcf 7a18006 d1385e8 7a18006 e409fcf 7a18006 1345fd5 7a18006 1345fd5 e409fcf 1345fd5 e409fcf d1385e8 e409fcf 1345fd5 d1385e8 e409fcf 1345fd5 d1385e8 1345fd5 d1385e8 e409fcf 1345fd5 1d1b60e 1345fd5 757bd2d 1345fd5 8010d4b 1345fd5 757bd2d 45ec7b5 e409fcf 1345fd5 e409fcf 1345fd5 e409fcf 45ec7b5 e409fcf 45ec7b5 e409fcf 1345fd5 717c9f8 1345fd5 8010d4b 45ec7b5 1345fd5 2c26fe7 e409fcf 1345fd5 e409fcf 3a08c60 e409fcf 1345fd5 e409fcf 1345fd5 e409fcf 7a18006 1345fd5 1d1b60e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import torch
import base64
import numpy as np
from io import BytesIO
from PIL import Image, ImageEnhance
import gradio as gr
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from ultralytics import YOLO
from prompts import front, back # prompts.py should define front and back as multiline strings
# Load the OCR model and processor once
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(
"allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16
).eval().to(device)
ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
# Load the YOLO model using ultralytics (ensure best.pt is in your working directory)
yolo_model = YOLO("best.pt")
def process_image(input_image):
"""
1. Preprocess the input image.
2. Run YOLO detection with a confidence threshold of 0.85.
3. Crop the image according to the detected bounding box.
4. Choose the corresponding prompt from prompts.py based on the label.
5. Convert the cropped image to base64 and build the OCR prompt.
6. Run the OCR model to extract text.
7. Return the cropped (preprocessed) image and the extracted text.
"""
# Step 1: Enhance the image (sharpness, contrast, brightness)
enhanced_image = ImageEnhance.Sharpness(input_image).enhance(2.0)
enhanced_image = ImageEnhance.Contrast(enhanced_image).enhance(1.5)
enhanced_image = ImageEnhance.Brightness(enhanced_image).enhance(0.8)
# Step 2: Run YOLO detection using ultralytics with confidence threshold = 0.85
image_np = np.array(enhanced_image)
results = yolo_model.predict(source=image_np, conf=0.85)
result = results[0]
# If no boxes detected, return the enhanced image with an error message.
if len(result.boxes) == 0:
return enhanced_image, "No document detected by YOLO."
# Step 3: Select the detection with the highest confidence
boxes = result.boxes
confidences = boxes.conf.cpu().numpy() # convert tensor to numpy array
best_index = int(confidences.argmax())
best_box = boxes.xyxy[best_index].cpu().numpy().tolist() # [xmin, ymin, xmax, ymax]
xmin, ymin, xmax, ymax = map(int, best_box)
# Retrieve the detected label using the model's names mapping
class_idx = int(boxes.cls[best_index].item())
label = yolo_model.names[class_idx]
# Step 4: Crop the image using the bounding box
cropped_image = enhanced_image.crop((xmin, ymin, xmax, ymax))
# OPTIMIZATION: Resize the image to reduce processing time
# Calculate aspect ratio to maintain proportions
max_size = (800, 800) # Further reduced from 800x800
cropped_image.thumbnail(max_size, Image.LANCZOS)
# Select the corresponding OCR prompt based on the YOLO label
if label.lower() == "front":
doc_prompt = front
elif label.lower() == "back":
doc_prompt = back
else:
doc_prompt = front # Default to front if unexpected label
# Step 5: Convert cropped image to base64 for the message
buffered = BytesIO()
cropped_image.save(buffered, format="PNG")
cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Build the message in the expected format for the OCR processor
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": doc_prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{cropped_base64}"}},
],
}
]
text_prompt = ocr_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Step 6: Prepare inputs and run the OCR model
inputs = ocr_processor(
text=[text_prompt],
images=[cropped_image],
padding=True,
return_tensors="pt",
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# FIXED: Generation parameters with proper combinations to avoid warnings
# Choose one of these two approaches:
# Approach 1: Greedy decoding (fastest)
# output = ocr_model.generate(
# **inputs,
# max_new_tokens=40,
# num_beams=1,
# do_sample=False # Greedy decoding
# )
output = ocr_model.generate(
**inputs,
temperature=0.8,
max_new_tokens=1024,
num_return_sequences=1,
do_sample=True,
)
# Uncomment this block and comment the above if you want sampling instead
# # Approach 2: Sampling (more natural but slower)
# output = ocr_model.generate(
# **inputs,
# max_new_tokens=40,
# do_sample=True,
# temperature=0.2,
# top_p=0.95,
# top_k=50,
# num_return_sequences=1
# )
prompt_length = inputs["input_ids"].shape[1]
new_tokens = output[:, prompt_length:]
text_output = ocr_processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
extracted_text = text_output[0]
# Step 7: Return the cropped (preprocessed) image and extracted text
return cropped_image, extracted_text
# Define the Gradio Interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Input Document Image"),
outputs=[
gr.Image(type="pil", label="Cropped & Preprocessed Image"),
gr.Textbox(label="Extracted Text")
],
title="Document OCR with YOLO and OLMOCR",
description=(
"Upload an image of a document. The app enhances the image, then extracts text using an OCR model."
),
allow_flagging="never" # Disable flagging to simplify UI
)
# Enable queue and sharing for Hugging Face Space
iface.launch(share=True) |