Reality123b commited on
Commit
01cbb26
·
verified ·
1 Parent(s): 21418e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -47
app.py CHANGED
@@ -9,42 +9,57 @@ from PIL import Image
9
 
10
  @dataclass
11
  class ChatMessage:
 
12
  role: str
13
  content: str
14
 
15
  def to_dict(self):
 
16
  return {"role": self.role, "content": self.content}
17
 
18
  class XylariaChat:
19
  def __init__(self):
 
20
  self.hf_token = os.getenv("HF_TOKEN")
21
  if not self.hf_token:
22
  raise ValueError("HuggingFace token not found in environment variables")
23
 
 
24
  self.client = InferenceClient(
25
- model="Qwen/QwQ-32B-Preview",
26
  api_key=self.hf_token
27
  )
28
 
29
- self.image_api_url = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"
 
30
  self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"}
31
 
 
32
  self.conversation_history = []
33
  self.persistent_memory = {}
34
 
 
35
  self.system_prompt = """You are a helpful and harmless assistant. You are Xylaria developed by Sk Md Saad Amin . You should think step-by-step."""
36
 
37
  def store_information(self, key, value):
 
38
  self.persistent_memory[key] = value
39
  return f"Stored: {key} = {value}"
40
 
41
  def retrieve_information(self, key):
 
42
  return self.persistent_memory.get(key, "No information found for this key.")
43
 
44
  def reset_conversation(self):
 
 
 
 
 
45
  self.conversation_history = []
46
  self.persistent_memory.clear()
47
 
 
48
  try:
49
  self.client = InferenceClient(
50
  model="Qwen/QwQ-32B-Preview",
@@ -53,26 +68,39 @@ class XylariaChat:
53
  except Exception as e:
54
  print(f"Error resetting API client: {e}")
55
 
56
- return None
57
 
58
  def caption_image(self, image):
 
 
 
 
 
 
 
59
  try:
 
60
  if isinstance(image, str) and os.path.isfile(image):
61
  with open(image, "rb") as f:
62
  data = f.read()
 
63
  elif isinstance(image, str):
 
64
  if image.startswith('data:image'):
65
  image = image.split(',')[1]
66
  data = base64.b64decode(image)
 
67
  else:
68
  data = image.read()
69
 
 
70
  response = requests.post(
71
  self.image_api_url,
72
  headers=self.image_api_headers,
73
  data=data
74
  )
75
 
 
76
  if response.status_code == 200:
77
  caption = response.json()[0].get('generated_text', 'No caption generated')
78
  return caption
@@ -83,22 +111,46 @@ class XylariaChat:
83
  return f"Error processing image: {str(e)}"
84
 
85
  def perform_math_ocr(self, image_path):
 
 
 
 
 
 
 
86
  try:
 
87
  img = Image.open(image_path)
 
 
88
  text = pytesseract.image_to_string(img)
 
 
89
  return text.strip()
 
90
  except Exception as e:
91
  return f"Error during Math OCR: {e}"
92
-
93
  def get_response(self, user_input, image=None):
 
 
 
 
 
 
 
 
94
  try:
 
95
  messages = []
96
 
 
97
  messages.append(ChatMessage(
98
  role="system",
99
  content=self.system_prompt
100
  ).to_dict())
101
 
 
102
  if self.persistent_memory:
103
  memory_context = "Remembered Information:\n" + "\n".join(
104
  [f"{k}: {v}" for k, v in self.persistent_memory.items()]
@@ -108,23 +160,29 @@ class XylariaChat:
108
  content=memory_context
109
  ).to_dict())
110
 
 
111
  for msg in self.conversation_history:
112
  messages.append(msg)
113
 
 
114
  if image:
115
  image_caption = self.caption_image(image)
116
  user_input = f"description of an image: {image_caption}\n\nUser's message about it: {user_input}"
