ammariii08 commited on
Commit
1345fd5
·
verified ·
1 Parent(s): 757bd2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -38
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import torch
2
  import base64
3
- import urllib.request
4
  import numpy as np
5
  from io import BytesIO
6
  from PIL import Image, ImageEnhance
@@ -8,10 +7,12 @@ import gradio as gr
8
 
9
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
10
  from ultralytics import YOLO
11
- from olmocr.data.renderpdf import render_pdf_to_base64png
 
12
  from olmocr.prompts import build_finetuning_prompt
13
  from olmocr.prompts.anchor import get_anchor_text
14
 
 
15
  # Load the OCR model and processor once
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -24,55 +25,88 @@ yolo_model = YOLO("best.pt")
24
 
25
  def process_image(input_image):
26
  """
27
- Process the input image:
28
- 1. Enhance the image.
29
- 2. Detect and crop the document using YOLO (conf ≥ 0.85).
30
- 3. Generate an OCR prompt from a sample PDF.
31
- 4. Run the OCR model using the prompt and the cropped image.
32
- 5. Return the cropped image and extracted text.
 
33
  """
34
- # Step 1: Enhance the input image (sharpness, contrast, brightness)
35
  enhanced_image = ImageEnhance.Sharpness(input_image).enhance(2.0)
36
  enhanced_image = ImageEnhance.Contrast(enhanced_image).enhance(1.5)
37
  enhanced_image = ImageEnhance.Brightness(enhanced_image).enhance(0.8)
38
 
39
- # Step 2: Run YOLO detection with confidence threshold = 0.85
40
  image_np = np.array(enhanced_image)
41
  results = yolo_model.predict(source=image_np, conf=0.85)
42
  result = results[0]
43
 
 
44
  if len(result.boxes) == 0:
45
  return enhanced_image, "No document detected by YOLO."
46
 
47
- # Select the detection with the highest confidence
48
  boxes = result.boxes
49
- confidences = boxes.conf.cpu().numpy()
50
  best_index = int(confidences.argmax())
51
  best_box = boxes.xyxy[best_index].cpu().numpy().tolist() # [xmin, ymin, xmax, ymax]
52
  xmin, ymin, xmax, ymax = map(int, best_box)
53
 
54
- # Step 3: Crop the image using the bounding box and optionally resize it
55
- cropped_image = enhanced_image.crop((xmin, ymin, xmax, ymax))
56
- max_size = (800, 800) # Resize to reduce processing time
57
- cropped_image.thumbnail(max_size, Image.LANCZOS)
58
 
59
- # Step 4: Build the OCR prompt using a sample PDF
60
- sample_pdf_url = "https://molmo.allenai.org/paper.pdf"
61
- sample_pdf_path = "./paper.pdf"
62
- urllib.request.urlretrieve(sample_pdf_url, sample_pdf_path)
63
 
64
- # Render page 1 to an image (used only for prompt building)
65
- sample_image_base64 = render_pdf_to_base64png(sample_pdf_path, 1, target_longest_image_dim=1024)
 
 
66
 
67
- # Extract document metadata and build the prompt
68
- anchor_text = get_anchor_text(sample_pdf_path, 1, pdf_engine="pdfreport", target_length=4000)
69
- prompt = build_finetuning_prompt(anchor_text)
 
 
 
 
70
 
71
- # Step 5: Build the OCR message using the generated prompt and the cropped image.
72
  buffered = BytesIO()
73
  cropped_image.save(buffered, format="PNG")
