olmOCR / app.py
herokeyboard369's picture
Create app.py
b8577b9 verified
import torch
import base64
import urllib.request
import gradio as gr
from io import BytesIO
from PIL import Image
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
# Initialize the model
model = Qwen2VLForConditionalGeneration.from_pretrained("allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16).eval()
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Function to process PDF and generate text
def process_pdf(pdf_file):
pdf_filename = pdf_file.name
image_base64 = render_pdf_to_base64png(pdf_filename, 1, target_longest_image_dim=1024)
anchor_text = get_anchor_text(pdf_filename, 1, pdf_engine="pdfreport", target_length=4000)
prompt = build_finetuning_prompt(anchor_text)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="pt",
)
inputs = {key: value.to(device) for (key, value) in inputs.items()}
output = model.generate(
**inputs,
temperature=0.8,
max_new_tokens=1500,
num_return_sequences=1,
do_sample=True,
)
prompt_length = inputs["input_ids"].shape[1]
new_tokens = output[:, prompt_length:]
text_output = processor.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
return text_output[0]
# Create Gradio Interface
iface = gr.Interface(
fn=process_pdf,
inputs=gr.File(label="Upload PDF"),
outputs=gr.Textbox(label="Extracted Text"),
title="PDF Text Extractor",
description="Upload a PDF file and extract text using Qwen2-VL-7B-Instruct."
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()