ammariii08 commited on
Commit
d1385e8
·
verified ·
1 Parent(s): e12d8b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -25
app.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image, ImageEnhance
6
  import gradio as gr
7
 
8
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
 
9
  from prompts import front, back # prompts.py should define front and back as multiline strings
10
 
11
  # Load the OCR model and processor once
@@ -15,49 +16,54 @@ ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(
15
  ).eval().to(device)
16
  ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
17
 
18
- # Load the YOLO model (using torch.hub and your custom checkpoint "best.pt")
19
- yolo_model = torch.hub.load('ultralytics/yolov5', 'custom', path='best.pt', force_reload=False)
20
 
21
  def process_image(input_image):
22
  """
23
  1. Preprocess the input image.
24
- 2. Run YOLO detection to get the document type and bounding box.
25
- 3. Crop the image according to the bounding box.
26
- 4. Based on the detection label ("front" or "back"), select the corresponding prompt.
27
- 5. Convert the cropped image to base64 and build the chat message.
28
- 6. Run the OCR model using the constructed prompt and cropped image.
29
- 7. Return the cropped image and extracted text.
30
  """
31
  # Step 1: Enhance the image (sharpness, contrast, brightness)
32
  enhanced_image = ImageEnhance.Sharpness(input_image).enhance(2.0)
33
  enhanced_image = ImageEnhance.Contrast(enhanced_image).enhance(1.5)
34
  enhanced_image = ImageEnhance.Brightness(enhanced_image).enhance(0.8)
35
 
36
- # Step 2: Run YOLO detection
37
- # Convert PIL image to numpy array (RGB)
38
  image_np = np.array(enhanced_image)
39
- results = yolo_model(image_np)
40
- df = results.pandas().xyxy[0]
41
 
42
- if df.empty:
 
43
  return enhanced_image, "No document detected by YOLO."
44
 
45
- # Use the detection with the highest confidence
46
- best_row = df.sort_values(by="confidence", ascending=False).iloc[0]
47
- label = best_row['name']
48
- bbox = (int(best_row['xmin']), int(best_row['ymin']),
49
- int(best_row['xmax']), int(best_row['ymax']))
 
50
 
51
- # Step 3: Crop the image using the bounding box
52
- cropped_image = enhanced_image.crop(bbox)
 
53
 
54
- # Step 4: Select the prompt based on YOLO label
 
 
 
55
  if label.lower() == "front":
56
  doc_prompt = front
57
  elif label.lower() == "back":
58
  doc_prompt = back
59
  else:
60
- doc_prompt = front # default to front if label is unexpected
61
 
62
  # Step 5: Convert cropped image to base64 for the message
63
  buffered = BytesIO()
@@ -110,11 +116,11 @@ iface = gr.Interface(
110
  gr.Image(type="pil", label="Cropped & Preprocessed Image"),
111
  gr.Textbox(label="Extracted Text")
112
  ],
113
- title="Document OCR with YOLO and OLMOCR",
114
  description=(
115
  "Upload an image of a document. The app enhances the image, uses a YOLO model "
116
- "to detect and crop the document (front/back), and then extracts text using the OCR model "
117
- "with a corresponding prompt."
118
  ),
119
  )
120
 
 
6
  import gradio as gr
7
 
8
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
9
+ from ultralytics import YOLO
10
  from prompts import front, back # prompts.py should define front and back as multiline strings
11
 
12
  # Load the OCR model and processor once
 
16
  ).eval().to(device)
17
  ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
18
 
19
+ # Load the YOLO model using ultralytics (ensure best.pt is in your working directory)
20
+ yolo_model = YOLO("best.pt")
21
 
22
  def process_image(input_image):
23
  """
24
  1. Preprocess the input image.
25
+ 2. Run YOLO detection with a confidence threshold of 0.85.
26
+ 3. Crop the image according to the detected bounding box.
27
+ 4. Choose the corresponding prompt from prompts.py based on the label.
28
+ 5. Convert the cropped image to base64 and build the OCR prompt.
29
+ 6. Run the OCR model to extract text.
30
+ 7. Return the cropped (preprocessed) image and the extracted text.
31
  """
32
  # Step 1: Enhance the image (sharpness, contrast, brightness)
33
  enhanced_image = ImageEnhance.Sharpness(input_image).enhance(2.0)
34
  enhanced_image = ImageEnhance.Contrast(enhanced_image).enhance(1.5)
35
  enhanced_image = ImageEnhance.Brightness(enhanced_image).enhance(0.8)
36
 
37
+ # Step 2: Run YOLO detection using ultralytics with confidence threshold = 0.85
 
38
  image_np = np.array(enhanced_image)
39
+ results = yolo_model.predict(source=image_np, conf=0.85)
40
+ result = results[0]
41
 
42
+ # If no boxes detected, return the enhanced image with an error message.
43
+ if len(result.boxes) == 0:
44
  return enhanced_image, "No document detected by YOLO."
45
 
46
+ # Step 3: Select the detection with the highest confidence
47
+ boxes = result.boxes
48
+ confidences = boxes.conf.cpu().numpy() # convert tensor to numpy array
49
+ best_index = int(confidences.argmax())
50
+ best_box = boxes.xyxy[best_index].cpu().numpy().tolist() # [xmin, ymin, xmax, ymax]
51
+ xmin, ymin, xmax, ymax = map(int, best_box)
52
 
53
+ # Retrieve the detected label using the model's names mapping
54
+ class_idx = int(boxes.cls[best_index].item())
55
+ label = yolo_model.names[class_idx]
56
 
57
+ # Step 4: Crop the image using the bounding box
58
+ cropped_image = enhanced_image.crop((xmin, ymin, xmax, ymax))
59
+
60
+ # Select the corresponding OCR prompt based on the YOLO label
61
  if label.lower() == "front":
62
  doc_prompt = front
63
  elif label.lower() == "back":
64
  doc_prompt = back
65
  else:
66
+ doc_prompt = front # Default to front if unexpected label
67
 
68
  # Step 5: Convert cropped image to base64 for the message
69
  buffered = BytesIO()
 
116
  gr.Image(type="pil", label="Cropped & Preprocessed Image"),
117
  gr.Textbox(label="Extracted Text")
118
  ],
119
+ title="Document OCR with YOLO (Ultralytics) and OLMOCR",
120
  description=(
121
  "Upload an image of a document. The app enhances the image, uses a YOLO model "
122
+ "to detect and crop the document (front/back) with a confidence threshold of 0.85, and "
123
+ "then extracts text using an OCR model with a corresponding prompt."
124
  ),
125
  )
126