Spaces:
Runtime error
Runtime error
FEAT: Basic VGG inference added
Browse files- app.py +113 -67
- models/vgg_4.pt +3 -0
- models/vgg_50.pt +3 -0
- models/vgg_8.pt +3 -0
- text_images/0.png +0 -0
- text_images/1.png +0 -0
- text_images/2.png +0 -0
- text_images/3.png +0 -0
- text_images/4.png +0 -0
app.py
CHANGED
@@ -9,7 +9,26 @@ import os
|
|
9 |
import torch
|
10 |
from ocr_libs import tess_ocr
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
target_size = (400, 512)
|
14 |
# image = Image.open(img_name).convert("L")
|
15 |
w, h = image.size
|
@@ -32,9 +51,19 @@ def process_image(image):
|
|
32 |
# image = torch.tensor(image)
|
33 |
return image
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
@st.cache_resource
|
36 |
-
def load_models():
|
37 |
-
model_paths = ["models/prep_50.pt", "models/prep_4.pt", "models/prep_4.pt"]
|
38 |
models = [torch.load(mpath, map_location='cpu').eval() for mpath in model_paths]
|
39 |
return models
|
40 |
|
@@ -44,74 +73,91 @@ def load_ocr():
|
|
44 |
def get_text_boxes(_ocr, image):
|
45 |
return _ocr.detect_text(image)
|
46 |
|
47 |
-
def clean_image(image, model):
|
48 |
img_out = model(image.unsqueeze(0))
|
49 |
-
img_out = transforms.ToPILImage()(img_out.reshape(
|
50 |
return img_out
|
51 |
|
52 |
ocr = load_ocr()
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
if
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
|
117 |
|
|
|
9 |
import torch
|
10 |
from ocr_libs import tess_ocr
|
11 |
|
12 |
+
class PadWhite(object):
|
13 |
+
def __init__(self, size):
|
14 |
+
assert isinstance(size, (int, tuple))
|
15 |
+
if isinstance(size, tuple):
|
16 |
+
self.height, self.width = size
|
17 |
+
elif isinstance(size, int):
|
18 |
+
self.height = self.width = size
|
19 |
+
|
20 |
+
def __call__(self, img):
|
21 |
+
if img.size[0] > self.width or img.size[1] > self.height:
|
22 |
+
img.thumbnail((self.width, self.height))
|
23 |
+
delta_width = self.width - img.size[0]
|
24 |
+
delta_height = self.height - img.size[1]
|
25 |
+
pad_width = delta_width // 2
|
26 |
+
pad_height = delta_height // 2
|
27 |
+
padding = (pad_width, pad_height, delta_width -
|
28 |
+
pad_width, delta_height-pad_height)
|
29 |
+
return ImageOps.expand(img, padding, fill=255)
|
30 |
+
|
31 |
+
def process_image_pos(image):
|
32 |
target_size = (400, 512)
|
33 |
# image = Image.open(img_name).convert("L")
|
34 |
w, h = image.size
|
|
|
51 |
# image = torch.tensor(image)
|
52 |
return image
|
53 |
|
54 |
+
def process_image_vgg(image):
|
55 |
+
input_size = (32, 128)
|
56 |
+
transform = transforms.Compose([
|
57 |
+
PadWhite(input_size),
|
58 |
+
transforms.ToTensor(),
|
59 |
+
])
|
60 |
+
image = transform(image)
|
61 |
+
return image
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
@st.cache_resource
|
66 |
+
def load_models(model_paths):
|
|
|
67 |
models = [torch.load(mpath, map_location='cpu').eval() for mpath in model_paths]
|
68 |
return models
|
69 |
|
|
|
73 |
def get_text_boxes(_ocr, image):
|
74 |
return _ocr.detect_text(image)
|
75 |
|
76 |
+
def clean_image(image, model, shape=(400, 512)):
|
77 |
img_out = model(image.unsqueeze(0))
|
78 |
+
img_out = transforms.ToPILImage()(img_out.reshape(*shape).detach())
|
79 |
return img_out
|
80 |
|
81 |
ocr = load_ocr()
|
82 |
+
|
83 |
+
|
84 |
+
dataset = st.radio(
|
85 |
+
"Choose image type ",
|
86 |
+
('POS', 'VGG'))
|
87 |
+
if dataset == "POS":
|
88 |
+
model_paths = ["models/prep_50.pt", "models/prep_8.pt", "models/prep_4.pt"]
|
89 |
+
process_image = process_image_pos
|
90 |
+
image_folder = "receipt_images"
|
91 |
+
shape = (400, 512)
|
92 |
+
|
93 |
+
elif dataset == "VGG":
|
94 |
+
model_paths = ["models/vgg_50.pt", "models/vgg_8.pt", "models/vgg_4.pt"]
|
95 |
+
process_image = process_image_vgg
|
96 |
+
image_folder = "text_images"
|
97 |
+
shape = (32, 128)
|
98 |
+
|
99 |
+
if dataset:
|
100 |
+
NUM_IMAGES = 3
|
101 |
+
|
102 |
+
image_paths = [f"{image_folder}/{i}.png" for i in range(NUM_IMAGES)]
|
103 |
+
img = None
|
104 |
+
|
105 |
+
img_index = image_select(
|
106 |
+
label="Select Image",
|
107 |
+
images=image_paths,
|
108 |
+
use_container_width=False,
|
109 |
+
index=-1,
|
110 |
+
return_value="index"
|
111 |
+
)
|
112 |
+
img = None
|
113 |
+
with st.form("my-form", clear_on_submit=True):
|
114 |
+
image_file = st.file_uploader("Upload Image",type=['png','jpeg','jpg'])
|
115 |
+
submitted = st.form_submit_button("UPLOAD!")
|
116 |
+
if submitted and image_file is not None:
|
117 |
+
img = Image.open(image_file).convert("L")
|
118 |
+
|
119 |
+
# If no image was uploaded, use selected image
|
120 |
+
if img is None and img_index >= 0:
|
121 |
+
img = Image.open(image_paths[img_index]).convert("L")
|
122 |
+
|
123 |
+
|
124 |
+
cols = st.columns(4)
|
125 |
+
|
126 |
+
# Set Text
|
127 |
+
cols[0].text("Input Image")
|
128 |
+
cols[1].text("Full Training")
|
129 |
+
cols[2].text("8%")
|
130 |
+
cols[3].text("4%")
|
131 |
+
models = load_models(model_paths)
|
132 |
+
|
133 |
+
if img is not None:
|
134 |
+
with st.spinner('Document Cleaning in progress ...'):
|
135 |
+
img_tensor = process_image(img)
|
136 |
+
pil_image = transforms.ToPILImage()(img_tensor)
|
137 |
+
clned_imgs = [clean_image(torch.clone(img_tensor), m, shape) for m in models]
|
138 |
+
cols[0].image(pil_image)
|
139 |
+
for i in range(3):
|
140 |
+
cols[i + 1].image(clned_imgs[i])
|
141 |
+
|
142 |
+
|
143 |
+
with st.spinner('Text Detection and Recognition in progress ...'):
|
144 |
+
text_boxes = get_text_boxes(ocr, pil_image)
|
145 |
+
all_texts = list()
|
146 |
+
all_texts.append(ocr.extract_text(pil_image, text_boxes))
|
147 |
+
for i in range(3):
|
148 |
+
all_texts.append(ocr.extract_text(clned_imgs[i], text_boxes))
|
149 |
+
# text_boxes_more = get_text_boxes(ocr, clned_imgs[3])
|
150 |
+
title_cols = st.columns(5)
|
151 |
+
headings = ["Word Image", "Original", "Cleaned (100%)", "Cleaned (8%)", "Cleaned (4%)"]
|
152 |
+
for i, heading in enumerate(headings):
|
153 |
+
title_cols[i].markdown(f"## {heading}")
|
154 |
+
|
155 |
+
|
156 |
+
for i, box in enumerate(text_boxes):
|
157 |
+
txt_box_cols = st.columns(5)
|
158 |
+
txt_box_cols[0].image(box[0], use_column_width="always")
|
159 |
+
for j in range(4):
|
160 |
+
txt_box_cols[j + 1].text(all_texts[j][i])
|
161 |
|
162 |
|
163 |
|
models/vgg_4.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66973ba641156bc7afc03fe16846538887d5aba999b524a9f4f630a2d5f09e94
|
3 |
+
size 31124602
|
models/vgg_50.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2cc5cd193eb730070620680796f5ac0b767213cbfae161bff485a37061966f0e
|
3 |
+
size 31124700
|
models/vgg_8.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a177ea051d8a2b365e778e4fd8f44b1cfbfa93edc1fe22affa1e7a643b6eed91
|
3 |
+
size 31124501
|
text_images/0.png
ADDED
![]() |
text_images/1.png
ADDED
![]() |
text_images/2.png
ADDED
![]() |
text_images/3.png
ADDED
![]() |
text_images/4.png
ADDED
![]() |