117
 
 
118
  messages.append(ChatMessage(
119
  role="user",
120
  content=user_input
121
  ).to_dict())
122
 
 
123
  input_tokens = sum(len(msg['content'].split()) for msg in messages)
124
- max_new_tokens = 16384 - input_tokens - 50
125
 
 
126
  max_new_tokens = min(max_new_tokens, 10020)
127
 
 
128
  stream = self.client.chat_completion(
129
  messages=messages,
130
  model="Qwen/QwQ-32B-Preview",
@@ -133,14 +191,20 @@ class XylariaChat:
133
  top_p=0.9,
134
  stream=True
135
  )
136
-
137
  return stream
138
-
139
  except Exception as e:
140
  print(f"Detailed error in get_response: {e}")
141
  return f"Error generating response: {str(e)}"
142
 
143
  def messages_to_prompt(self, messages):
 
 
 
 
 
 
144
  prompt = ""
145
  for msg in messages:
146
  if msg["role"] == "system":
@@ -149,59 +213,68 @@ class XylariaChat:
149
  prompt += f"<|user|>\n{msg['content']}<|end|>\n"
150
  elif msg["role"] == "assistant":
151
  prompt += f"<|assistant|>\n{msg['content']}<|end|>\n"
152
- prompt += "<|assistant|>\n"
153
  return prompt
154
-
 
155
  def create_interface(self):
156
  def streaming_response(message, chat_history, image_filepath, math_ocr_image_path):
 
157
  ocr_text = ""
158
  if math_ocr_image_path:
159
  ocr_text = self.perform_math_ocr(math_ocr_image_path)
160
  if ocr_text.startswith("Error"):
161
- updated_history = chat_history + [[{"role": "user", "content": message}, {"role": "assistant", "content": ocr_text}]]
162
- yield "", updated_history, None, None
163
- return
164
- elif len(ocr_text) > 500:
165
- ocr_text = "OCR output is too large to be processed."
166
- updated_history = chat_history + [[{"role": "user", "content": message}, {"role": "assistant", "content": ocr_text}]]
167
  yield "", updated_history, None, None
168
  return
169
  else:
170
  message = f"Math OCR Result: {ocr_text}\n\nUser's message: {message}"
171
 
 
172
  if image_filepath:
173
  response_stream = self.get_response(message, image_filepath)
174
  else:
175
  response_stream = self.get_response(message)
 
176
 
 
177
  if isinstance(response_stream, str):
178
- updated_history = chat_history + [[{"role": "user", "content": message}, {"role": "assistant", "content": response_stream}]]
 
179
  yield "", updated_history, None, None
180
  return
181
 
 
182
  full_response = ""
183
- updated_history = chat_history + [[{"role": "user", "content": message}, {"role": "assistant", "content": ""}]]
184
 
 
185
  try:
186
  for chunk in response_stream:
187
  if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
188
  chunk_content = chunk.choices[0].delta.content
189
  full_response += chunk_content
190
-
191
- updated_history[-1][1]["content"] = full_response
 
192
  yield "", updated_history, None, None
193
  except Exception as e:
194
  print(f"Streaming error: {e}")
195
- updated_history[-1][1]["content"] = f"Error during response: {e}"
 
196
  yield "", updated_history, None, None
197
  return
198
 
 
199
  self.conversation_history.append(ChatMessage(role="user", content=message).to_dict())
200
  self.conversation_history.append(ChatMessage(role="assistant", content=full_response).to_dict())
201
 
 
202
  if len(self.conversation_history) > 10:
203
  self.conversation_history = self.conversation_history[-10:]
