gtata commited on
Commit
0755539
·
1 Parent(s): d9a169f

FEAT: Basic VGG inference added

Browse files
app.py CHANGED
@@ -9,7 +9,26 @@ import os
9
  import torch
10
  from ocr_libs import tess_ocr
11
 
12
- def process_image(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  target_size = (400, 512)
14
  # image = Image.open(img_name).convert("L")
15
  w, h = image.size
@@ -32,9 +51,19 @@ def process_image(image):
32
  # image = torch.tensor(image)
33
  return image
34
 
 
 
 
 
 
 
 
 
 
 
 
35
  @st.cache_resource
36
- def load_models():
37
- model_paths = ["models/prep_50.pt", "models/prep_4.pt", "models/prep_4.pt"]
38
  models = [torch.load(mpath, map_location='cpu').eval() for mpath in model_paths]
39
  return models
40
 
@@ -44,74 +73,91 @@ def load_ocr():
44
  def get_text_boxes(_ocr, image):
45
  return _ocr.detect_text(image)
46
 
47
- def clean_image(image, model):
48
  img_out = model(image.unsqueeze(0))
49
- img_out = transforms.ToPILImage()(img_out.reshape(400, 512).detach())
50
  return img_out
51
 
52
  ocr = load_ocr()
53
- image_folder = "receipt_images"
54
- NUM_IMAGES = 3
55
-
56
- image_paths = [f"{image_folder}/{i}.png" for i in range(NUM_IMAGES)]
57
- img = None
58
-
59
- img_index = image_select(
60
- label="Select Image",
61
- images=image_paths,
62
- use_container_width=False,
63
- index=-1,
64
- return_value="index"
65
- )
66
- img = None
67
- with st.form("my-form", clear_on_submit=True):
68
- image_file = st.file_uploader("Upload Image",type=['png','jpeg','jpg'])
69
- submitted = st.form_submit_button("UPLOAD!")
70
- if submitted and image_file is not None:
71
- img = Image.open(image_file).convert("L")
72
-
73
- # If no image was uploaded, use selected image
74
- if img is None and img_index >= 0:
75
- img = Image.open(image_paths[img_index]).convert("L")
76
-
77
-
78
- cols = st.columns(4)
79
-
80
- # Set Text
81
- cols[0].text("Input Image")
82
- cols[1].text("Full Training")
83
- cols[2].text("8%")
84
- cols[3].text("4%")
85
- models = load_models()
86
-
87
- if img is not None:
88
- with st.spinner('Document Cleaning in progress ...'):
89
- img_tensor = process_image(img)
90
- pil_image = transforms.ToPILImage()(img_tensor)
91
- clned_imgs = [clean_image(torch.clone(img_tensor), m) for m in models]
92
- cols[0].image(pil_image)
93
- for i in range(3):
94
- cols[i + 1].image(clned_imgs[i])
95
-
96
-
97
- with st.spinner('Text Detection and Recognition in progress ...'):
98
- text_boxes = get_text_boxes(ocr, pil_image)
99
- all_texts = list()
100
- all_texts.append(ocr.extract_text(pil_image, text_boxes))
101
- for i in range(3):
102
- all_texts.append(ocr.extract_text(clned_imgs[i], text_boxes))
103
- # text_boxes_more = get_text_boxes(ocr, clned_imgs[3])
104
- title_cols = st.columns(5)
105
- headings = ["Word Image", "Original", "Cleaned (100%)", "Cleaned (8%)", "Cleaned (4%)"]
106
- for i, heading in enumerate(headings):
107
- title_cols[i].markdown(f"## {heading}")
108
-
109
-
110
- for i, box in enumerate(text_boxes):
111
- txt_box_cols = st.columns(5)
112
- txt_box_cols[0].image(box[0], use_column_width="always")
113
- for j in range(4):
114
- txt_box_cols[j + 1].text(all_texts[j][i])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
 
117
 
 
9
  import torch
10
  from ocr_libs import tess_ocr
11
 
12
+ class PadWhite(object):
13
+ def __init__(self, size):
14
+ assert isinstance(size, (int, tuple))
15
+ if isinstance(size, tuple):
16
+ self.height, self.width = size
17
+ elif isinstance(size, int):
18
+ self.height = self.width = size
19
+
20
+ def __call__(self, img):
21
+ if img.size[0] > self.width or img.size[1] > self.height:
22
+ img.thumbnail((self.width, self.height))
23
+ delta_width = self.width - img.size[0]
24
+ delta_height = self.height - img.size[1]
25
+ pad_width = delta_width // 2
26
+ pad_height = delta_height // 2
27
+ padding = (pad_width, pad_height, delta_width -
28
+ pad_width, delta_height-pad_height)
29
+ return ImageOps.expand(img, padding, fill=255)
30
+
31
+ def process_image_pos(image):
32
  target_size = (400, 512)
33
  # image = Image.open(img_name).convert("L")
34
  w, h = image.size
 
51
  # image = torch.tensor(image)
52
  return image
53
 
54
+ def process_image_vgg(image):
55
+ input_size = (32, 128)
56
+ transform = transforms.Compose([
57
+ PadWhite(input_size),
58
+ transforms.ToTensor(),
59
+ ])
60
+ image = transform(image)
61
+ return image
62
+
63
+
64
+
65
  @st.cache_resource
66
+ def load_models(model_paths):
 
67
  models = [torch.load(mpath, map_location='cpu').eval() for mpath in model_paths]
68
  return models
69
 
 
73
  def get_text_boxes(_ocr, image):
74
  return _ocr.detect_text(image)
75
 
76
+ def clean_image(image, model, shape=(400, 512)):
77
  img_out = model(image.unsqueeze(0))
78
+ img_out = transforms.ToPILImage()(img_out.reshape(*shape).detach())
79
  return img_out
80
 
81
  ocr = load_ocr()
82
+
83
+
84
+ dataset = st.radio(
85
+ "Choose image type ",
86
+ ('POS', 'VGG'))
87
+ if dataset == "POS":
88
+ model_paths = ["models/prep_50.pt", "models/prep_8.pt", "models/prep_4.pt"]
89
+ process_image = process_image_pos
90
+ image_folder = "receipt_images"
91
+ shape = (400, 512)
92
+
93
+ elif dataset == "VGG":
94
+ model_paths = ["models/vgg_50.pt", "models/vgg_8.pt", "models/vgg_4.pt"]
95
+ process_image = process_image_vgg
96
+ image_folder = "text_images"
97
+ shape = (32, 128)
98
+
99
+ if dataset:
100
+ NUM_IMAGES = 3
101
+
102
+ image_paths = [f"{image_folder}/{i}.png" for i in range(NUM_IMAGES)]
103
+ img = None
104
+
105
+ img_index = image_select(
106
+ label="Select Image",
107
+ images=image_paths,
108
+ use_container_width=False,
109
+ index=-1,
110
+ return_value="index"
111
+ )
112
+ img = None
113
+ with st.form("my-form", clear_on_submit=True):
114
+ image_file = st.file_uploader("Upload Image",type=['png','jpeg','jpg'])
115
+ submitted = st.form_submit_button("UPLOAD!")
116
+ if submitted and image_file is not None:
117
+ img = Image.open(image_file).convert("L")
118
+
119
+ # If no image was uploaded, use selected image
120
+ if img is None and img_index >= 0:
121
+ img = Image.open(image_paths[img_index]).convert("L")
122
+
123
+
124
+ cols = st.columns(4)
125
+
126
+ # Set Text
127
+ cols[0].text("Input Image")
128
+ cols[1].text("Full Training")
129
+ cols[2].text("8%")
130
+ cols[3].text("4%")
131
+ models = load_models(model_paths)
132
+
133
+ if img is not None:
134
+ with st.spinner('Document Cleaning in progress ...'):
135
+ img_tensor = process_image(img)
136
+ pil_image = transforms.ToPILImage()(img_tensor)
137
+ clned_imgs = [clean_image(torch.clone(img_tensor), m, shape) for m in models]
138
+ cols[0].image(pil_image)
139
+ for i in range(3):
140
+ cols[i + 1].image(clned_imgs[i])
141
+
142
+
143
+ with st.spinner('Text Detection and Recognition in progress ...'):
144
+ text_boxes = get_text_boxes(ocr, pil_image)
145
+ all_texts = list()
146
+ all_texts.append(ocr.extract_text(pil_image, text_boxes))
147
+ for i in range(3):
148
+ all_texts.append(ocr.extract_text(clned_imgs[i], text_boxes))
149
+ # text_boxes_more = get_text_boxes(ocr, clned_imgs[3])
150
+ title_cols = st.columns(5)
151
+ headings = ["Word Image", "Original", "Cleaned (100%)", "Cleaned (8%)", "Cleaned (4%)"]
152
+ for i, heading in enumerate(headings):
153
+ title_cols[i].markdown(f"## {heading}")
154
+
155
+
156
+ for i, box in enumerate(text_boxes):
157
+ txt_box_cols = st.columns(5)
158
+ txt_box_cols[0].image(box[0], use_column_width="always")
159
+ for j in range(4):
160
+ txt_box_cols[j + 1].text(all_texts[j][i])
161
 
162
 
163
 
models/vgg_4.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66973ba641156bc7afc03fe16846538887d5aba999b524a9f4f630a2d5f09e94
3
+ size 31124602
models/vgg_50.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cc5cd193eb730070620680796f5ac0b767213cbfae161bff485a37061966f0e
3
+ size 31124700
models/vgg_8.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a177ea051d8a2b365e778e4fd8f44b1cfbfa93edc1fe22affa1e7a643b6eed91
3
+ size 31124501
text_images/0.png ADDED
text_images/1.png ADDED
text_images/2.png ADDED
text_images/3.png ADDED
text_images/4.png ADDED