Reality123b commited on
Commit
6baa45b
·
verified ·
1 Parent(s): 417372b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -157
app.py CHANGED
@@ -5,8 +5,7 @@ import gradio as gr
5
  from huggingface_hub import InferenceClient
6
  from dataclasses import dataclass
7
  import pytesseract
8
- from PIL import Image, ImageGrab
9
- import io
10
 
11
  @dataclass
12
  class ChatMessage:
@@ -31,8 +30,8 @@ class XylariaChat:
31
  api_key=self.hf_token
32
  )
33
 
34
- # Image captioning API setup with the new model
35
- self.image_api_url = "https://api-inference.huggingface.co/models/microsoft/git-large-coco"
36
  self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"}
37
 
38
  # Initialize conversation history and persistent memory
@@ -75,47 +74,38 @@ class XylariaChat:
75
  """
76
  Caption an uploaded image using Hugging Face API
77
  Args:
78
- image (str or list): Base64 encoded image(s), file path(s), or file-like object(s)
79
  Returns:
80
- str: Concatenated image captions or error message
81
  """
82
  try:
83
- # Ensure image is a list
84
- if not isinstance(image, list):
85
- image = [image]
86
-
87
- captions = []
88
- for img in image:
89
- # If image is a file path, read and encode
90
- if isinstance(img, str) and os.path.isfile(img):
91
- with open(img, "rb") as f:
92
- data = f.read()
93
- # If image is already base64 encoded
94
- elif isinstance(img, str):
95
- # Remove data URI prefix if present
96
- if img.startswith('data:image'):
97
- img = img.split(',')[1]
98
- data = base64.b64decode(img)
99
- # If image is a file-like object
100
- else:
101
- data = img.read()
102
-
103
- # Send request to Hugging Face API
104
- response = requests.post(
105
- self.image_api_url,
106
- headers=self.image_api_headers,
107
- data=data
108
- )
109
-
110
- # Check response
111
- if response.status_code == 200:
112
- caption = response.json()[0].get('generated_text', 'No caption generated')
113
- captions.append(caption)
114
- else:
115
- captions.append(f"Error captioning image: {response.status_code} - {response.text}")
116
 
117
- # Return concatenated captions
118
- return "\n".join(captions)
 
 
 
 
119
 
120
  except Exception as e:
121
  return f"Error processing image: {str(e)}"
@@ -141,13 +131,12 @@ class XylariaChat:
141
  except Exception as e:
142
  return f"Error during Math OCR: {e}"
143
 
144
- def get_response(self, user_input, images=None, math_ocr_image=None):
145
  """
146
  Generate a response using chat completions with improved error handling
147
  Args:
148
  user_input (str): User's message
149
- images (list, optional): List of uploaded images
150
- math_ocr_image (str, optional): Path to math OCR image
151
  Returns:
152
  Stream of chat completions or error message
153
  """
@@ -178,25 +167,15 @@ class XylariaChat:
178
  content=msg['content']
179
  ).to_dict())
180
 
181
- # Process images if uploaded
182
- image_context = ""
183
- if images and any(images):
184
- image_caption = self.caption_image(images)
185
- image_context += f"Uploaded images: {image_caption}\n\n"
186
-
187
- # Process math OCR image if uploaded
188
- if math_ocr_image:
189
- ocr_text = self.perform_math_ocr(math_ocr_image)
190
- if not ocr_text.startswith("Error"):
191
- image_context += f"Math OCR Result: {ocr_text}\n\n"
192
-
193
- # Combine image context with user input
194
- full_input = image_context + user_input
195
 
196
  # Add user input
197
  messages.append(ChatMessage(
198
  role="user",
199
- content=full_input
200
  ).to_dict())
201
 
202
  # Calculate available tokens
@@ -242,32 +221,31 @@ class XylariaChat:
242
 
243
 
244
  def create_interface(self):
245
- def get_clipboard_image():
246
- """Capture image from clipboard"""
247
- try:
248
- img = ImageGrab.grabclipboard()
249
- if img is not None:
250
- # Save clipboard image to a temporary file
251
- temp_path = "clipboard_image.png"
252
- img.save(temp_path)
253
- return temp_path
254
- return None
255
- except Exception as e:
256
- print(f"Error getting clipboard image: {e}")
257
- return None
258
-
259
- def streaming_response(message, chat_history, image1, image2, image3, image4, image5, math_ocr_image_path):
260
- # Collect non-None images
261
- images = [img for img in [image1, image2, image3, image4, image5] if img is not None]
262
 
263
- # Generate response
264
- response_stream = self.get_response(message, images, math_ocr_image_path)
 
 
 
 
265
 
