Spaces:
Running
Running
import torch | |
import base64 | |
import numpy as np | |
from io import BytesIO | |
from PIL import Image, ImageEnhance | |
import gradio as gr | |
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration | |
from ultralytics import YOLO | |
from prompts import front, back # prompts.py should define front and back as multiline strings | |
# Load the OCR model and processor once | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
ocr_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
"allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16 | |
).eval().to(device) | |
ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | |
# Load the YOLO model using ultralytics (ensure best.pt is in your working directory) | |
yolo_model = YOLO("best.pt") | |
def process_image(input_image): | |
""" | |
1. Preprocess the input image. | |
2. Run YOLO detection with a confidence threshold of 0.85. | |
3. Crop the image according to the detected bounding box. | |
4. Choose the corresponding prompt from prompts.py based on the label. | |
5. Convert the cropped image to base64 and build the OCR prompt. | |
6. Run the OCR model to extract text. | |
7. Return the cropped (preprocessed) image and the extracted text. | |
""" | |
# Step 1: Enhance the image (sharpness, contrast, brightness) | |
enhanced_image = ImageEnhance.Sharpness(input_image).enhance(2.0) | |
enhanced_image = ImageEnhance.Contrast(enhanced_image).enhance(1.5) | |
enhanced_image = ImageEnhance.Brightness(enhanced_image).enhance(0.8) | |
# Step 2: Run YOLO detection using ultralytics with confidence threshold = 0.85 | |
image_np = np.array(enhanced_image) | |
results = yolo_model.predict(source=image_np, conf=0.85) | |
result = results[0] | |
# If no boxes detected, return the enhanced image with an error message. | |
if len(result.boxes) == 0: | |
return enhanced_image, "No document detected by YOLO." | |
# Step 3: Select the detection with the highest confidence | |
boxes = result.boxes | |
confidences = boxes.conf.cpu().numpy() # convert tensor to numpy array | |
best_index = int(confidences.argmax()) | |
best_box = boxes.xyxy[best_index].cpu().numpy().tolist() # [xmin, ymin, xmax, ymax] | |
xmin, ymin, xmax, ymax = map(int, best_box) | |
# Retrieve the detected label using the model's names mapping | |
class_idx = int(boxes.cls[best_index].item()) | |
label = yolo_model.names[class_idx] | |
# Step 4: Crop the image using the bounding box | |
cropped_image = enhanced_image.crop((xmin, ymin, xmax, ymax)) | |
# OPTIMIZATION: Resize the image to reduce processing time | |
# Calculate aspect ratio to maintain proportions | |
max_size = (800, 800) # Further reduced from 800x800 | |
cropped_image.thumbnail(max_size, Image.LANCZOS) | |
# Select the corresponding OCR prompt based on the YOLO label | |
if label.lower() == "front": | |
doc_prompt = front | |
elif label.lower() == "back": | |
doc_prompt = back | |
else: | |
doc_prompt = front # Default to front if unexpected label | |
# Step 5: Convert cropped image to base64 for the message | |
buffered = BytesIO() | |
cropped_image.save(buffered, format="PNG") | |
cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
# Build the message in the expected format for the OCR processor | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": doc_prompt}, | |
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{cropped_base64}"}}, | |
], | |
} | |
] | |
text_prompt = ocr_processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
# Step 6: Prepare inputs and run the OCR model | |
inputs = ocr_processor( | |
text=[text_prompt], | |
images=[cropped_image], | |
padding=True, | |
return_tensors="pt", | |
) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# FIXED: Generation parameters with proper combinations to avoid warnings | |
# Choose one of these two approaches: | |
# Approach 1: Greedy decoding (fastest) | |
# output = ocr_model.generate( | |
# **inputs, | |
# max_new_tokens=40, | |
# num_beams=1, | |
# do_sample=False # Greedy decoding | |
# ) | |
output = ocr_model.generate( | |
**inputs, | |
temperature=0.8, | |
max_new_tokens=1024, | |
num_return_sequences=1, | |
do_sample=True, | |
) | |
# Uncomment this block and comment the above if you want sampling instead | |
# # Approach 2: Sampling (more natural but slower) | |
# output = ocr_model.generate( | |
# **inputs, | |
# max_new_tokens=40, | |
# do_sample=True, | |
# temperature=0.2, | |
# top_p=0.95, | |
# top_k=50, | |
# num_return_sequences=1 | |
# ) | |
prompt_length = inputs["input_ids"].shape[1] | |
new_tokens = output[:, prompt_length:] | |
text_output = ocr_processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True) | |
extracted_text = text_output[0] | |
# Step 7: Return the cropped (preprocessed) image and extracted text | |
return cropped_image, extracted_text | |
# Define the Gradio Interface | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=gr.Image(type="pil", label="Input Document Image"), | |
outputs=[ | |
gr.Image(type="pil", label="Cropped & Preprocessed Image"), | |
gr.Textbox(label="Extracted Text") | |
], | |
title="Document OCR with YOLO and OLMOCR", | |
description=( | |
"Upload an image of a document. The app enhances the image, then extracts text using an OCR model." | |
), | |
allow_flagging="never" # Disable flagging to simplify UI | |
) | |
# Enable queue and sharing for Hugging Face Space | |
iface.launch(share=True) |