Spaces:
Runtime error
Runtime error
FEAT: Tesseract OCR text recog added
Browse files- UI displays text crops and corresponding text for different preprocessed images
- app.py +43 -32
- ocr_libs.py +28 -0
- packages.txt +6 -0
- receipt_images/0.png +0 -0
- receipt_images/2.png +0 -0
- receipt_images/4.png +0 -0
- receipt_images/5.png +0 -0
app.py
CHANGED
@@ -7,6 +7,7 @@ from streamlit_image_select import image_select
|
|
7 |
import torchvision.transforms as transforms
|
8 |
import os
|
9 |
import torch
|
|
|
10 |
|
11 |
def process_image(image):
|
12 |
target_size = (400, 512)
|
@@ -31,23 +32,29 @@ def process_image(image):
|
|
31 |
# image = torch.tensor(image)
|
32 |
return image
|
33 |
|
34 |
-
|
35 |
def load_models():
|
36 |
model_paths = ["models/prep_50.pt", "models/prep_4.pt", "models/prep_4.pt"]
|
37 |
models = [torch.load(mpath, map_location='cpu').eval() for mpath in model_paths]
|
38 |
return models
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def clean_image(image, model):
|
41 |
img_out = model(image.unsqueeze(0))
|
42 |
-
img_out = img_out.reshape(400, 512).detach()
|
43 |
return img_out
|
44 |
|
|
|
45 |
image_folder = "receipt_images"
|
46 |
NUM_IMAGES = 3
|
47 |
|
48 |
image_paths = [f"{image_folder}/{i}.png" for i in range(NUM_IMAGES)]
|
49 |
-
|
50 |
-
|
51 |
img = None
|
52 |
|
53 |
img_index = image_select(
|
@@ -64,42 +71,46 @@ with st.form("my-form", clear_on_submit=True):
|
|
64 |
if submitted and image_file is not None:
|
65 |
img = Image.open(image_file).convert("L")
|
66 |
|
67 |
-
|
68 |
-
if img is None:
|
69 |
img = Image.open(image_paths[img_index]).convert("L")
|
70 |
|
71 |
|
72 |
cols = st.columns(4)
|
73 |
-
cols[0].text("Input Image")
|
74 |
|
|
|
|
|
75 |
cols[1].text("Full Training")
|
76 |
cols[2].text("8%")
|
77 |
cols[3].text("4%")
|
78 |
-
|
79 |
models = load_models()
|
80 |
|
81 |
if img is not None:
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
7 |
import torchvision.transforms as transforms
|
8 |
import os
|
9 |
import torch
|
10 |
+
from ocr_libs import tess_ocr
|
11 |
|
12 |
def process_image(image):
|
13 |
target_size = (400, 512)
|
|
|
32 |
# image = torch.tensor(image)
|
33 |
return image
|
34 |
|
35 |
+
@st.cache_resource
|
36 |
def load_models():
|
37 |
model_paths = ["models/prep_50.pt", "models/prep_4.pt", "models/prep_4.pt"]
|
38 |
models = [torch.load(mpath, map_location='cpu').eval() for mpath in model_paths]
|
39 |
return models
|
40 |
|
41 |
+
@st.cache_resource
|
42 |
+
def load_ocr():
|
43 |
+
return tess_ocr()
|
44 |
+
|
45 |
+
def get_text_boxes(_ocr, image):
|
46 |
+
return _ocr.detect_text(image)
|
47 |
+
|
48 |
def clean_image(image, model):
|
49 |
img_out = model(image.unsqueeze(0))
|
50 |
+
img_out = transforms.ToPILImage()(img_out.reshape(400, 512).detach())
|
51 |
return img_out
|
52 |
|
53 |
+
ocr = load_ocr()
|
54 |
image_folder = "receipt_images"
|
55 |
NUM_IMAGES = 3
|
56 |
|
57 |
image_paths = [f"{image_folder}/{i}.png" for i in range(NUM_IMAGES)]
|
|
|
|
|
58 |
img = None
|
59 |
|
60 |
img_index = image_select(
|
|
|
71 |
if submitted and image_file is not None:
|
72 |
img = Image.open(image_file).convert("L")
|
73 |
|
74 |
+
# If no image was uploaded, use selected image
|
75 |
+
if img is None and img_index >= 0:
|
76 |
img = Image.open(image_paths[img_index]).convert("L")
|
77 |
|
78 |
|
79 |
cols = st.columns(4)
|
|
|
80 |
|
81 |
+
# Set Text
|
82 |
+
cols[0].text("Input Image")
|
83 |
cols[1].text("Full Training")
|
84 |
cols[2].text("8%")
|
85 |
cols[3].text("4%")
|
|
|
86 |
models = load_models()
|
87 |
|
88 |
if img is not None:
|
89 |
+
with st.spinner('Document Cleaning in progress ...'):
|
90 |
+
img_tensor = process_image(img)
|
91 |
+
pil_image = transforms.ToPILImage()(img_tensor)
|
92 |
+
clned_imgs = [clean_image(torch.clone(img_tensor), m) for m in models]
|
93 |
+
cols[0].image(pil_image)
|
94 |
+
for i in range(3):
|
95 |
+
cols[i + 1].image(clned_imgs[i])
|
96 |
+
|
97 |
+
text_boxes = get_text_boxes(ocr, pil_image)
|
98 |
+
all_texts = list()
|
99 |
+
all_texts.append(ocr.extract_text(pil_image, text_boxes))
|
100 |
+
for i in range(3):
|
101 |
+
all_texts.append(ocr.extract_text(clned_imgs[i], text_boxes))
|
102 |
+
# text_boxes_more = get_text_boxes(ocr, clned_imgs[3])
|
103 |
+
|
104 |
+
print(all_texts)
|
105 |
+
for i, box in enumerate(text_boxes):
|
106 |
+
txt_box_cols = st.columns(5)
|
107 |
+
txt_box_cols[0].image(box[0], use_column_width="always")
|
108 |
+
for j in range(4):
|
109 |
+
txt_box_cols[j + 1].text(all_texts[j][i])
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
ocr_libs.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tesserocr
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
class tess_ocr:
|
6 |
+
|
7 |
+
def __init__(self):
|
8 |
+
self.api = tesserocr.PyTessBaseAPI(lang='eng')
|
9 |
+
self.api_line = tesserocr.PyTessBaseAPI(lang='eng', psm=tesserocr.PSM.SINGLE_LINE, oem=tesserocr.OEM.LSTM_ONLY)
|
10 |
+
|
11 |
+
def detect_text(self, image):
|
12 |
+
self.api.SetImage(image)
|
13 |
+
boxes = self.api.GetComponentImages(tesserocr.RIL.WORD, True)
|
14 |
+
return boxes
|
15 |
+
|
16 |
+
def extract_text(self, image, boxes):
|
17 |
+
OFFSET = 6
|
18 |
+
texts = list()
|
19 |
+
for i, (im, box, _, _) in enumerate(boxes):
|
20 |
+
cropped = image.crop((box["x"] - OFFSET, box["y"] - OFFSET , box["x"] + box["w"] + OFFSET, box["y"] + box["h"] + OFFSET))
|
21 |
+
self.api_line.SetImage(cropped)
|
22 |
+
ocrResult = self.api_line.GetUTF8Text().strip()
|
23 |
+
conf = self.api_line.MeanTextConf()
|
24 |
+
texts.append(ocrResult)
|
25 |
+
return texts
|
26 |
+
|
27 |
+
|
28 |
+
|
packages.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
libgl1
|
2 |
+
cmake
|
3 |
+
libssl-dev
|
4 |
+
libtesseract-dev
|
5 |
+
pkg-config
|
6 |
+
tesseract-ocr
|
receipt_images/0.png
CHANGED
![]() |
![]() |
receipt_images/2.png
CHANGED
![]() |
![]() |
receipt_images/4.png
ADDED
![]() |
receipt_images/5.png
ADDED
![]() |