204
 
 
205
  custom_css = """
206
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
207
  body, .gradio-container {
@@ -215,6 +288,7 @@ class XylariaChat:
215
  .gradio-container button {
216
  font-family: 'Inter', sans-serif !important;
217
  }
 
218
  .image-container {
219
  display: flex;
220
  gap: 10px;
@@ -231,9 +305,11 @@ class XylariaChat:
231
  max-height: 200px;
232
  border-radius: 8px;
233
  }
 
234
  .clear-button {
235
  display: none;
236
  }
 
237
  .chatbot-container .message {
238
  opacity: 0;
239
  animation: fadeIn 0.5s ease-in-out forwards;
@@ -248,27 +324,20 @@ class XylariaChat:
248
  transform: translateY(0);
249
  }
250
  }
251
- .gradio-accordion {
252
- overflow: hidden;
253
- transition: max-height 0.3s ease-in-out;
254
- max-height: 0;
255
- }
256
- .gradio-accordion.open {
257
- max-height: 500px;
258
- }
259
  """
260
 
261
  with gr.Blocks(theme='soft', css=custom_css) as demo:
 
262
  with gr.Column():
263
  chatbot = gr.Chatbot(
264
  label="Xylaria 1.5 Senoa (EXPERIMENTAL)",
265
  height=500,
266
  show_copy_button=True,
267
- type='messages'
268
  )
269
 
270
- with gr.Accordion("Image Input", open=False) as accordion:
271
- with gr.Row(elem_classes="image-container"):
 
272
  with gr.Column(elem_classes="image-upload"):
273
  img = gr.Image(
274
  sources=["upload", "webcam"],
@@ -283,7 +352,9 @@ class XylariaChat:
283
  label="Upload Image for Math OCR",
284
  elem_classes="image-preview"
285
  )
 
286
 
 
287
  with gr.Row():
288
  with gr.Column(scale=4):
289
  txt = gr.Textbox(
@@ -293,10 +364,12 @@ class XylariaChat:
293
  )
294
  btn = gr.Button("Send", scale=1)
295
 
 
296
  with gr.Row():
297
  clear = gr.Button("Clear Conversation")
298
  clear_memory = gr.Button("Clear Memory")
299
 
 
300
  btn.click(
301
  fn=streaming_response,
302
  inputs=[txt, chatbot, img, math_ocr_img],
@@ -308,6 +381,7 @@ class XylariaChat:
308
  outputs=[txt, chatbot, img, math_ocr_img]
309
  )
310
 
 
311
  clear.click(
312
  fn=lambda: None,
313
  inputs=None,
@@ -315,6 +389,7 @@ class XylariaChat:
315
  queue=False
316
  )
317
 
 
318
  clear_memory.click(
319
  fn=self.reset_conversation,
320
  inputs=None,
@@ -322,30 +397,18 @@ class XylariaChat:
322
  queue=False
323
  )
324
 
325
- demo.load(None, None, None, _js="""
326
- () => {
327
- const accordion = document.querySelector(".gradio-accordion");
328
-
329
- if (accordion) {
330
- const accordionHeader = accordion.querySelector(".label-wrap");
331
-
332
- accordionHeader.addEventListener("click", () => {
333
- accordion.classList.toggle("open");
334
- });
335
- }
336
- }
337
- """)
338
-
339
  demo.load(self.reset_conversation, None, None)
340
 
341
  return demo
342
 
 
343
  def main():
344
  chat = XylariaChat()
345
  interface = chat.create_interface()
346
  interface.launch(
347
- share=False,
348
- debug=True
349
  )
350
 
351
  if __name__ == "__main__":
 
9
 
10
  @dataclass
11
  class ChatMessage:
12
+ """Custom ChatMessage class since huggingface_hub doesn't provide one"""
13
  role: str
14
  content: str
15
 
16
  def to_dict(self):
17
+ """Converts ChatMessage to a dictionary for JSON serialization."""
18
  return {"role": self.role, "content": self.content}
19
 
20
  class XylariaChat:
21
  def __init__(self):
22
+ # Securely load HuggingFace token
23
  self.hf_token = os.getenv("HF_TOKEN")
24
  if not self.hf_token:
25
  raise ValueError("HuggingFace token not found in environment variables")
26
 
