ammariii08 commited on
Commit
7832c3b
·
verified ·
1 Parent(s): c2fa9d2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -0
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # Set up caching for Hugging Face models
3
+ os.environ["TRANSFORMERS_CACHE"] = "./.cache"
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Disable GPU usage
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image, ImageEnhance
11
+ from ultralytics import YOLO
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ import torchvision.transforms as T
14
+ from transformers import AutoModel, AutoTokenizer
15
+ import gc
16
+
17
+ # Import prompts from prompts.py
18
+ from prompts import front as front_prompt, back as back_prompt
19
+
20
+ # ---------------------------
21
+ # HUGGING FACE MODEL SETUP (CPU)
22
+ # ---------------------------
23
+ path = "OpenGVLab/InternVL2_5-2B"
24
+ cache_folder = "./.cache"
25
+
26
+ # Load the Vision AI model and tokenizer globally.
27
+ model = AutoModel.from_pretrained(
28
+ path,
29
+ cache_dir=cache_folder,
30
+ torch_dtype=torch.float32,
31
+ trust_remote_code=True
32
+ ).eval().to("cpu")
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained(
35
+ path,
36
+ cache_dir=cache_folder,
37
+ trust_remote_code=True,
38
+ use_fast=False
39
+ )
40
+
41
+
42
+ # ---------------------------
43
+ # YOLO MODEL INITIALIZATION
44
+ # ---------------------------
45
+ model_path = "best.pt"
46
+ modelY = YOLO(model_path)
47
+ modelY.to('cpu') # Explicitly move model to CPU
48
+
49
+ def preprocessing(image):
50
+ """Apply enhancement filters and resize."""
51
+ image = Image.fromarray(np.array(image))
52
+ image = ImageEnhance.Sharpness(image).enhance(2.0) # Increase sharpness
53
+ image = ImageEnhance.Contrast(image).enhance(1.5) # Increase contrast
54
+ image = ImageEnhance.Brightness(image).enhance(0.8) # Reduce brightness
55
+
56
+ width = 448
57
+ aspect_ratio = image.height / image.width
58
+ height = int(width * aspect_ratio)
59
+ image = image.resize((width, height))
60
+ return image
61
+
62
+ def imageRotation(image):
63
+ """Rotate image if height exceeds width."""
64
+ if image.height > image.width:
65
+ return image.rotate(90, expand=True)
66
+ return image
67
+
68
+ def detect_document(image):
69
+ """Detect front/back of the document using YOLO."""
70
+ image_np = np.array(image)
71
+ results = modelY(image_np, conf=0.85, device='cpu')
72
+
73
+ detected_classes = set()
74
+ labels = []
75
+ bounding_boxes = []
76
+
77
+ for result in results:
78
+ for box in result.boxes:
79
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
80
+ conf = box.conf[0]
81
+ cls = int(box.cls[0])
82
+ class_name = modelY.names[cls]
83
+
84
+ detected_classes.add(class_name)
85
+ label = f"{class_name} {conf:.2f}"
86
+ labels.append(label)
87
+ bounding_boxes.append((x1, y1, x2, y2, class_name, conf))
88
+
89
+ cv2.rectangle(image_np, (x1, y1), (x2, y2), (0, 255, 0), 2)
90
+ cv2.putText(image_np, label, (x1, y1 - 10),
91
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
92
+
93
+ possible_classes = {"front", "back"}
94
+ missing_classes = possible_classes - detected_classes
95
+ if missing_classes:
96
+ labels.append(f"Missing: {', '.join(missing_classes)}")
97
+
98
+ return Image.fromarray(image_np), labels, bounding_boxes
99
+
100
+ def crop_image(image, bounding_boxes):
101
+ """Crop detected bounding boxes from the image."""
102
+ cropped_images = {}
103
+ image_np = np.array(image)
104
+ for (x1, y1, x2, y2, class_name, conf) in bounding_boxes:
105
+ cropped = image_np[y1:y2, x1:x2]
106
+ cropped_images[class_name] = Image.fromarray(cropped)
107
+ return cropped_images
108
+
109
+ # ---------------------------
110
+ # VISION AI API FUNCTIONS
111
+ # ---------------------------
112
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
113
+ IMAGENET_STD = (0.229, 0.224, 0.225)
114
+
115
+ def build_transform(input_size):
116
+ transform = T.Compose([
117
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
118
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
119
+ T.ToTensor(),
120
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
121
+ ])
122
+ return transform
123
+
124
+ def load_image(image_file):
125
+ transform = build_transform(input_size=448)
126
+ pixel_values = transform(image_file).unsqueeze(0) # Add batch dimension
127
+ return pixel_values
128
+
129
+
130
+ def vision_ai_api(image, doc_type):
131
+ """Run the model using a dynamic prompt based on detected doc type."""
132
+ pixel_values = load_image(image).to(torch.float32).to("cpu")
133
+ generation_config = dict(max_new_tokens=512, do_sample=True)
134
+
135
+ question = front_prompt if doc_type == "front" else back_prompt if doc_type == "back" else "Please provide document details."
136
+
137
+ print("Before requesting model...")
138
+ response = model.chat(tokenizer, pixel_values, question, generation_config)
139
+ print("After requesting model...", response)
140
+
141
+ # Clear memory
142
+ del pixel_values
143
+ gc.collect() # Force garbage collection
144
+ torch.cuda.empty_cache()
145
+
146
+ return f'Assistant: {response}'
147
+
148
+ # ---------------------------
149
+ # PREDICTION PIPELINE
150
+ # ---------------------------
151
+ def predict(image):
152
+ """Pipeline: Preprocess → Detect → Crop → Vision AI API call."""
153
+ processed_image = preprocessing(image)
154
+ rotated_image = imageRotation(processed_image)
155
+ detected_image, labels, bounding_boxes = detect_document(rotated_image)
156
+ cropped_images = crop_image(rotated_image, bounding_boxes)
157
+
158
+ front_result, back_result = None, None
159
+ if "front" in cropped_images:
160
+ front_result = vision_ai_api(cropped_images["front"], "front")
161
+ if "back" in cropped_images:
162
+ back_result = vision_ai_api(cropped_images["back"], "back")
163
+
164
+ api_results = {"front": front_result, "back": back_result}
165
+ single_image = cropped_images.get("front") or cropped_images.get("back") or detected_image
166
+ return single_image, labels, api_results
167
+
168
+ # ---------------------------
169
+ # GRADIO INTERFACE LAUNCH
170
+ # ---------------------------
171
+ iface = gr.Interface(
172
+ fn=predict,
173
+ inputs="image",
174
+ outputs=["image", "text", "json"],
175
+ title="License Field Detection (Front & Back Card)"
176
+ )
177
+
178
+ iface.launch()