File size: 5,796 Bytes
7a18006
e409fcf
 
 
 
7a18006
 
e409fcf
d1385e8
1345fd5
 
e409fcf
 
 
 
 
 
7a18006
d1385e8
 
7a18006
e409fcf
7a18006
1345fd5
 
 
 
 
 
 
7a18006
1345fd5
e409fcf
 
 
 
1345fd5
e409fcf
d1385e8
 
e409fcf
1345fd5
d1385e8
e409fcf
 
1345fd5
d1385e8
1345fd5
d1385e8
 
 
e409fcf
1345fd5
 
 
1d1b60e
1345fd5
 
757bd2d
1345fd5
 
8010d4b
1345fd5
757bd2d
45ec7b5
 
 
 
 
 
 
e409fcf
1345fd5
e409fcf
 
 
 
1345fd5
e409fcf
 
 
 
45ec7b5
e409fcf
 
 
 
 
 
 
 
45ec7b5
e409fcf
 
 
 
 
 
 
 
1345fd5
 
 
 
 
 
 
 
 
 
 
717c9f8
1345fd5
8010d4b
45ec7b5
1345fd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c26fe7
e409fcf
 
 
 
 
1345fd5
e409fcf
 
 
 
 
 
 
 
 
 
3a08c60
e409fcf
1345fd5
e409fcf
1345fd5
e409fcf
7a18006
1345fd5
1d1b60e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import base64
import numpy as np
from io import BytesIO
from PIL import Image, ImageEnhance
import gradio as gr

from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from ultralytics import YOLO
from prompts import front, back  # prompts.py should define front and back as multiline strings

# Load the OCR model and processor once
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(
    "allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16
).eval().to(device)
ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

# Load the YOLO model using ultralytics (ensure best.pt is in your working directory)
yolo_model = YOLO("best.pt")

def process_image(input_image):
    """
    1. Preprocess the input image.
    2. Run YOLO detection with a confidence threshold of 0.85.
    3. Crop the image according to the detected bounding box.
    4. Choose the corresponding prompt from prompts.py based on the label.
    5. Convert the cropped image to base64 and build the OCR prompt.
    6. Run the OCR model to extract text.
    7. Return the cropped (preprocessed) image and the extracted text.
    """
    # Step 1: Enhance the image (sharpness, contrast, brightness)
    enhanced_image = ImageEnhance.Sharpness(input_image).enhance(2.0)
    enhanced_image = ImageEnhance.Contrast(enhanced_image).enhance(1.5)
    enhanced_image = ImageEnhance.Brightness(enhanced_image).enhance(0.8)
    
    # Step 2: Run YOLO detection using ultralytics with confidence threshold = 0.85
    image_np = np.array(enhanced_image)
    results = yolo_model.predict(source=image_np, conf=0.85)
    result = results[0]
    
    # If no boxes detected, return the enhanced image with an error message.
    if len(result.boxes) == 0:
        return enhanced_image, "No document detected by YOLO."
    
    # Step 3: Select the detection with the highest confidence
    boxes = result.boxes
    confidences = boxes.conf.cpu().numpy()  # convert tensor to numpy array
    best_index = int(confidences.argmax())
    best_box = boxes.xyxy[best_index].cpu().numpy().tolist()  # [xmin, ymin, xmax, ymax]
    xmin, ymin, xmax, ymax = map(int, best_box)
    
    # Retrieve the detected label using the model's names mapping
    class_idx = int(boxes.cls[best_index].item())
    label = yolo_model.names[class_idx]
    
    # Step 4: Crop the image using the bounding box
    cropped_image = enhanced_image.crop((xmin, ymin, xmax, ymax))
    
    # OPTIMIZATION: Resize the image to reduce processing time
    # Calculate aspect ratio to maintain proportions
    max_size = (800, 800)  # Further reduced from 800x800
    cropped_image.thumbnail(max_size, Image.LANCZOS)
    
    # Select the corresponding OCR prompt based on the YOLO label
    if label.lower() == "front":
        doc_prompt = front
    elif label.lower() == "back":
        doc_prompt = back
    else:
        doc_prompt = front  # Default to front if unexpected label
    
    # Step 5: Convert cropped image to base64 for the message
    buffered = BytesIO()
    cropped_image.save(buffered, format="PNG")
    cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
    
    # Build the message in the expected format for the OCR processor
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": doc_prompt},
                {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{cropped_base64}"}},
            ],
        }
    ]
    text_prompt = ocr_processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    # Step 6: Prepare inputs and run the OCR model
    inputs = ocr_processor(
        text=[text_prompt],
        images=[cropped_image],
        padding=True,
        return_tensors="pt",
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # FIXED: Generation parameters with proper combinations to avoid warnings
    # Choose one of these two approaches:
    
    # Approach 1: Greedy decoding (fastest)
    # output = ocr_model.generate(
    #     **inputs,
    #     max_new_tokens=40,
    #     num_beams=1,
    #     do_sample=False  # Greedy decoding
    # )

    output = ocr_model.generate(
            **inputs,
            temperature=0.8,
            max_new_tokens=1024,
            num_return_sequences=1,
            do_sample=True,
        )
    
    # Uncomment this block and comment the above if you want sampling instead
    # # Approach 2: Sampling (more natural but slower)
    # output = ocr_model.generate(
    #     **inputs,
    #     max_new_tokens=40,
    #     do_sample=True,
    #     temperature=0.2,
    #     top_p=0.95,
    #     top_k=50,
    #     num_return_sequences=1
    # )
    
    prompt_length = inputs["input_ids"].shape[1]
    new_tokens = output[:, prompt_length:]
    text_output = ocr_processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
    extracted_text = text_output[0]
    
    # Step 7: Return the cropped (preprocessed) image and extracted text
    return cropped_image, extracted_text

# Define the Gradio Interface
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil", label="Input Document Image"),
    outputs=[
        gr.Image(type="pil", label="Cropped & Preprocessed Image"),
        gr.Textbox(label="Extracted Text")
    ],
    title="Document OCR with YOLO and OLMOCR",
    description=(
        "Upload an image of a document. The app enhances the image, then extracts text using an OCR model."
    ),
    allow_flagging="never"  # Disable flagging to simplify UI
)

# Enable queue and sharing for Hugging Face Space
iface.launch(share=True)