266
  # Handle errors in get_response
267
  if isinstance(response_stream, str):
268
  # Return immediately with the error message
269
  updated_history = chat_history + [[message, response_stream]]
270
- yield ("", updated_history) + ((None,) * 6)
271
  return
272
 
273
  # Prepare for streaming response
@@ -283,12 +261,12 @@ class XylariaChat:
283
 
284
  # Update the last message in chat history with partial response
285
  updated_history[-1][1] = full_response
286
- yield ("", updated_history) + ((None,) * 6)
287
  except Exception as e:
288
  print(f"Streaming error: {e}")
289
  # Display error in the chat interface
290
  updated_history[-1][1] = f"Error during response: {e}"
291
- yield ("", updated_history) + ((None,) * 6)
292
  return
293
 
294
  # Update conversation history
@@ -303,9 +281,6 @@ class XylariaChat:
303
  if len(self.conversation_history) > 10:
304
  self.conversation_history = self.conversation_history[-10:]
305
 
306
- # Reset image inputs after processing
307
- yield ("", updated_history, None, None, None, None, None, None)
308
-
309
  # Custom CSS for Inter font and improved styling
310
  custom_css = """
311
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
@@ -320,6 +295,38 @@ class XylariaChat:
320
  .gradio-container button {
321
  font-family: 'Inter', sans-serif !important;
322
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  """
324
 
325
  with gr.Blocks(theme='soft', css=custom_css) as demo:
@@ -331,6 +338,29 @@ class XylariaChat:
331
  show_copy_button=True,
332
  )
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  # Input row with improved layout
335
  with gr.Row():
336
  with gr.Column(scale=4):
@@ -339,102 +369,70 @@ class XylariaChat:
339
  placeholder="Type your message...",
340
  container=False
341
  )
342
-
343
- # Image and Math upload buttons
344
- with gr.Column(scale=1):
345
- # Buttons for image and math uploads with symbolic icons
346
- with gr.Row():
347
- img_upload_btn = gr.Button("🖼️") # Image upload button
348
- math_upload_btn = gr.Button("➗") # Math upload button
349
- clipboard_btn = gr.Button("📋") # Clipboard paste button
350
-
351
- # Multiple image inputs
352
- with gr.Accordion("Images", open=False):
353
- with gr.Column():
354
- with gr.Row():
355
- img1 = gr.Image(
356
- sources=["upload", "webcam"],
357
- type="filepath",
358
- label="Image 1",
359
- height=200
360
- )
361
- img2 = gr.Image(
362
- sources=["upload", "webcam"],
363
- type="filepath",
364
- label="Image 2",
365
- height=200
366
- )
367
- with gr.Row():
368
- img3 = gr.Image(
369
- sources=["upload", "webcam"],
370
- type="filepath",
371
- label="Image 3",
372
- height=200
373
- )
374
- img4 = gr.Image(
375
- sources=["upload", "webcam"],
376
- type="filepath",
377
- label="Image 4",
378
- height=200
379
- )
380
- img5 = gr.Image(
381
- sources=["upload", "webcam"],
382
- type="filepath",
383
- label="Image 5",
384
- height=200
385
- )
386
-
387
- # Math OCR Image Upload
388
- with gr.Accordion("Math Input", open=False):
389
- math_ocr_img = gr.Image(
390
- sources=["upload", "webcam"],
391
- type="filepath",
392
- label="Upload Image for math",
393
- height=200
394
- )
395
 
396
  # Clear history and memory buttons
397
  with gr.Row():
398
  clear = gr.Button("Clear Conversation")
399
  clear_memory = gr.Button("Clear Memory")
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  # Submit functionality with streaming and image support
402
- btn = gr.Button("Send")
403
  btn.click(
404
  fn=streaming_response,
405
- inputs=[txt, chatbot, img1, img2, img3, img4, img5, math_ocr_img],
406
- outputs=[txt, chatbot, img1, img2, img3, img4, img5, math_ocr_img]
407
  )
408
  txt.submit(
409
  fn=streaming_response,
410
- inputs=[txt, chatbot, img1, img2, img3, img4, img5, math_ocr_img],
411
- outputs=[txt, chatbot, img1, img2, img3, img4, img5, math_ocr_img]
412
  )
413
 
414
- # Clipboard button functionality
415
- clipboard_btn.click(
416
- fn=get_clipboard_image,
417
- outputs=[img1]
418
- )
419
-
420
- # Clear conversation button
421
  clear.click(
422
- fn=self.reset_conversation,
423
  inputs=None,
424
- outputs=[chatbot, txt, img1, img2, img3, img4, img5, math_ocr_img]
 
425
  )
