import torch import base64 import numpy as np from io import BytesIO from PIL import Image, ImageEnhance import gradio as gr from transformers import AutoProcessor, Qwen2VLForConditionalGeneration from prompts import front, back # prompts.py should define front and back as multiline strings # Load the OCR model and processor once device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ocr_model = Qwen2VLForConditionalGeneration.from_pretrained( "allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16 ).eval().to(device) ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") # Load the YOLO model (using torch.hub and your custom checkpoint "best.pt") yolo_model = torch.hub.load('ultralytics/yolov5', 'custom', path='best.pt', force_reload=False) def process_image(input_image): """ 1. Preprocess the input image. 2. Run YOLO detection to get the document type and bounding box. 3. Crop the image according to the bounding box. 4. Based on the detection label ("front" or "back"), select the corresponding prompt. 5. Convert the cropped image to base64 and build the chat message. 6. Run the OCR model using the constructed prompt and cropped image. 7. Return the cropped image and extracted text. """ # Step 1: Enhance the image (sharpness, contrast, brightness) enhanced_image = ImageEnhance.Sharpness(input_image).enhance(2.0) enhanced_image = ImageEnhance.Contrast(enhanced_image).enhance(1.5) enhanced_image = ImageEnhance.Brightness(enhanced_image).enhance(0.8) # Step 2: Run YOLO detection # Convert PIL image to numpy array (RGB) image_np = np.array(enhanced_image) results = yolo_model(image_np) df = results.pandas().xyxy[0] if df.empty: return enhanced_image, "No document detected by YOLO." # Use the detection with the highest confidence best_row = df.sort_values(by="confidence", ascending=False).iloc[0] label = best_row['name'] bbox = (int(best_row['xmin']), int(best_row['ymin']), int(best_row['xmax']), int(best_row['ymax'])) # Step 3: Crop the image using the bounding box cropped_image = enhanced_image.crop(bbox) # Step 4: Select the prompt based on YOLO label if label.lower() == "front": doc_prompt = front elif label.lower() == "back": doc_prompt = back else: doc_prompt = front # default to front if label is unexpected # Step 5: Convert cropped image to base64 for the message buffered = BytesIO() cropped_image.save(buffered, format="PNG") cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") # Build the message in the expected format for the OCR processor messages = [ { "role": "user", "content": [ {"type": "text", "text": doc_prompt}, {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{cropped_base64}"}}, ], } ] text_prompt = ocr_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Step 6: Prepare inputs and run the OCR model inputs = ocr_processor( text=[text_prompt], images=[cropped_image], padding=True, return_tensors="pt", ) inputs = {k: v.to(device) for k, v in inputs.items()} output = ocr_model.generate( **inputs, temperature=0.8, max_new_tokens=50, num_return_sequences=1, do_sample=True, ) prompt_length = inputs["input_ids"].shape[1] new_tokens = output[:, prompt_length:] text_output = ocr_processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True) extracted_text = text_output[0] # Step 7: Return the cropped (preprocessed) image and extracted text return cropped_image, extracted_text # Define the Gradio Interface iface = gr.Interface( fn=process_image, inputs=gr.Image(type="pil", label="Input Document Image"), outputs=[ gr.Image(type="pil", label="Cropped & Preprocessed Image"), gr.Textbox(label="Extracted Text") ], title="Document OCR with YOLO and OLMOCR", description=( "Upload an image of a document. The app enhances the image, uses a YOLO model " "to detect and crop the document (front/back), and then extracts text using the OCR model " "with a corresponding prompt." ), ) iface.launch()