ZennyKenny commited on
Commit
fd11c5a
·
verified ·
1 Parent(s): 817e54c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -2
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
 
4
  import spaces
5
 
6
  # Load TrOCR model
@@ -10,13 +11,26 @@ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwri
10
  @spaces.GPU
11
  def recognize_text(image):
12
  try:
 
13
  image = image.convert("RGB")
 
 
 
14
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
15
- generated_ids = model.generate(pixel_values)
 
 
 
 
 
 
 
16
  text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
17
-
 
18
  return text
19
  except Exception as e:
 
20
  return f"Error: {str(e)}"
21
 
22
  # Gradio UI
 
1
  import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
+ import torch
5
  import spaces
6
 
7
  # Load TrOCR model
 
11
  @spaces.GPU
12
  def recognize_text(image):
13
  try:
14
+ # Convert image to RGB if it's not already
15
  image = image.convert("RGB")
16
+ print("Image converted to RGB.")
17
+
18
+ # Preprocess the image
19
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
20
+ print("Image preprocessed. Pixel values shape:", pixel_values.shape)
21
+
22
+ # Generate text from the image
23
+ with torch.no_grad(): # Disable gradient calculation for inference
24
+ generated_ids = model.generate(pixel_values)
25
+ print("Generated IDs:", generated_ids)
26
+
27
+ # Decode the generated IDs to text
28
  text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
29
+ print("Decoded text:", text)
30
+
31
  return text
32
  except Exception as e:
33
+ print(f"Error: {str(e)}")
34
  return f"Error: {str(e)}"
35
 
36
  # Gradio UI