note-to-text / app.py
ZennyKenny's picture
Update app.py
194e156 verified
raw
history blame
2.06 kB
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
# Load TrOCR model
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
def preprocess_image(image):
# Convert image to RGB
image = image.convert("RGB")
# Resize and normalize the image to [0, 1]
transform = transforms.Compose([
transforms.Resize((384, 384)), # Resize to the expected input size
transforms.ToTensor(), # Convert to tensor and scale to [0, 1]
])
pixel_values = transform(image).unsqueeze(0) # Add batch dimension
return pixel_values
def visualize_image(pixel_values):
# Convert tensor to numpy array and permute dimensions for visualization
image = pixel_values.squeeze().permute(1, 2, 0).numpy()
plt.imshow(image)
plt.title("Preprocessed Image")
plt.show()
def recognize_text(image):
try:
# Preprocess the image
pixel_values = preprocess_image(image)
print("Image preprocessed. Pixel values shape:", pixel_values.shape)
# Visualize preprocessed image
visualize_image(pixel_values)
# Generate text from the image
with torch.no_grad():
generated_ids = model.generate(pixel_values)
print("Generated IDs:", generated_ids)
# Decode the generated IDs to text
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Decoded text:", text)
return text
except Exception as e:
print(f"Error: {str(e)}")
return f"Error: {str(e)}"
# Gradio UI
note = gr.Interface(
fn=recognize_text,
inputs=gr.Image(type="pil"),
outputs="text",
title="Handwritten Note to Digital Text",
description="Upload an image of handwritten text, and the AI will convert it to digital text."
)
note.launch()