ZennyKenny commited on
Commit
194e156
·
verified ·
1 Parent(s): 525e830

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -11
app.py CHANGED
@@ -2,31 +2,43 @@ import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
  import torch
 
5
  import matplotlib.pyplot as plt
6
- import spaces
7
 
8
  # Load TrOCR model
9
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
10
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
11
 
12
- @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def recognize_text(image):
14
  try:
15
- # Convert image to RGB if it's not already
16
- image = image.convert("RGB")
17
- print("Image converted to RGB.")
18
-
19
  # Preprocess the image
20
- pixel_values = processor(images=image, return_tensors="pt").pixel_values
21
  print("Image preprocessed. Pixel values shape:", pixel_values.shape)
22
 
23
  # Visualize preprocessed image
24
- plt.imshow(pixel_values.squeeze().permute(1, 2, 0))
25
- plt.title("Preprocessed Image")
26
- plt.show()
27
 
28
  # Generate text from the image
29
- with torch.no_grad(): # Disable gradient calculation for inference
30
  generated_ids = model.generate(pixel_values)
31
  print("Generated IDs:", generated_ids)
32
 
 
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
  import torch
5
+ from torchvision import transforms
6
  import matplotlib.pyplot as plt
 
7
 
8
  # Load TrOCR model
9
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
10
  model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
11
 
12
+ def preprocess_image(image):
13
+ # Convert image to RGB
14
+ image = image.convert("RGB")
15
+
16
+ # Resize and normalize the image to [0, 1]
17
+ transform = transforms.Compose([
18
+ transforms.Resize((384, 384)), # Resize to the expected input size
19
+ transforms.ToTensor(), # Convert to tensor and scale to [0, 1]
20
+ ])
21
+ pixel_values = transform(image).unsqueeze(0) # Add batch dimension
22
+ return pixel_values
23
+
24
+ def visualize_image(pixel_values):
25
+ # Convert tensor to numpy array and permute dimensions for visualization
26
+ image = pixel_values.squeeze().permute(1, 2, 0).numpy()
27
+ plt.imshow(image)
28
+ plt.title("Preprocessed Image")
29
+ plt.show()
30
+
31
  def recognize_text(image):
32
  try:
 
 
 
 
33
  # Preprocess the image
34
+ pixel_values = preprocess_image(image)
35
  print("Image preprocessed. Pixel values shape:", pixel_values.shape)
36
 
37
  # Visualize preprocessed image
38
+ visualize_image(pixel_values)
 
 
39
 
40
  # Generate text from the image
41
+ with torch.no_grad():
42
  generated_ids = model.generate(pixel_values)
43
  print("Generated IDs:", generated_ids)
44