27
+ # Initialize the inference client with the Qwen model
28
  self.client = InferenceClient(
29
+ model="Qwen/QwQ-32B-Preview", # Using the specified model
30
  api_key=self.hf_token
31
  )
32
 
33
+ # Image captioning API setup
34
+ self.image_api_url = "https://api-inference.huggingface.co/models/microsoft/git-large-coco"
35
  self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"}
36
 
37
+ # Initialize conversation history and persistent memory
38
  self.conversation_history = []
39
  self.persistent_memory = {}
40
 
41
+ # System prompt with more detailed instructions
42
  self.system_prompt = """You are a helpful and harmless assistant. You are Xylaria developed by Sk Md Saad Amin . You should think step-by-step."""
43
 
44
  def store_information(self, key, value):
45
+ """Store important information in persistent memory"""
46
  self.persistent_memory[key] = value
47
  return f"Stored: {key} = {value}"
48
 
49
  def retrieve_information(self, key):
50
+ """Retrieve information from persistent memory"""
51
  return self.persistent_memory.get(key, "No information found for this key.")
52
 
53
  def reset_conversation(self):
54
+ """
55
+ Completely reset the conversation history, persistent memory,
56
+ and clear API-side memory
57
+ """
58
+ # Clear local memory
59
  self.conversation_history = []
60
  self.persistent_memory.clear()
61
 
62
+ # Reinitialize the client (not strictly necessary for the API, but can help with local state)
63
  try:
64
  self.client = InferenceClient(
65
  model="Qwen/QwQ-32B-Preview",
 
68
  except Exception as e:
69
  print(f"Error resetting API client: {e}")
70
 
71
+ return None # To clear the chatbot interface
72
 
73
  def caption_image(self, image):
74
+ """
75
+ Caption an uploaded image using Hugging Face API
76
+ Args:
77
+ image (str): Base64 encoded image or file path
78
+ Returns:
79
+ str: Image caption or error message
80
+ """
81
  try:
82
+ # If image is a file path, read and encode
83
  if isinstance(image, str) and os.path.isfile(image):
84
  with open(image, "rb") as f:
85
  data = f.read()
86
+ # If image is already base64 encoded
87
  elif isinstance(image, str):
88
+ # Remove data URI prefix if present
89
  if image.startswith('data:image'):
90
  image = image.split(',')[1]
91
  data = base64.b64decode(image)
92
+ # If image is a file-like object (unlikely with Gradio, but good to have)
93
  else:
94
  data = image.read()
95
 
96
+ # Send request to Hugging Face API
97
  response = requests.post(
98
  self.image_api_url,
99
  headers=self.image_api_headers,
100
  data=data
101
  )
102
 
103
+ # Check response
104
  if response.status_code == 200:
105
  caption = response.json()[0].get('generated_text', 'No caption generated')
106
  return caption
 
111
  return f"Error processing image: {str(e)}"
112
 
113
  def perform_math_ocr(self, image_path):
114
+ """
115
+ Perform OCR on an image and return the extracted text.
116
+ Args:
117
+ image_path (str): Path to the image file.
118
+ Returns:
119
+ str: Extracted text from the image, or an error message.
120
+ """
121
  try:
122
+ # Open the image using Pillow library
123
  img = Image.open(image_path)
124
+
125
+ # Use Tesseract to do OCR on the image
126
  text = pytesseract.image_to_string(img)
127
+
128
+ # Remove leading/trailing whitespace and return
129
  return text.strip()
130
+
131
  except Exception as e:
132
  return f"Error during Math OCR: {e}"
133
+
134
  def get_response(self, user_input, image=None):
135
+ """
136
+ Generate a response using chat completions with improved error handling
137
+ Args:
138
+ user_input (str): User's message
139
+ image (optional): Uploaded image
140
+ Returns:
141
+ Stream of chat completions or error message
142
+ """
143
  try:
144
+ # Prepare messages with conversation context and persistent memory
145
  messages = []
146
 
147
+ # Add system prompt as first message
148
  messages.append(ChatMessage(
149
  role="system",
150
  content=self.system_prompt
151
  ).to_dict())
152
 
153
+ # Add persistent memory context if available
154
  if self.persistent_memory:
155
  memory_context = "Remembered Information:\n" + "\n".join(
156
  [f"{k}: {v}" for k, v in self.persistent_memory.items()]
 
160
  content=memory_context
161
  ).to_dict())
