import torch import base64 import numpy as np from io import BytesIO from PIL import Image, ImageEnhance import gradio as gr from transformers import AutoProcessor, Qwen2VLForConditionalGeneration from ultralytics import YOLO from prompts import front, back # prompts.py should define front and back as multiline strings # Load the OCR model and processor once device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ocr_model = Qwen2VLForConditionalGeneration.from_pretrained( "allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16 ).eval().to(device) ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") # Load the YOLO model using ultralytics (ensure best.pt is in your working directory) yolo_model = YOLO("best.pt") def process_image(input_image): """ 1. Preprocess the input image. 2. Run YOLO detection with a confidence threshold of 0.85. 3. Crop the image according to the detected bounding box. 4. Choose the corresponding prompt from prompts.py based on the label. 5. Convert the cropped image to base64 and build the OCR prompt. 6. Run the OCR model to extract text. 7. Return the cropped (preprocessed) image and the extracted text. """ # Step 1: Enhance the image (sharpness, contrast, brightness) enhanced_image = ImageEnhance.Sharpness(input_image).enhance(2.0) enhanced_image = ImageEnhance.Contrast(enhanced_image).enhance(1.5) enhanced_image = ImageEnhance.Brightness(enhanced_image).enhance(0.8) # Step 2: Run YOLO detection using ultralytics with confidence threshold = 0.85 image_np = np.array(enhanced_image) results = yolo_model.predict(source=image_np, conf=0.85) result = results[0] # If no boxes detected, return the enhanced image with an error message. if len(result.boxes) == 0: return enhanced_image, "No document detected by YOLO." # Step 3: Select the detection with the highest confidence boxes = result.boxes confidences = boxes.conf.cpu().numpy() # convert tensor to numpy array best_index = int(confidences.argmax()) best_box = boxes.xyxy[best_index].cpu().numpy().tolist() # [xmin, ymin, xmax, ymax] xmin, ymin, xmax, ymax = map(int, best_box) # Retrieve the detected label using the model's names mapping class_idx = int(boxes.cls[best_index].item()) label = yolo_model.names[class_idx] # Step 4: Crop the image using the bounding box cropped_image = enhanced_image.crop((xmin, ymin, xmax, ymax)) # Select the corresponding OCR prompt based on the YOLO label if label.lower() == "front": doc_prompt = front elif label.lower() == "back": doc_prompt = back else: doc_prompt = front # Default to front if unexpected label # Step 5: Convert cropped image to base64 for the message buffered = BytesIO() cropped_image.save(buffered, format="PNG") cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") # Build the message in the expected format for the OCR processor messages = [ { "role": "user", "content": [ {"type": "text", "text": doc_prompt}, {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{cropped_base64}"}}, ], } ] text_prompt = ocr_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Step 6: Prepare inputs and run the OCR model inputs = ocr_processor( text=[text_prompt], images=[cropped_image], padding=True, return_tensors="pt", ) inputs = {k: v.to(device) for k, v in inputs.items()} output = ocr_model.generate( **inputs, temperature=0.8, max_new_tokens=50, num_return_sequences=1, do_sample=True, ) prompt_length = inputs["input_ids"].shape[1] new_tokens = output[:, prompt_length:] text_output = ocr_processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True) extracted_text = text_output[0] # Step 7: Return the cropped (preprocessed) image and extracted text return cropped_image, extracted_text # Define the Gradio Interface iface = gr.Interface( fn=process_image, inputs=gr.Image(type="pil", label="Input Document Image"), outputs=[ gr.Image(type="pil", label="Cropped & Preprocessed Image"), gr.Textbox(label="Extracted Text") ], title="Document OCR with YOLO (Ultralytics) and OLMOCR", description=( "Upload an image of a document. The app enhances the image, uses a YOLO model " "to detect and crop the document (front/back) with a confidence threshold of 0.85, and " "then extracts text using an OCR model with a corresponding prompt." ), ) iface.launch()