gtata commited on
Commit
2ecfea1
·
1 Parent(s): 43e62b9

FEAT: Tesseract OCR text recog added

Browse files

- UI displays text crops and corresponding text for different preprocessed images

app.py CHANGED
@@ -7,6 +7,7 @@ from streamlit_image_select import image_select
7
  import torchvision.transforms as transforms
8
  import os
9
  import torch
 
10
 
11
  def process_image(image):
12
  target_size = (400, 512)
@@ -31,23 +32,29 @@ def process_image(image):
31
  # image = torch.tensor(image)
32
  return image
33
 
34
-
35
  def load_models():
36
  model_paths = ["models/prep_50.pt", "models/prep_4.pt", "models/prep_4.pt"]
37
  models = [torch.load(mpath, map_location='cpu').eval() for mpath in model_paths]
38
  return models
39
 
 
 
 
 
 
 
 
40
  def clean_image(image, model):
41
  img_out = model(image.unsqueeze(0))
42
- img_out = img_out.reshape(400, 512).detach().numpy()
43
  return img_out
44
 
 
45
  image_folder = "receipt_images"
46
  NUM_IMAGES = 3
47
 
48
  image_paths = [f"{image_folder}/{i}.png" for i in range(NUM_IMAGES)]
49
-
50
-
51
  img = None
52
 
53
  img_index = image_select(
@@ -64,42 +71,46 @@ with st.form("my-form", clear_on_submit=True):
64
  if submitted and image_file is not None:
65
  img = Image.open(image_file).convert("L")
66
 
67
- print(img_index)
68
- if img is None:
69
  img = Image.open(image_paths[img_index]).convert("L")
70
 
71
 
72
  cols = st.columns(4)
73
- cols[0].text("Input Image")
74
 
 
 
75
  cols[1].text("Full Training")
76
  cols[2].text("8%")
77
  cols[3].text("4%")
78
-
79
  models = load_models()
80
 
81
  if img is not None:
82
-
83
- img_tensor = process_image(img)
84
- clned_imgs = [clean_image(img_tensor, m) for m in models]
85
-
86
-
87
- cols[0].image(img_tensor.permute(1,2, 0).numpy())
88
-
89
-
90
- cols[1].image(clned_imgs[0])
91
-
92
-
93
- cols[2].image(clned_imgs[1])
94
-
95
-
96
- cols[3].image(clned_imgs[2])
97
-
98
-
99
-
100
-
101
-
102
- # cols = st.columns(NUM_IMAGES)
103
- # for i in range(NUM_IMAGES):
104
- # image = Image.open(os.path.join(image_folder, f"{i}.png"))
105
- # cols[i].image(image)
 
 
 
 
 
7
  import torchvision.transforms as transforms
8
  import os
9
  import torch
10
+ from ocr_libs import tess_ocr
11
 
12
  def process_image(image):
13
  target_size = (400, 512)
 
32
  # image = torch.tensor(image)
33
  return image
34
 
35
+ @st.cache_resource
36
  def load_models():
37
  model_paths = ["models/prep_50.pt", "models/prep_4.pt", "models/prep_4.pt"]
38
  models = [torch.load(mpath, map_location='cpu').eval() for mpath in model_paths]
39
  return models
40
 
41
+ @st.cache_resource
42
+ def load_ocr():
43
+ return tess_ocr()
44
+
45
+ def get_text_boxes(_ocr, image):
46
+ return _ocr.detect_text(image)
47
+
48
  def clean_image(image, model):
49
  img_out = model(image.unsqueeze(0))
50
+ img_out = transforms.ToPILImage()(img_out.reshape(400, 512).detach())
51
  return img_out
52
 
53
+ ocr = load_ocr()
54
  image_folder = "receipt_images"
55
  NUM_IMAGES = 3
56
 
57
  image_paths = [f"{image_folder}/{i}.png" for i in range(NUM_IMAGES)]
 
 
58
  img = None
59
 
60
  img_index = image_select(
 
71
  if submitted and image_file is not None:
72
  img = Image.open(image_file).convert("L")
73
 
74
+ # If no image was uploaded, use selected image
75
+ if img is None and img_index >= 0:
76
  img = Image.open(image_paths[img_index]).convert("L")
77
 
78
 
79
  cols = st.columns(4)
 
80
 
81
+ # Set Text
82
+ cols[0].text("Input Image")
83
  cols[1].text("Full Training")
84
  cols[2].text("8%")
85
  cols[3].text("4%")
 
86
  models = load_models()
87
 
88
  if img is not None:
89
+ with st.spinner('Document Cleaning in progress ...'):
90
+ img_tensor = process_image(img)
91
+ pil_image = transforms.ToPILImage()(img_tensor)
92
+ clned_imgs = [clean_image(torch.clone(img_tensor), m) for m in models]
93
+ cols[0].image(pil_image)
94
+ for i in range(3):
95
+ cols[i + 1].image(clned_imgs[i])
96
+
97
+ text_boxes = get_text_boxes(ocr, pil_image)
98
+ all_texts = list()
99
+ all_texts.append(ocr.extract_text(pil_image, text_boxes))
100
+ for i in range(3):
101
+ all_texts.append(ocr.extract_text(clned_imgs[i], text_boxes))
102
+ # text_boxes_more = get_text_boxes(ocr, clned_imgs[3])
103
+
104
+ print(all_texts)
105
+ for i, box in enumerate(text_boxes):
106
+ txt_box_cols = st.columns(5)
107
+ txt_box_cols[0].image(box[0], use_column_width="always")
108
+ for j in range(4):
109
+ txt_box_cols[j + 1].text(all_texts[j][i])
110
+
111
+
112
+
113
+
114
+
115
+
116
+
ocr_libs.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tesserocr
2
+
3
+
4
+
5
+ class tess_ocr:
6
+
7
+ def __init__(self):
8
+ self.api = tesserocr.PyTessBaseAPI(lang='eng')
9
+ self.api_line = tesserocr.PyTessBaseAPI(lang='eng', psm=tesserocr.PSM.SINGLE_LINE, oem=tesserocr.OEM.LSTM_ONLY)
10
+
11
+ def detect_text(self, image):
12
+ self.api.SetImage(image)
13
+ boxes = self.api.GetComponentImages(tesserocr.RIL.WORD, True)
14
+ return boxes
15
+
16
+ def extract_text(self, image, boxes):
17
+ OFFSET = 6
18
+ texts = list()
19
+ for i, (im, box, _, _) in enumerate(boxes):
20
+ cropped = image.crop((box["x"] - OFFSET, box["y"] - OFFSET , box["x"] + box["w"] + OFFSET, box["y"] + box["h"] + OFFSET))
21
+ self.api_line.SetImage(cropped)
22
+ ocrResult = self.api_line.GetUTF8Text().strip()
23
+ conf = self.api_line.MeanTextConf()
24
+ texts.append(ocrResult)
25
+ return texts
26
+
27
+
28
+
packages.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ libgl1
2
+ cmake
3
+ libssl-dev
4
+ libtesseract-dev
5
+ pkg-config
6
+ tesseract-ocr
receipt_images/0.png CHANGED
receipt_images/2.png CHANGED
receipt_images/4.png ADDED
receipt_images/5.png ADDED