ammariii08 commited on
Commit
757bd2d
·
verified ·
1 Parent(s): 92b6d26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -65
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  import base64
 
3
  import numpy as np
4
  from io import BytesIO
5
  from PIL import Image, ImageEnhance
@@ -7,7 +8,9 @@ import gradio as gr
7
 
8
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
9
  from ultralytics import YOLO
10
- from prompts import front, back # prompts.py should define front and back as multiline strings
 
 
11
 
12
  # Load the OCR model and processor once
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -21,66 +24,60 @@ yolo_model = YOLO("best.pt")
21
 
22
  def process_image(input_image):
23
  """
24
- 1. Preprocess the input image.
25
- 2. Run YOLO detection with a confidence threshold of 0.85.
26
- 3. Crop the image according to the detected bounding box.
27
- 4. Choose the corresponding prompt from prompts.py based on the label.
28
- 5. Convert the cropped image to base64 and build the OCR prompt.
29
- 6. Run the OCR model to extract text.
30
- 7. Return the cropped (preprocessed) image and the extracted text.
31
  """
32
- # Step 1: Enhance the image (sharpness, contrast, brightness)
33
  enhanced_image = ImageEnhance.Sharpness(input_image).enhance(2.0)
34
  enhanced_image = ImageEnhance.Contrast(enhanced_image).enhance(1.5)
35
  enhanced_image = ImageEnhance.Brightness(enhanced_image).enhance(0.8)
36
 
37
- # Step 2: Run YOLO detection using ultralytics with confidence threshold = 0.85
38
  image_np = np.array(enhanced_image)
39
  results = yolo_model.predict(source=image_np, conf=0.85)
40
  result = results[0]
41
 
42
- # If no boxes detected, return the enhanced image with an error message.
43
  if len(result.boxes) == 0:
44
  return enhanced_image, "No document detected by YOLO."
45
 
46
- # Step 3: Select the detection with the highest confidence
47
  boxes = result.boxes
48
- confidences = boxes.conf.cpu().numpy() # convert tensor to numpy array
49
  best_index = int(confidences.argmax())
50
  best_box = boxes.xyxy[best_index].cpu().numpy().tolist() # [xmin, ymin, xmax, ymax]
51
  xmin, ymin, xmax, ymax = map(int, best_box)
52
 
53
- # Retrieve the detected label using the model's names mapping
54
- class_idx = int(boxes.cls[best_index].item())
55
- label = yolo_model.names[class_idx]
56
-
57
- # Step 4: Crop the image using the bounding box
58
  cropped_image = enhanced_image.crop((xmin, ymin, xmax, ymax))
59
-
60
- # OPTIMIZATION: Resize the image to reduce processing time
61
- # Calculate aspect ratio to maintain proportions
62
- max_size = (640, 640) # Further reduced from 800x800
63
  cropped_image.thumbnail(max_size, Image.LANCZOS)
64
 
65
- # Select the corresponding OCR prompt based on the YOLO label
66
- if label.lower() == "front":
67
- doc_prompt = front
68
- elif label.lower() == "back":
69
- doc_prompt = back
70
- else:
71
- doc_prompt = front # Default to front if unexpected label
 
 
 
 
72
 
73
- # Step 5: Convert cropped image to base64 for the message
74
  buffered = BytesIO()
75
  cropped_image.save(buffered, format="PNG")
76
  cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
77
 
78
- # Build the message in the expected format for the OCR processor
79
  messages = [
80
  {
81
  "role": "user",
82
  "content": [
83
- {"type": "text", "text": doc_prompt},
84
  {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{cropped_base64}"}},
85
  ],
86
  }
@@ -98,43 +95,20 @@ def process_image(input_image):
98
  )
99
  inputs = {k: v.to(device) for k, v in inputs.items()}
100
 
101
- # FIXED: Generation parameters with proper combinations to avoid warnings
102
- # Choose one of these two approaches:
103
-
104
- # Approach 1: Greedy decoding (fastest)
105
- # output = ocr_model.generate(
106
- # **inputs,
107
- # max_new_tokens=40,
108
- # num_beams=1,
109
- # do_sample=False # Greedy decoding
110
- # )
111
-
112
  output = ocr_model.generate(
113
- **inputs,
114
- temperature=0.2,
115
- max_new_tokens=50,
116
- num_return_sequences=1,
117
- do_sample=True,
118
- )
119
-
120
- # Uncomment this block and comment the above if you want sampling instead
121
- # # Approach 2: Sampling (more natural but slower)
122
- # output = ocr_model.generate(
123
- # **inputs,
124
- # max_new_tokens=40,
125
- # do_sample=True,
126
- # temperature=0.2,
127
- # top_p=0.95,
128
- # top_k=50,
129
- # num_return_sequences=1
130
- # )
131
 
132
  prompt_length = inputs["input_ids"].shape[1]
133
  new_tokens = output[:, prompt_length:]
134
  text_output = ocr_processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
135
  extracted_text = text_output[0]
136
 
137
- # Step 7: Return the cropped (preprocessed) image and extracted text
138
  return cropped_image, extracted_text
139
 
140
  # Define the Gradio Interface
@@ -147,10 +121,10 @@ iface = gr.Interface(
147
  ],
148
  title="Document OCR with YOLO and OLMOCR",
149
  description=(
150
- "Upload an image of a document. The app enhances the image, then extracts text using an OCR model."
 
151
  ),
152
- allow_flagging="never" # Disable flagging to simplify UI
153
  )