426
 
427
- # Clear memory button
428
  clear_memory.click(
429
- fn=lambda: self.persistent_memory.clear(),
430
  inputs=None,
431
- outputs=[]
 
432
  )
433
 
 
 
 
434
  return demo
435
 
436
- # Optional: If you want to run the interface
437
- if __name__ == "__main__":
438
  chat = XylariaChat()
439
  interface = chat.create_interface()
440
- interface.launch()
 
 
 
 
 
 
 
5
  from huggingface_hub import InferenceClient
6
  from dataclasses import dataclass
7
  import pytesseract
8
+ from PIL import Image
 
9
 
10
  @dataclass
11
  class ChatMessage:
 
30
  api_key=self.hf_token
31
  )
32
 
33
+ # Image captioning API setup
34
+ self.image_api_url = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"
35
  self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"}
36
 
37
  # Initialize conversation history and persistent memory
 
74
  """
75
  Caption an uploaded image using Hugging Face API
76
  Args:
77
+ image (str): Base64 encoded image or file path
78
  Returns:
79
+ str: Image caption or error message
80
  """
81
  try:
82
+ # If image is a file path, read and encode
83
+ if isinstance(image, str) and os.path.isfile(image):
84
+ with open(image, "rb") as f:
85
+ data = f.read()
86
+ # If image is already base64 encoded
87
+ elif isinstance(image, str):
88
+ # Remove data URI prefix if present
89
+ if image.startswith('data:image'):
90
+ image = image.split(',')[1]
91
+ data = base64.b64decode(image)
92
+ # If image is a file-like object (unlikely with Gradio, but good to have)
93
+ else:
94
+ data = image.read()
95
+
96
+ # Send request to Hugging Face API
97
+ response = requests.post(
98
+ self.image_api_url,
99
+ headers=self.image_api_headers,
100
+ data=data
101
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ # Check response
104
+ if response.status_code == 200:
105
+ caption = response.json()[0].get('generated_text', 'No caption generated')
106
+ return caption
107
+ else:
108
+ return f"Error captioning image: {response.status_code} - {response.text}"
109
 
110
  except Exception as e:
111
  return f"Error processing image: {str(e)}"
 
131
  except Exception as e:
132
  return f"Error during Math OCR: {e}"
133
 
134
+ def get_response(self, user_input, image=None):
135
  """
136
  Generate a response using chat completions with improved error handling
137
  Args:
138
  user_input (str): User's message
139
+ image (optional): Uploaded image
 
140
  Returns:
141
  Stream of chat completions or error message
142
  """
 
167
  content=msg['content']
168
  ).to_dict())
169
 
170
+ # Process image if uploaded
171
+ if image:
172
+ image_caption = self.caption_image(image)
173
+ user_input = f"Uploaded image : {image_caption}\n\nUser's message: {user_input}"
 
 
 
 
 
 
 
 
 
 
174
 
175
  # Add user input
176
  messages.append(ChatMessage(
177
  role="user",
178
+ content=user_input
179
  ).to_dict())
180
 
181
  # Calculate available tokens
 
221
 
222
 
223
  def create_interface(self):
224
+ def streaming_response(message, chat_history, image_filepath, math_ocr_image_path):
225
+
226
+ ocr_text = ""
227
+ if math_ocr_image_path:
228
+ ocr_text = self.perform_math_ocr(math_ocr_image_path)
229
+ if ocr_text.startswith("Error"):
230
+ # Handle OCR error
231
+ updated_history = chat_history + [[message, ocr_text]]
232
+ yield "", updated_history, None, None
233
+ return
234
+ else:
235
+ message = f"Math OCR Result: {ocr_text}\n\nUser's message: {message}"
 
 
 
 
 
236
 
237
+ # Check if an image was actually uploaded
238
+ if image_filepath:
239
+ response_stream = self.get_response(message, image_filepath)
240
+ else:
241
+ response_stream = self.get_response(message)
242
+
243
 
244
  # Handle errors in get_response
245
  if isinstance(response_stream, str):
246
  # Return immediately with the error message
247
  updated_history = chat_history + [[message, response_stream]]
248
+ yield "", updated_history, None, None
249
  return
250
 
251
  # Prepare for streaming response
 
261
 
262
  # Update the last message in chat history with partial response
263
  updated_history[-1][1] = full_response
264
+ yield "", updated_history, None, None
265
  except Exception as e:
266
  print(f"Streaming error: {e}")
267
  # Display error in the chat interface
268
  updated_history[-1][1] = f"Error during response: {e}"
269
+ yield "", updated_history, None, None
270
  return
271
 
272
  # Update conversation history
 
