amiguel commited on
Commit
95c6d6f
·
verified ·
1 Parent(s): 0864d4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -130
app.py CHANGED
@@ -1,139 +1,99 @@
1
  import streamlit as st
2
- import torch
 
3
  import base64
4
- from io import BytesIO
5
  from PIL import Image
6
- from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
7
- from olmocr.data.renderpdf import render_pdf_to_base64png
8
- from olmocr.prompts import build_finetuning_prompt
9
- from olmocr.prompts.anchor import get_anchor_text
10
 
11
- # Initialize the model
12
- model = Qwen2VLForConditionalGeneration.from_pretrained("allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16).eval()
13
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- model.to(device)
 
 
 
 
 
 
 
16
 
17
- # Set the font
18
- st.markdown(
19
- """
20
- <style>
21
- @import url('https://fonts.googleapis.com/css2?family=Tw+Cen+MT&display=swap');
22
- body {
23
- font-family: 'Tw Cen MT', sans-serif;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
25
- </style>
26
- """,
27
- unsafe_allow_html=True,
28
- )
 
 
 
 
 
29
 
30
- # Title and description
31
- st.title("Document Processing App")
32
- st.write("Upload a PDF, Excel, Word, PNG, JPG, or JPEG file to process it.")
33
 
34
- # File uploader
35
- uploaded_file = st.sidebar.file_uploader("Choose a file", type=["pdf", "xls", "xlsx", "doc", "docx", "png", "jpg", "jpeg"])
36
 
37
- if uploaded_file is not None:
38
- # Process the uploaded file
39
- if uploaded_file.type == "application/pdf":
40
- # Render page 1 to an image
41
- image_base64 = render_pdf_to_base64png(uploaded_file, 1, target_longest_image_dim=1024)
42
-
43
- # Build the prompt, using document metadata
44
- anchor_text = get_anchor_text(uploaded_file, 1, pdf_engine="pdfreport", target_length=4000)
45
- prompt = build_finetuning_prompt(anchor_text)
46
-
47
- # Build the full prompt
48
- messages = [
49
- {
50
- "role": "user",
51
- "content": [
52
- {"type": "text", "text": prompt},
53
- {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
54
- ],
55
- }
56
- ]
57
-
58
- # Apply the chat template and processor
59
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
60
- main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
61
- inputs = processor(
62
- text=[text],
63
- images=[main_image],
64
- padding=True,
65
- return_tensors="pt",
66
- )
67
- inputs = {key: value.to(device) for (key, value) in inputs.items()}
68
-
69
- # Generate the output
70
- output = model.generate(
71
- **inputs,
72
- temperature=0.8,
73
- max_new_tokens=50,
74
- num_return_sequences=1,
75
- do_sample=True,
76
- )
77
-
78
- # Decode the output
79
- prompt_length = inputs["input_ids"].shape[1]
80
- new_tokens = output[:, prompt_length:]
81
- text_output = processor.tokenizer.batch_decode(
82
- new_tokens, skip_special_tokens=True
83
- )
84
-
85
- # Display the result
86
- st.write("Processed Text:")
87
- st.write(text_output)
88
-
89
- elif uploaded_file.type in ["image/png", "image/jpeg"]:
90
- # Load the image
91
- image = Image.open(uploaded_file)
92
- image_base64 = base64.b64encode(image.tobytes()).decode('utf-8')
93
-
94
- # Build the prompt
95
- prompt = "Please describe the content of the image."
96
-
97
- # Build the full prompt
98
- messages = [
99
- {
100
- "role": "user",
101
- "content": [
102
- {"type": "text", "text": prompt},
103
- {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
104
- ],
105
- }
106
- ]
107
-
108
- # Apply the chat template and processor
109
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
110
- inputs = processor(
111
- text=[text],
112
- images=[image],
113
- padding=True,
114
- return_tensors="pt",
115
- )
116
- inputs = {key: value.to(device) for (key, value) in inputs.items()}
117
-
118
- # Generate the output
119
- output = model.generate(
120
- **inputs,
121
- temperature=0.8,
122
- max_new_tokens=50,
123
- num_return_sequences=1,
124
- do_sample=True,
125
- )
126
-
127
- # Decode the output
128
- prompt_length = inputs["input_ids"].shape[1]
129
- new_tokens = output[:, prompt_length:]
130
- text_output = processor.tokenizer.batch_decode(
131
- new_tokens, skip_special_tokens=True
132
- )
133
-
134
- # Display the result
135
- st.write("Processed Text:")
136
- st.write(text_output)
137
-
138
- else:
139
- st.write("Unsupported file type.")
 
1
  import streamlit as st
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq
3
+ from pdf2image import convert_from_path
4
  import base64
5
+ import io
6
  from PIL import Image
 
 
 
 
7
 
8
+ # Load the OCR model and processor from Hugging Face
9
+ try:
10
+ processor = AutoProcessor.from_pretrained("allenai/olmOCR-7B-0225-preview")
11
+ model = AutoModelForVision2Seq.from_pretrained("allenai/olmOCR-7B-0225-preview")
12
+ except ImportError as e:
13
+ processor = None
14
+ model = None
15
+ print(f"Error loading model: {str(e)}. Please ensure PyTorch is installed.")
16
+ except ValueError as e:
17
+ processor = None
18
+ model = None
19
+ print(f"Error with model configuration: {str(e)}")
20
 
21
+ def process_pdf(pdf_file):
22
+ """ Process the uploaded PDF file one page at a time, yielding HTML for each page with its image and extracted text. """
23
+ if processor is None or model is None:
24
+ return "<p>Error: Model could not be loaded. Check environment setup (PyTorch may be missing) or model compatibility.</p>"
25
+
26
+ # Check if a PDF file was uploaded
27
+ if pdf_file is None:
28
+ return "<p>Please upload a PDF file.</p>"
29
+
30
+ # Convert PDF to images
31
+ try:
32
+ pages = convert_from_path(pdf_file.name)
33
+ except Exception as e:
34
+ return f"<p>Error converting PDF to images: {str(e)}</p>"
35
+
36
+ # Initial HTML with "Copy All" button and container for pages
37
+ html = '<div><button onclick="copyAll()" style="margin-bottom: 10px;">Copy All</button></div><div id="pages">'
38
+
39
+ # Process each page incrementally
40
+ for i, page in enumerate(pages):
41
+ # Convert the page image to base64 for embedding in HTML
42
+ buffered = io.BytesIO()
43
+ page.save(buffered, format="PNG")
44
+ img_str = base64.b64encode(buffered.getvalue()).decode()
45
+ img_data = f"data:image/png;base64,{img_str}"
46
+
47
+ # Extract text from the page using the OCR model
48
+ try:
49
+ inputs = processor(text="Extract the text from this image.", images=page, return_tensors="pt")
50
+ outputs = model.generate(**inputs)
51
+ text = processor.decode(outputs[0], skip_special_tokens=True)
52
+ except Exception as e:
53
+ text = f"Error extracting text: {str(e)}"
54
+
55
+ # Generate HTML for this page's section
56
+ textarea_id = f"text{i+1}"
57
+ page_html = f'''
58
+ <div class="page" style="margin-bottom: 20px; border-bottom: 1px solid #ccc; padding-bottom: 20px;">
59
+ <h3>Page {i+1}</h3>
60
+ <div style="display: flex; align-items: flex-start;">
61
+ <img src="{img_data}" alt="Page {i+1}" style="max-width: 300px; margin-right: 20px;">
62
+ <div style="flex-grow: 1;">
63
+ <textarea id="{textarea_id}" rows="10" style="width: 100%;">{text}</textarea>
64
+ <button onclick="copyText('{textarea_id}')" style="margin-top: 5px;">Copy</button>
65
+ </div>
66
+ </div>
67
+ </div>
68
+ '''
69
+ html += page_html
70
+
71
+ # After all pages are processed, close the div and add JavaScript
72
+ html += '</div>'
73
+ html += '''
74
+ <script>
75
+ function copyText(id) {
76
+ var text = document.getElementById(id);
77
+ text.select();
78
+ document.execCommand("copy");
79
  }
80
+ function copyAll() {
81
+ var texts = document.querySelectorAll("#pages textarea");
82
+ var allText = Array.from(texts).map(t => t.value).join("\\n\\n");
83
+ navigator.clipboard.writeText(allText);
84
+ }
85
+ </script>
86
+ '''
87
+
88
+ return html
89
 
90
+ # Define the Streamlit interface
91
+ st.title("PDF Text Extractor")
92
+ st.markdown("Upload a PDF file and click 'Extract Text' to see each page's image and extracted text incrementally.")
93
 
94
+ pdf_input = st.file_uploader("Upload PDF", type=["pdf"])
95
+ submit_btn = st.button("Extract Text")
96
 
97
+ if submit_btn and pdf_input:
98
+ output_html = process_pdf(pdf_input)
99
+ st.components.v1.html(output_html, height=800)