162
 
163
+ # Convert existing conversation history to ChatMessage objects and then to dictionaries
164
  for msg in self.conversation_history:
165
  messages.append(msg)
166
 
167
+ # Process image if uploaded
168
  if image:
169
  image_caption = self.caption_image(image)
170
  user_input = f"description of an image: {image_caption}\n\nUser's message about it: {user_input}"
171
 
172
+ # Add user input
173
  messages.append(ChatMessage(
174
  role="user",
175
  content=user_input
176
  ).to_dict())
177
 
178
+ # Calculate available tokens
179
  input_tokens = sum(len(msg['content'].split()) for msg in messages)
180
+ max_new_tokens = 16384 - input_tokens - 50 # Reserve some tokens for safety
181
 
182
+ # Limit max_new_tokens to prevent exceeding the total limit
183
  max_new_tokens = min(max_new_tokens, 10020)
184
 
185
+ # Generate response with streaming
186
  stream = self.client.chat_completion(
187
  messages=messages,
188
  model="Qwen/QwQ-32B-Preview",
 
191
  top_p=0.9,
192
  stream=True
193
  )
194
+
195
  return stream
196
+
197
  except Exception as e:
198
  print(f"Detailed error in get_response: {e}")
199
  return f"Error generating response: {str(e)}"
200
 
201
  def messages_to_prompt(self, messages):
202
+ """
203
+ Convert a list of ChatMessage dictionaries to a single prompt string.
204
+
205
+ This is a simple implementation and you might need to adjust it
206
+ based on the specific requirements of the model you are using.
207
+ """
208
  prompt = ""
209
  for msg in messages:
210
  if msg["role"] == "system":
 
213
  prompt += f"<|user|>\n{msg['content']}<|end|>\n"
214
  elif msg["role"] == "assistant":
215
  prompt += f"<|assistant|>\n{msg['content']}<|end|>\n"
216
+ prompt += "<|assistant|>\n" # Start of assistant's turn
217
  return prompt
218
+
219
+
220
  def create_interface(self):
221
  def streaming_response(message, chat_history, image_filepath, math_ocr_image_path):
222
+
223
  ocr_text = ""
224
  if math_ocr_image_path:
225
  ocr_text = self.perform_math_ocr(math_ocr_image_path)
226
  if ocr_text.startswith("Error"):
227
+ # Handle OCR error
228
+ updated_history = chat_history + [[message, ocr_text]]
 
 
 
 
229
  yield "", updated_history, None, None
230
  return
231
  else:
232
  message = f"Math OCR Result: {ocr_text}\n\nUser's message: {message}"
233
 
234
+ # Check if an image was actually uploaded
235
  if image_filepath:
236
  response_stream = self.get_response(message, image_filepath)
237
  else:
238
  response_stream = self.get_response(message)
239
+
240
 
241
+ # Handle errors in get_response
242
  if isinstance(response_stream, str):
243
+ # Return immediately with the error message
244
+ updated_history = chat_history + [[message, response_stream]]
245
  yield "", updated_history, None, None
246
  return
247
 
248
+ # Prepare for streaming response
249
  full_response = ""
250
+ updated_history = chat_history + [[message, ""]]
251
 
252
+ # Streaming output
253
  try:
254
  for chunk in response_stream:
255
  if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
256
  chunk_content = chunk.choices[0].delta.content
257
  full_response += chunk_content