281
  if len(self.conversation_history) > 10:
282
  self.conversation_history = self.conversation_history[-10:]
283
 
 
 
 
284
  # Custom CSS for Inter font and improved styling
285
  custom_css = """
286
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
 
295
  .gradio-container button {
296
  font-family: 'Inter', sans-serif !important;
297
  }
298
+ /* Image Upload Styling */
299
+ .image-container {
300
+ border: 1px solid #ccc;
301
+ border-radius: 8px;
302
+ padding: 10px;
303
+ margin-bottom: 10px;
304
+ display: flex;
305
+ flex-direction: column;
306
+ align-items: center;
307
+ gap: 10px;
308
+ background-color: #f8f8f8;
309
+ }
310
+ .image-preview {
311
+ max-width: 200px;
312
+ max-height: 200px;
313
+ border-radius: 8px;
314
+ }
315
+ .image-buttons {
316
+ display: flex;
317
+ gap: 10px;
318
+ }
319
+ .image-buttons button {
320
+ padding: 8px 15px;
321
+ border-radius: 5px;
322
+ background-color: #4CAF50;
323
+ color: white;
324
+ border: none;
325
+ cursor: pointer;
326
+ }
327
+ .image-buttons button:hover {
328
+ background-color: #367c39;
329
+ }
330
  """
331
 
332
  with gr.Blocks(theme='soft', css=custom_css) as demo:
 
338
  show_copy_button=True,
339
  )
340
 
341
+ # Enhanced Image Upload Section
342
+ with gr.Accordion("Image Input", open=False):
343
+ with gr.Column() as image_container: # Use a Column for the image container
344
+ img = gr.Image(
345
+ sources=["upload", "webcam"],
346
+ type="filepath",
347
+ label="", # Remove label as it's redundant
348
+ elem_classes="image-preview", # Add a class for styling
349
+ )
350
+ with gr.Row():
351
+ clear_image_btn = gr.Button("Clear Image")
352
+
353
+ with gr.Accordion("Math Input", open=False):
354
+ with gr.Column():
355
+ math_ocr_img = gr.Image(
356
+ sources=["upload", "webcam"],
357
+ type="filepath",
358
+ label="Upload Image for math",
359
+ elem_classes="image-preview"
360
+ )
361
+ with gr.Row():
362
+ clear_math_ocr_btn = gr.Button("Clear Math Image")
363
+
364
  # Input row with improved layout
365
  with gr.Row():
366
  with gr.Column(scale=4):
 
369
  placeholder="Type your message...",
370
  container=False
371
  )
372
+ btn = gr.Button("Send", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  # Clear history and memory buttons
375
  with gr.Row():
376
  clear = gr.Button("Clear Conversation")
377
  clear_memory = gr.Button("Clear Memory")
378
 
379
+ # Clear image functionality
380
+ clear_image_btn.click(
381
+ fn=lambda: None,
382
+ inputs=None,
383
+ outputs=[img],
384
+ queue=False
385
+ )
386
+
387
+ # Clear Math OCR image functionality
388
+ clear_math_ocr_btn.click(
389
+ fn=lambda: None,
390
+ inputs=None,
391
+ outputs=[math_ocr_img],
392
+ queue=False
393
+ )
394
+
395
  # Submit functionality with streaming and image support
 
396
  btn.click(
397
  fn=streaming_response,
398
+ inputs=[txt, chatbot, img, math_ocr_img],
399
+ outputs=[txt, chatbot, img, math_ocr_img]
400
  )
401
  txt.submit(
402
  fn=streaming_response,
403
+ inputs=[txt, chatbot, img, math_ocr_img],
404
+ outputs=[txt, chatbot, img, math_ocr_img]
405
  )
406
 
407
+ # Clear conversation history
 
 
 
 
 
 
408
  clear.click(
409
+ fn=lambda: None,
410
  inputs=None,
411
+ outputs=[chatbot],
412
+ queue=False
413
  )
414
 
415
+ # Clear persistent memory and reset conversation
416
  clear_memory.click(
417
+ fn=self.reset_conversation,
418
  inputs=None,
419
+ outputs=[chatbot],
420
+ queue=False
421
  )
422
 
423
+ # Ensure memory is cleared when the interface is closed
424
+ demo.load(self.reset_conversation, None, None)
425
+
426
  return demo
427
 
428
+ # Launch the interface
429
+ def main():
430
  chat = XylariaChat()
431
  interface = chat.create_interface()
432
+ interface.launch(
433
+ share=True, # Optional: create a public link
434
+ debug=True # Show detailed errors
435
+ )
436
+
437
+ if __name__ == "__main__":
438
+ main()