154
 
155
- # Enable queue and sharing for Hugging Face Space
156
  iface.launch(share=True)
 
1
  import torch
2
  import base64
3
+ import urllib.request
4
  import numpy as np
5
  from io import BytesIO
6
  from PIL import Image, ImageEnhance
 
8
 
9
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
10
  from ultralytics import YOLO
11
+ from olmocr.data.renderpdf import render_pdf_to_base64png
12
+ from olmocr.prompts import build_finetuning_prompt
13
+ from olmocr.prompts.anchor import get_anchor_text
14
 
15
  # Load the OCR model and processor once
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
24
 
25
  def process_image(input_image):
26
  """
27
+ Process the input image:
28
+ 1. Enhance the image.
29
+ 2. Detect and crop the document using YOLO (conf 0.85).
30
+ 3. Generate an OCR prompt from a sample PDF.
31
+ 4. Run the OCR model using the prompt and the cropped image.
32
+ 5. Return the cropped image and extracted text.
 
33
  """
34
+ # Step 1: Enhance the input image (sharpness, contrast, brightness)
35
  enhanced_image = ImageEnhance.Sharpness(input_image).enhance(2.0)
36
  enhanced_image = ImageEnhance.Contrast(enhanced_image).enhance(1.5)
37
  enhanced_image = ImageEnhance.Brightness(enhanced_image).enhance(0.8)
38
 
39
+ # Step 2: Run YOLO detection with confidence threshold = 0.85
40
  image_np = np.array(enhanced_image)
41
  results = yolo_model.predict(source=image_np, conf=0.85)
42
  result = results[0]
43
 
 
44
  if len(result.boxes) == 0:
45
  return enhanced_image, "No document detected by YOLO."
46
 
47
+ # Select the detection with the highest confidence
48
  boxes = result.boxes
49
+ confidences = boxes.conf.cpu().numpy()
50
  best_index = int(confidences.argmax())
51
  best_box = boxes.xyxy[best_index].cpu().numpy().tolist() # [xmin, ymin, xmax, ymax]
52
  xmin, ymin, xmax, ymax = map(int, best_box)
53
 
54
+ # Step 3: Crop the image using the bounding box and optionally resize it
 
 
 
 
55
  cropped_image = enhanced_image.crop((xmin, ymin, xmax, ymax))
56
+ max_size = (800, 800) # Resize to reduce processing time
 
 
 
57
  cropped_image.thumbnail(max_size, Image.LANCZOS)
58
 
59
+ # Step 4: Build the OCR prompt using a sample PDF
60
+ sample_pdf_url = "https://molmo.allenai.org/paper.pdf"
61
+ sample_pdf_path = "./paper.pdf"
62
+ urllib.request.urlretrieve(sample_pdf_url, sample_pdf_path)
63
+
64
+ # Render page 1 to an image (used only for prompt building)
65
+ sample_image_base64 = render_pdf_to_base64png(sample_pdf_path, 1, target_longest_image_dim=1024)
66
+
67
+ # Extract document metadata and build the prompt
68
+ anchor_text = get_anchor_text(sample_pdf_path, 1, pdf_engine="pdfreport", target_length=4000)
69
+ prompt = build_finetuning_prompt(anchor_text)
70
 
71
+ # Step 5: Build the OCR message using the generated prompt and the cropped image.
72
  buffered = BytesIO()
73
  cropped_image.save(buffered, format="PNG")
74
  cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
75
 
 
76
  messages = [
77
  {
78
  "role": "user",
79
  "content": [
80
+ {"type": "text", "text": prompt},
81
  {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{cropped_base64}"}},
82
  ],
83
  }
 
95
  )
96
  inputs = {k: v.to(device) for k, v in inputs.items()}
97
 
 
 
 
 
 
 
 
 
 
 
 
98
  output = ocr_model.generate(
99
+ **inputs,
100
+ temperature=0.8,
101
+ max_new_tokens=50,
102
+ num_return_sequences=1,
103
+ do_sample=True,
104
+ )
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  prompt_length = inputs["input_ids"].shape[1]
107
  new_tokens = output[:, prompt_length:]
108
  text_output = ocr_processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
109
  extracted_text = text_output[0]
110
 
111
+ # Step 7: Return the cropped image and the extracted text
112
  return cropped_image, extracted_text
113
 
114
  # Define the Gradio Interface
 
121
  ],
122
  title="Document OCR with YOLO and OLMOCR",
123
  description=(
124
+ "Upload an image of a document. The app enhances the image, detects and crops it using YOLO, "
125
+ "then builds an OCR prompt from a sample PDF and extracts text."
126
  ),
127
+ allow_flagging="never"
128
  )
129
 
 
130
  iface.launch(share=True)