74
  cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  messages = [
77
  {
78
  "role": "user",
@@ -86,7 +120,7 @@ def process_image(input_image):
86
  messages, tokenize=False, add_generation_prompt=True
87
  )
88
 
89
- # Step 6: Prepare inputs and run the OCR model
90
  inputs = ocr_processor(
91
  text=[text_prompt],
92
  images=[cropped_image],
@@ -95,20 +129,43 @@ def process_image(input_image):
95
  )
96
  inputs = {k: v.to(device) for k, v in inputs.items()}
97
 
98
- output = ocr_model.generate(
99
- **inputs,
100
- temperature=0.8,
101
- max_new_tokens=50,
102
- num_return_sequences=1,
103
- do_sample=True,
104
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  prompt_length = inputs["input_ids"].shape[1]
107
  new_tokens = output[:, prompt_length:]
108
  text_output = ocr_processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
109
  extracted_text = text_output[0]
110
 
111
- # Step 7: Return the cropped image and the extracted text
112
  return cropped_image, extracted_text
113
 
114
  # Define the Gradio Interface
@@ -121,10 +178,10 @@ iface = gr.Interface(
121
  ],
122
  title="Document OCR with YOLO and OLMOCR",
123
  description=(
124
- "Upload an image of a document. The app enhances the image, detects and crops it using YOLO, "
125
- "then builds an OCR prompt from a sample PDF and extracts text."
126
  ),
127
- allow_flagging="never"
128
  )
129
 
 
130
  iface.launch(share=True)
 
1
  import torch
2
  import base64
 
3
  import numpy as np
4
  from io import BytesIO
5
  from PIL import Image, ImageEnhance
 
7
 
8
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
9
  from ultralytics import YOLO
10
+ from prompts import front, back # prompts.py should define front and back as multiline strings
11
+
12
  from olmocr.prompts import build_finetuning_prompt
13
  from olmocr.prompts.anchor import get_anchor_text
14
 
15
+
16
  # Load the OCR model and processor once
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(
 
25
 
26
  def process_image(input_image):
27
  """
28
+ 1. Preprocess the input image.
29
+ 2. Run YOLO detection with a confidence threshold of 0.85.
30
+ 3. Crop the image according to the detected bounding box.
31
+ 4. Choose the corresponding prompt from prompts.py based on the label.
32
+ 5. Convert the cropped image to base64 and build the OCR prompt.
33
+ 6. Run the OCR model to extract text.
34
+ 7. Return the cropped (preprocessed) image and the extracted text.
35
  """
36
+ # Step 1: Enhance the image (sharpness, contrast, brightness)
37
  enhanced_image = ImageEnhance.Sharpness(input_image).enhance(2.0)
38
  enhanced_image = ImageEnhance.Contrast(enhanced_image).enhance(1.5)
39
  enhanced_image = ImageEnhance.Brightness(enhanced_image).enhance(0.8)
40
 
41
+ # Step 2: Run YOLO detection using ultralytics with confidence threshold = 0.85
42
  image_np = np.array(enhanced_image)
43
  results = yolo_model.predict(source=image_np, conf=0.85)
44
  result = results[0]
45
 
46
+ # If no boxes detected, return the enhanced image with an error message.
47
  if len(result.boxes) == 0:
48
  return enhanced_image, "No document detected by YOLO."
49
 
50
+ # Step 3: Select the detection with the highest confidence
51
  boxes = result.boxes
52
+ confidences = boxes.conf.cpu().numpy() # convert tensor to numpy array
53
  best_index = int(confidences.argmax())
54
  best_box = boxes.xyxy[best_index].cpu().numpy().tolist() # [xmin, ymin, xmax, ymax]
55
  xmin, ymin, xmax, ymax = map(int, best_box)
56
 
57
+ # Retrieve the detected label using the model's names mapping
58
+ class_idx = int(boxes.cls[best_index].item())
59
+ label = yolo_model.names[class_idx]
 
60
 
61
+ # Step 4: Crop the image using the bounding box
62
+ cropped_image = enhanced_image.crop((xmin, ymin, xmax, ymax))
 
 
63
 
64
+ # OPTIMIZATION: Resize the image to reduce processing time
65
+ # Calculate aspect ratio to maintain proportions
66
+ max_size = (640, 640) # Further reduced from 800x800
67
+ cropped_image.thumbnail(max_size, Image.LANCZOS)
68
 
69
+ # # Select the corresponding OCR prompt based on the YOLO label
70
+ # if label.lower() == "front":
71
+ # doc_prompt = front
72
+ # elif label.lower() == "back":
73
+ # doc_prompt = back
74
+ # else:
75
+ # doc_prompt = front # Default to front if unexpected label
76
 
77
+ # Step 5: Convert cropped image to base64 for the message
78
  buffered = BytesIO()
79
  cropped_image.save(buffered, format="PNG")
80
  cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
81
 
82
+ # # Build the message in the expected format for the OCR processor
83
+ # messages = [
84
+ # {
85
+ # "role": "user",
86
+ # "content": [
87
+ # {"type": "text", "text": doc_prompt},
88
+ # {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{cropped_base64}"}},
89
+ # ],
90
+ # }
91
+ # ]
92
+ # text_prompt = ocr_processor.apply_chat_template(
93
+ # messages, tokenize=False, add_generation_prompt=True
94
+ # )
95
+
96
+ # # Step 6: Prepare inputs and run the OCR model
97
+ # inputs = ocr_processor(
98
+ # text=[text_prompt],
99
+ # images=[cropped_image],
100
+ # padding=True,
101
+ # return_tensors="pt",
102
+ # )
103
+ # inputs = {k: v.to(device) for k, v in inputs.items()}
104
+
105
+
106
+ anchor_text = extract_anchor_text_from_image(cropped_image) # You'll need to implement this
107
+ prompt = build_finetuning_prompt(anchor_text)
108
+
109
+ # Build the message in the expected format for the OCR processor
110
  messages = [
111
  {
112
  "role": "user",
 
120
  messages, tokenize=False, add_generation_prompt=True
121
  )
122
 
123
+ # Rest of your code for processing with OCR
124
  inputs = ocr_processor(
125
  text=[text_prompt],
126
  images=[cropped_image],
 
129
  )
130
  inputs = {k: v.to(device) for k, v in inputs.items()}
131
 
132
+ # FIXED: Generation parameters with proper combinations to avoid warnings
133
+ # Choose one of these two approaches:
134
+
135
+ # Approach 1: Greedy decoding (fastest)
136
+ # output = ocr_model.generate(
137
+ # **inputs,
138
+ # max_new_tokens=40,
139
+ # num_beams=1,
140
+ # do_sample=False # Greedy decoding
141
+ # )
142
+
143
+ output = model.generate(
144
+ **inputs,
145
+ temperature=0.2,
146
+ max_new_tokens=50,
147
+ num_return_sequences=1,
148
+ do_sample=True,
149
+ )
150
+
151
+ # Uncomment this block and comment the above if you want sampling instead
152
+ # # Approach 2: Sampling (more natural but slower)
153
+ # output = ocr_model.generate(
154
+ # **inputs,
155
+ # max_new_tokens=40,
156
+ # do_sample=True,
157
+ # temperature=0.2,
158
+ # top_p=0.95,
159
+ # top_k=50,
160
+ # num_return_sequences=1
161
+ # )
162
 
163
  prompt_length = inputs["input_ids"].shape[1]
164
  new_tokens = output[:, prompt_length:]
165
  text_output = ocr_processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
166
  extracted_text = text_output[0]
167
 
168
+ # Step 7: Return the cropped (preprocessed) image and extracted text
169
  return cropped_image, extracted_text
170
 
171
  # Define the Gradio Interface
 
178
  ],
179
  title="Document OCR with YOLO and OLMOCR",
180
  description=(
181
+ "Upload an image of a document. The app enhances the image, then extracts text using an OCR model."
 
182
  ),
183
+ allow_flagging="never" # Disable flagging to simplify UI
184
  )
185
 
186
+ # Enable queue and sharing for Hugging Face Space
187
  iface.launch(share=True)