Reality123b commited on
Commit
55aa708
·
verified ·
1 Parent(s): 9edae58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -28
app.py CHANGED
@@ -13,6 +13,8 @@ import networkx as nx
13
  from collections import Counter
14
  import json
15
  from datetime import datetime
 
 
16
 
17
  @dataclass
18
  class ChatMessage:
@@ -32,11 +34,9 @@ class XylariaChat:
32
  model="mistralai/Mistral-Nemo-Instruct-2407",
33
  token=self.hf_token
34
  )
35
-
36
- self.image_api_url = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"
37
- self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"}
38
 
39
  self.image_gen_api_url = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
 
40
 
41
  self.conversation_history = []
42
  self.persistent_memory = []
@@ -97,6 +97,13 @@ class XylariaChat:
97
 
98
  self.chat_history_file = "chat_history.json"
99
 
 
 
 
 
 
 
 
100
 
101
  def update_internal_state(self, emotion_deltas, cognitive_load_deltas, introspection_delta, engagement_delta):
102
  for emotion, delta in emotion_deltas.items():
@@ -401,34 +408,44 @@ class XylariaChat:
401
  print(f"Error resetting API client: {e}")
402
 
403
  return None
404
-
405
- def caption_image(self, image):
406
  try:
407
- if isinstance(image, str) and os.path.isfile(image):
408
- with open(image, "rb") as f:
409
- data = f.read()
410
- elif isinstance(image, str):
411
- if image.startswith('data:image'):
412
- image = image.split(',')[1]
413
- data = base64.b64decode(image)
414
  else:
415
- data = image.read()
416
-
417
- response = requests.post(
418
- self.image_api_url,
419
- headers=self.image_api_headers,
420
- data=data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  )
422
-
423
- if response.status_code == 200:
424
- caption = response.json()[0].get('generated_text', 'No caption generated')
425
- return caption
426
- else:
427
- return f"Error captioning image: {response.status_code} - {response.text}"
428
 
429
  except Exception as e:
430
- return f"Error processing image: {str(e)}"
431
-
432
  def generate_image(self, prompt):
433
  try:
434
  payload = {"inputs": prompt}
@@ -484,8 +501,11 @@ class XylariaChat:
484
  messages.append(msg)
485
 
486
  if image:
487
- image_caption = self.caption_image(image)
488
- user_input = f"description of an image: {image_caption}\n\nUser's message about it: {user_input}"
 
 
 
489
 
490
  messages.append(ChatMessage(
491
  role="user",
 
13
  from collections import Counter
14
  import json
15
  from datetime import datetime
16
+ from transformers import AutoProcessor, AutoModelForVision2Seq
17
+ from transformers.image_utils import load_image
18
 
19
  @dataclass
20
  class ChatMessage:
 
34
  model="mistralai/Mistral-Nemo-Instruct-2407",
35
  token=self.hf_token
36
  )
 
 
 
37
 
38
  self.image_gen_api_url = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
39
+ self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"}
40
 
41
  self.conversation_history = []
42
  self.persistent_memory = []
 
97
 
98
  self.chat_history_file = "chat_history.json"
99
 
100
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
101
+ self.vlm_processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
102
+ self.vlm_model = AutoModelForVision2Seq.from_pretrained(
103
+ "HuggingFaceTB/SmolVLM-Instruct",
104
+ torch_dtype=torch.bfloat16,
105
+ _attn_implementation="flash_attention_2" if self.device == "cuda" else "eager",
106
+ ).to(self.device)
107
 
108
  def update_internal_state(self, emotion_deltas, cognitive_load_deltas, introspection_delta, engagement_delta):
109
  for emotion, delta in emotion_deltas.items():
 
408
  print(f"Error resetting API client: {e}")
409
 
410
  return None
411
+
412
+ def caption_image_vlm(self, image, user_input):
413
  try:
414
+
415
+ if isinstance(image, str) and image.startswith('http'):
416
+ image = load_image(image)
417
+ elif isinstance(image, str) and os.path.isfile(image):
418
+ image = Image.open(image)
419
+ elif isinstance(image, str) and image.startswith('data:image'):
420
+ image = Image.open(base64.b64decode(image.split(',')[1]))
421
  else:
422
+ image = Image.fromarray(image)
423
+
424
+ messages = [
425
+ {
426
+ "role": "user",
427
+ "content": [
428
+ {"type": "image"},
429
+ {"type": "text", "text": user_input}
430
+ ]
431
+ },
432
+ ]
433
+
434
+ prompt = self.vlm_processor.apply_chat_template(messages, add_generation_prompt=True)
435
+ inputs = self.vlm_processor(text=prompt, images=[image], return_tensors="pt")
436
+ inputs = inputs.to(self.device)
437
+
438
+ generated_ids = self.vlm_model.generate(**inputs, max_new_tokens=500)
439
+ generated_texts = self.vlm_processor.batch_decode(
440
+ generated_ids,
441
+ skip_special_tokens=True,
442
  )
443
+
444
+ return generated_texts[0].split("Assistant: ")[-1]
 
 
 
 
445
 
446
  except Exception as e:
447
+ return f"Error captioning image with VLM: {str(e)}"
448
+
449
  def generate_image(self, prompt):
450
  try:
451
  payload = {"inputs": prompt}
 
501
  messages.append(msg)
502
 
503
  if image:
504
+ image_caption = self.caption_image_vlm(image, user_input)
505
+ messages.append(ChatMessage(
506
+ role="user",
507
+ content=image_caption
508
+ ).to_dict())
509
 
510
  messages.append(ChatMessage(
511
  role="user",