258
+
259
+ # Update the last message in chat history with partial response
260
+ updated_history[-1][1] = full_response
261
  yield "", updated_history, None, None
262
  except Exception as e:
263
  print(f"Streaming error: {e}")
264
+ # Display error in the chat interface
265
+ updated_history[-1][1] = f"Error during response: {e}"
266
  yield "", updated_history, None, None
267
  return
268
 
269
+ # Update conversation history
270
  self.conversation_history.append(ChatMessage(role="user", content=message).to_dict())
271
  self.conversation_history.append(ChatMessage(role="assistant", content=full_response).to_dict())
272
 
273
+ # Limit conversation history
274
  if len(self.conversation_history) > 10:
275
  self.conversation_history = self.conversation_history[-10:]
276
 
277
+ # Custom CSS for Inter font and improved styling
278
  custom_css = """
279
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
280
  body, .gradio-container {
 
288
  .gradio-container button {
289
  font-family: 'Inter', sans-serif !important;
290
  }
291
+ /* Image Upload Styling */
292
  .image-container {
293
  display: flex;
294
  gap: 10px;
 
305
  max-height: 200px;
306
  border-radius: 8px;
307
  }
308
+ /* Remove clear image buttons */
309
  .clear-button {
310
  display: none;
311
  }
312
+ /* Animate chatbot messages */
313
  .chatbot-container .message {
314
  opacity: 0;
315
  animation: fadeIn 0.5s ease-in-out forwards;
 
324
  transform: translateY(0);
325
  }
326
  }
 
 
 
 
 
 
 
 
327
  """
328
 
329
  with gr.Blocks(theme='soft', css=custom_css) as demo:
330
+ # Chat interface with improved styling
331
  with gr.Column():
332
  chatbot = gr.Chatbot(
333
  label="Xylaria 1.5 Senoa (EXPERIMENTAL)",
334
  height=500,
335
  show_copy_button=True,
 
336
  )
337
 
338
+ # Enhanced Image Upload Section
339
+ with gr.Accordion("Image Input", open=False):
340
+ with gr.Row(elem_classes="image-container"): # Use a Row for side-by-side layout
341
  with gr.Column(elem_classes="image-upload"):
342
  img = gr.Image(
343
  sources=["upload", "webcam"],
 
352
  label="Upload Image for Math OCR",
353
  elem_classes="image-preview"
354
  )
355
+ # Removed clear buttons as per requirement
356
 
357
+ # Input row with improved layout
358
  with gr.Row():
359
  with gr.Column(scale=4):
360
  txt = gr.Textbox(
 
364
  )
365
  btn = gr.Button("Send", scale=1)
366
 
367
+ # Clear history and memory buttons
368
  with gr.Row():
369
  clear = gr.Button("Clear Conversation")
370
  clear_memory = gr.Button("Clear Memory")
371
 
372
+ # Submit functionality with streaming and image support
373
  btn.click(
374
  fn=streaming_response,
375
  inputs=[txt, chatbot, img, math_ocr_img],
 
381
  outputs=[txt, chatbot, img, math_ocr_img]
382
  )
383
 
384
+ # Clear conversation history
385
  clear.click(
386
  fn=lambda: None,
387
  inputs=None,
 
389
  queue=False
390
  )
391
 
392
+ # Clear persistent memory and reset conversation
393
  clear_memory.click(
394
  fn=self.reset_conversation,
395
  inputs=None,
 
397
  queue=False
398
  )
399
 
400
+ # Ensure memory is cleared when the interface is closed
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  demo.load(self.reset_conversation, None, None)
402
 
403
  return demo
404
 
405
+ # Launch the interface
406
  def main():
407
  chat = XylariaChat()
408
  interface = chat.create_interface()
409
  interface.launch(
410
+ share=True, # Optional: create a public link
411
+ debug=True # Show detailed errors
412
  )
413
 
414
  if __name__ == "__main__":