prithivMLmods commited on
Commit
de1dce3
·
verified ·
1 Parent(s): 554bd83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -1
app.py CHANGED
@@ -20,6 +20,7 @@ from transformers import (
20
  TextIteratorStreamer,
21
  Qwen2VLForConditionalGeneration,
22
  AutoProcessor,
 
23
  )
24
  from transformers.image_utils import load_image
25
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
@@ -51,6 +52,16 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
51
  torch_dtype=torch.float16
52
  ).to("cuda").eval()
53
 
 
 
 
 
 
 
 
 
 
 
54
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
55
  communicate = edge_tts.Communicate(text, voice)
56
  await communicate.save(output_file)
@@ -188,6 +199,38 @@ def generate(
188
  files = input_dict.get("files", [])
189
 
190
  lower_text = text.lower().strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  # Check if the prompt is an image generation command using model flags.
192
  if (lower_text.startswith("@lightningv5") or
193
  lower_text.startswith("@lightningv4") or
@@ -345,13 +388,14 @@ demo = gr.ChatInterface(
345
  ['@lightningv4 A serene landscape with mountains'],
346
  ['@turbov3 Abstract art, colorful and vibrant'],
347
  ["@tts2 What causes rainbows to form?"],
 
348
  ],
349
  cache_examples=False,
350
  type="messages",
351
  description=DESCRIPTION,
352
  css=css,
353
  fill_height=True,
354
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="use the tags @lightningv5 @lightningv4 @turbov3 for image gen !"),
355
  stop_btn="Stop Generation",
356
  multimodal=True,
357
 
 
20
  TextIteratorStreamer,
21
  Qwen2VLForConditionalGeneration,
22
  AutoProcessor,
23
+ AutoModelForImageTextToText, # <-- New import for aya-vision
24
  )
25
  from transformers.image_utils import load_image
26
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
 
52
  torch_dtype=torch.float16
53
  ).to("cuda").eval()
54
 
55
+ # --- New feature: aya-vision ---
56
+ AYA_MODEL_ID = "CohereForAI/aya-vision-8b"
57
+ aya_processor = AutoProcessor.from_pretrained(AYA_MODEL_ID, trust_remote_code=True)
58
+ aya_model = AutoModelForImageTextToText.from_pretrained(
59
+ AYA_MODEL_ID,
60
+ trust_remote_code=True,
61
+ torch_dtype=torch.float16
62
+ ).to("cuda").eval()
63
+ # --------------------------------
64
+
65
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
66
  communicate = edge_tts.Communicate(text, voice)
67
  await communicate.save(output_file)
 
199
  files = input_dict.get("files", [])
200
 
201
  lower_text = text.lower().strip()
202
+
203
+ # --- New branch for @aya-vision feature ---
204
+ if lower_text.startswith("@aya-vision"):
205
+ prompt_clean = re.sub(r"@aya-vision", "", text, flags=re.IGNORECASE).strip().strip('"')
206
+ if not files:
207
+ yield "Please provide an image for @aya-vision command."
208
+ return
209
+ image = load_image(files[0])
210
+ messages = [{
211
+ "role": "user",
212
+ "content": [
213
+ {"type": "image", "image": image},
214
+ {"type": "text", "text": prompt_clean},
215
+ ]
216
+ }]
217
+ prompt_aya = aya_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
218
+ inputs = aya_processor(text=[prompt_aya], images=[image], return_tensors="pt", padding=True).to("cuda")
219
+ streamer = TextIteratorStreamer(aya_processor, skip_prompt=True, skip_special_tokens=True)
220
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
221
+ thread = Thread(target=aya_model.generate, kwargs=generation_kwargs)
222
+ thread.start()
223
+
224
+ buffer = ""
225
+ yield "💭 Processing @aya-vision..."
226
+ for new_text in streamer:
227
+ buffer += new_text
228
+ buffer = buffer.replace("<|im_end|>", "")
229
+ time.sleep(0.01)
230
+ yield buffer
231
+ return
232
+ # ------------------------------------------------
233
+
234
  # Check if the prompt is an image generation command using model flags.
235
  if (lower_text.startswith("@lightningv5") or
236
  lower_text.startswith("@lightningv4") or
 
388
  ['@lightningv4 A serene landscape with mountains'],
389
  ['@turbov3 Abstract art, colorful and vibrant'],
390
  ["@tts2 What causes rainbows to form?"],
391
+ ["@aya-vision Describe the content of this image"],
392
  ],
393
  cache_examples=False,
394
  type="messages",
395
  description=DESCRIPTION,
396
  css=css,
397
  fill_height=True,
398
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="@aya-vision for img-txt-txt / use the tags @lightningv5 @lightningv4 @turbov3 or @aya-vision for image-based commands!"),
399
  stop_btn="Stop Generation",
400
  multimodal=True,
401