yuki-imajuku commited on
Commit
b5e2084
·
0 Parent(s):

initial commit

Browse files
Files changed (4) hide show
  1. .gitattributes +35 -0
  2. README.md +12 -0
  3. app.py +114 -0
  4. requirements.txt +4 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Manga Panel OCR
3
+ emoji: 📚
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.20.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install FlashAttention
2
+ import subprocess
3
+ subprocess.run(
4
+ "pip install flash-attn --no-build-isolation",
5
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
+ shell=True,
7
+ )
8
+
9
+ import base64
10
+ from io import BytesIO
11
+ import re
12
+
13
+ from PIL import Image, ImageDraw
14
+ import gradio as gr
15
+ import spaces
16
+ import torch
17
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
18
+ from qwen_vl_utils import process_vision_info
19
+
20
+
21
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
22
+
23
+
24
+ def pil2base64(image: Image.Image) -> str:
25
+ buffered = BytesIO()
26
+ image.save(buffered, format="PNG")
27
+ return base64.b64encode(buffered.getvalue()).decode()
28
+
29
+
30
+ @spaces.GPU
31
+ @torch.inference_mode()
32
+ def inference_fn(
33
+ image: Image.Image | None,
34
+ # progress=gr.Progress(track_tqdm=True),
35
+ ) -> tuple[str, Image.Image | None]:
36
+ if image is None:
37
+ gr.Warning("Please upload an image!", duration=10)
38
+ return "Please upload an image!", None
39
+
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
42
+ "yuki-imajuku/Qwen2.5-VL-3B-Instruct-FT-Manga109-OCR-Panel-Onomatopoeia",
43
+ torch_dtype=torch.bfloat16,
44
+ attn_implementation="flash_attention_2",
45
+ device_map=device,
46
+ )
47
+
48
+ base64_image = pil2base64(image)
49
+ messages = [
50
+ {"role": "user", "content": [
51
+ {"type": "image", "image": f"data:image;base64,{base64_image}"},
52
+ {"type": "text", "text": "With this image, please output the result of OCR with grounding."}
53
+ ]},
54
+ ]
55
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
56
+ image_inputs, video_inputs = process_vision_info(messages)
57
+ inputs = processor(
58
+ text=[text],
59
+ images=image_inputs,
60
+ videos=video_inputs,
61
+ padding=True,
62
+ return_tensors="pt",
63
+ )
64
+ inputs = inputs.to(model.device)
65
+
66
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
67
+ generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
68
+ raw_output = processor.batch_decode(
69
+ generated_ids_trimmed,
70
+ skip_special_tokens=False,
71
+ clean_up_tokenization_spaces=False,
72
+ )[0]
73
+
74
+ print(raw_output)
75
+
76
+ result_image = image_inputs[0].copy()
77
+ draw = ImageDraw.Draw(result_image)
78
+ ocr_texts = []
79
+ for ocr_text, ocr_quad in re.findall(r"<\|object_ref_start\|>(.+?)<\|object_ref_end\|><\|quad_start\|>([\d,]+)<\|quad_end\|>", raw_output):
80
+ ocr_texts.append(f"{ocr_text} -> {ocr_quad}")
81
+ quad = [int(x) for x in ocr_quad.split(",")]
82
+ for i in range(4):
83
+ start_point = quad[i*2:i*2+2]
84
+ end_point = quad[i*2+2:i*2+4] if i < 3 else quad[:2]
85
+ draw.line(start_point + end_point, fill="red", width=4)
86
+ ocr_texts_str = "\n".join(ocr_texts)
87
+
88
+ return ocr_texts_str, result_image
89
+
90
+
91
+ with gr.Blocks() as demo:
92
+ gr.Markdown("# Manga Panel OCR")
93
+ with gr.Row():
94
+ with gr.Column():
95
+ input_image = gr.Image(label="Input Image", image_mode="RGB", type="pil")
96
+ input_button = gr.Button(value="Submit")
97
+ with gr.Column():
98
+ ocr_text = gr.Textbox(label="Result", lines=5)
99
+ ocr_image = gr.Image(label="OCR Result", type="pil", show_label=False)
100
+
101
+ input_button.click(
102
+ fn=inference_fn,
103
+ inputs=[input_image],
104
+ outputs=[ocr_text, ocr_image],
105
+ )
106
+ ocr_examples = gr.Examples(
107
+ examples=[],
108
+ fn=inference_fn,
109
+ inputs=[input_image],
110
+ outputs=[ocr_text, ocr_image],
111
+ cache_examples=False,
112
+ )
113
+
114
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ accelerate==1.3.0
2
+ qwen-vl-utils==0.0.10
3
+ torchvision==0.20.1 --extra-index-url https://download.pytorch.org/whl/cu121
4
+ transformers @ git+https://github.com/huggingface/transformers@6b550462139655d488d4c663086a63e98713c6b9