Reality123b commited on
Commit
417372b
·
verified ·
1 Parent(s): 1637733

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -157
app.py CHANGED
@@ -5,7 +5,8 @@ import gradio as gr
5
  from huggingface_hub import InferenceClient
6
  from dataclasses import dataclass
7
  import pytesseract
8
- from PIL import Image
 
9
 
10
  @dataclass
11
  class ChatMessage:
@@ -30,8 +31,8 @@ class XylariaChat:
30
  api_key=self.hf_token
31
  )
32
 
33
- # Image captioning API setup
34
- self.image_api_url = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"
35
  self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"}
36
 
37
  # Initialize conversation history and persistent memory
@@ -74,38 +75,47 @@ class XylariaChat:
74
  """
75
  Caption an uploaded image using Hugging Face API
76
  Args:
77
- image (str): Base64 encoded image or file path
78
  Returns:
79
- str: Image caption or error message
80
  """
81
  try:
82
- # If image is a file path, read and encode
83
- if isinstance(image, str) and os.path.isfile(image):
84
- with open(image, "rb") as f:
85
- data = f.read()
86
- # If image is already base64 encoded
87
- elif isinstance(image, str):
88
- # Remove data URI prefix if present
89
- if image.startswith('data:image'):
90
- image = image.split(',')[1]
91
- data = base64.b64decode(image)
92
- # If image is a file-like object (unlikely with Gradio, but good to have)
93
- else:
94
- data = image.read()
95
-
96
- # Send request to Hugging Face API
97
- response = requests.post(
98
- self.image_api_url,
99
- headers=self.image_api_headers,
100
- data=data
101
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- # Check response
104
- if response.status_code == 200:
105
- caption = response.json()[0].get('generated_text', 'No caption generated')
106
- return caption
107
- else:
108
- return f"Error captioning image: {response.status_code} - {response.text}"
109
 
110
  except Exception as e:
111
  return f"Error processing image: {str(e)}"
@@ -113,10 +123,8 @@ class XylariaChat:
113
  def perform_math_ocr(self, image_path):
114
  """
115
  Perform OCR on an image and return the extracted text.
116
-
117
  Args:
118
  image_path (str): Path to the image file.
119
-
120
  Returns:
121
  str: Extracted text from the image, or an error message.
122
  """
@@ -133,12 +141,13 @@ class XylariaChat:
133
  except Exception as e:
134
  return f"Error during Math OCR: {e}"
135
 
136
- def get_response(self, user_input, image=None):
137
  """
138
  Generate a response using chat completions with improved error handling
139
  Args:
140
  user_input (str): User's message
141
- image (optional): Uploaded image
 
142
  Returns:
143
  Stream of chat completions or error message
144
  """
@@ -169,15 +178,25 @@ class XylariaChat:
169
  content=msg['content']
170
  ).to_dict())
171
 
172
- # Process image if uploaded
173
- if image:
174
- image_caption = self.caption_image(image)
175
- user_input = f"Uploaded image : {image_caption}\n\nUser's message: {user_input}"
 
 
 
 
 
 
 
 
 
 
176
 
177
  # Add user input
178
  messages.append(ChatMessage(
179
  role="user",
180
- content=user_input
181
  ).to_dict())
182
 
183
  # Calculate available tokens
@@ -223,31 +242,32 @@ class XylariaChat:
223
 
224
 
225
  def create_interface(self):
226
- def streaming_response(message, chat_history, image_filepath, math_ocr_image_path):
227
-
228
- ocr_text = ""
229
- if math_ocr_image_path:
230
- ocr_text = self.perform_math_ocr(math_ocr_image_path)
231
- if ocr_text.startswith("Error"):
232
- # Handle OCR error
233
- updated_history = chat_history + [[message, ocr_text]]
234
- yield "", updated_history, None, None
235
- return
236
- else:
237
- message = f"Math OCR Result: {ocr_text}\n\nUser's message: {message}"
 
 
 
 
 
238
 
239
- # Check if an image was actually uploaded
240
- if image_filepath:
241
- response_stream = self.get_response(message, image_filepath)
242
- else:
243
- response_stream = self.get_response(message)
244
-
245
 
246
  # Handle errors in get_response
247
  if isinstance(response_stream, str):
248
  # Return immediately with the error message
249
  updated_history = chat_history + [[message, response_stream]]
250
- yield "", updated_history, None, None
251
  return
252
 
253
  # Prepare for streaming response
@@ -263,12 +283,12 @@ class XylariaChat:
263
 
264
  # Update the last message in chat history with partial response
265
  updated_history[-1][1] = full_response
266
- yield "", updated_history, None, None
267
  except Exception as e:
268
  print(f"Streaming error: {e}")
269
  # Display error in the chat interface
270
  updated_history[-1][1] = f"Error during response: {e}"
271
- yield "", updated_history, None, None
272
  return
273
 
274
  # Update conversation history
@@ -283,6 +303,9 @@ class XylariaChat:
283
  if len(self.conversation_history) > 10:
284
  self.conversation_history = self.conversation_history[-10:]
285
 
 
 
 
286
  # Custom CSS for Inter font and improved styling
287
  custom_css = """
288
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
@@ -297,38 +320,6 @@ class XylariaChat:
297
  .gradio-container button {
298
  font-family: 'Inter', sans-serif !important;
299
  }
300
- /* Image Upload Styling */
301
- .image-container {
302
- border: 1px solid #ccc;
303
- border-radius: 8px;
304
- padding: 10px;
305
- margin-bottom: 10px;
306
- display: flex;
307
- flex-direction: column;
308
- align-items: center;
309
- gap: 10px;
310
- background-color: #f8f8f8;
311
- }
312
- .image-preview {
313
- max-width: 200px;
314
- max-height: 200px;
315
- border-radius: 8px;
316
- }
317
- .image-buttons {
318
- display: flex;
319
- gap: 10px;
320
- }
321
- .image-buttons button {
322
- padding: 8px 15px;
323
- border-radius: 5px;
324
- background-color: #4CAF50;
325
- color: white;
326
- border: none;
327
- cursor: pointer;
328
- }
329
- .image-buttons button:hover {
330
- background-color: #367c39;
331
- }
332
  """
333
 
334
  with gr.Blocks(theme='soft', css=custom_css) as demo:
@@ -340,29 +331,6 @@ class XylariaChat:
340
  show_copy_button=True,
341
  )
342
 
343
- # Enhanced Image Upload Section
344
- with gr.Accordion("Image Input", open=False):
345
- with gr.Column() as image_container: # Use a Column for the image container
346
- img = gr.Image(
347
- sources=["upload", "webcam"],
348
- type="filepath",
349
- label="", # Remove label as it's redundant
350
- elem_classes="image-preview", # Add a class for styling
351
- )
352
- with gr.Row():
353
- clear_image_btn = gr.Button("Clear Image")
354
-
355
- with gr.Accordion("Math Input", open=False):
356
- with gr.Column():
357
- math_ocr_img = gr.Image(
358
- sources=["upload", "webcam"],
359
- type="filepath",
360
- label="Upload Image for math",
361
- elem_classes="image-preview"
362
- )
363
- with gr.Row():
364
- clear_math_ocr_btn = gr.Button("Clear Math Image")
365
-
366
  # Input row with improved layout
367
  with gr.Row():
368
  with gr.Column(scale=4):
@@ -371,70 +339,102 @@ class XylariaChat:
371
  placeholder="Type your message...",
372
  container=False
373
  )
374
- btn = gr.Button("Send", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
  # Clear history and memory buttons
377
  with gr.Row():
378
  clear = gr.Button("Clear Conversation")
379
  clear_memory = gr.Button("Clear Memory")
380
 
381
- # Clear image functionality
382
- clear_image_btn.click(
383
- fn=lambda: None,
384
- inputs=None,
385
- outputs=[img],
386
- queue=False
387
- )
388
-
389
- # Clear Math OCR image functionality
390
- clear_math_ocr_btn.click(
391
- fn=lambda: None,
392
- inputs=None,
393
- outputs=[math_ocr_img],
394
- queue=False
395
- )
396
-
397
  # Submit functionality with streaming and image support
 
398
  btn.click(
399
  fn=streaming_response,
400
- inputs=[txt, chatbot, img, math_ocr_img],
401
- outputs=[txt, chatbot, img, math_ocr_img]
402
  )
403
  txt.submit(
404
  fn=streaming_response,
405
- inputs=[txt, chatbot, img, math_ocr_img],
406
- outputs=[txt, chatbot, img, math_ocr_img]
 
 
 
 
 
 
407
  )
408
 
409
- # Clear conversation history
410
  clear.click(
411
- fn=lambda: None,
412
  inputs=None,
413
- outputs=[chatbot],
414
- queue=False
415
  )
416
 
417
- # Clear persistent memory and reset conversation
418
  clear_memory.click(
419
- fn=self.reset_conversation,
420
  inputs=None,
421
- outputs=[chatbot],
422
- queue=False
423
  )
424
 
425
- # Ensure memory is cleared when the interface is closed
426
- demo.load(self.reset_conversation, None, None)
427
-
428
  return demo
429
 
430
- # Launch the interface
431
- def main():
432
  chat = XylariaChat()
433
  interface = chat.create_interface()
434
- interface.launch(
435
- share=True, # Optional: create a public link
436
- debug=True # Show detailed errors
437
- )
438
-
439
- if __name__ == "__main__":
440
- main()
 
5
  from huggingface_hub import InferenceClient
6
  from dataclasses import dataclass
7
  import pytesseract
8
+ from PIL import Image, ImageGrab
9
+ import io
10
 
11
  @dataclass
12
  class ChatMessage:
 
31
  api_key=self.hf_token
32
  )
33
 
34
+ # Image captioning API setup with the new model
35
+ self.image_api_url = "https://api-inference.huggingface.co/models/microsoft/git-large-coco"
36
  self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"}
37
 
38
  # Initialize conversation history and persistent memory
 
75
  """
76
  Caption an uploaded image using Hugging Face API
77
  Args:
78
+ image (str or list): Base64 encoded image(s), file path(s), or file-like object(s)
79
  Returns:
80
+ str: Concatenated image captions or error message
81
  """
82
  try:
83
+ # Ensure image is a list
84
+ if not isinstance(image, list):
85
+ image = [image]
86
+
87
+ captions = []
88
+ for img in image:
89
+ # If image is a file path, read and encode
90
+ if isinstance(img, str) and os.path.isfile(img):
91
+ with open(img, "rb") as f:
92
+ data = f.read()
93
+ # If image is already base64 encoded
94
+ elif isinstance(img, str):
95
+ # Remove data URI prefix if present
96
+ if img.startswith('data:image'):
97
+ img = img.split(',')[1]
98
+ data = base64.b64decode(img)
99
+ # If image is a file-like object
100
+ else:
101
+ data = img.read()
102
+
103
+ # Send request to Hugging Face API
104
+ response = requests.post(
105
+ self.image_api_url,
106
+ headers=self.image_api_headers,
107
+ data=data
108
+ )
109
+
110
+ # Check response
111
+ if response.status_code == 200:
112
+ caption = response.json()[0].get('generated_text', 'No caption generated')
113
+ captions.append(caption)
114
+ else:
115
+ captions.append(f"Error captioning image: {response.status_code} - {response.text}")
116
 
117
+ # Return concatenated captions
118
+ return "\n".join(captions)
 
 
 
 
119
 
120
  except Exception as e:
121
  return f"Error processing image: {str(e)}"
 
123
  def perform_math_ocr(self, image_path):
124
  """
125
  Perform OCR on an image and return the extracted text.
 
126
  Args:
127
  image_path (str): Path to the image file.
 
128
  Returns:
129
  str: Extracted text from the image, or an error message.
130
  """
 
141
  except Exception as e:
142
  return f"Error during Math OCR: {e}"
143
 
144
+ def get_response(self, user_input, images=None, math_ocr_image=None):
145
  """
146
  Generate a response using chat completions with improved error handling
147
  Args:
148
  user_input (str): User's message
149
+ images (list, optional): List of uploaded images
150
+ math_ocr_image (str, optional): Path to math OCR image
151
  Returns:
152
  Stream of chat completions or error message
153
  """
 
178
  content=msg['content']
179
  ).to_dict())
180
 
181
+ # Process images if uploaded
182
+ image_context = ""
183
+ if images and any(images):
184
+ image_caption = self.caption_image(images)
185
+ image_context += f"Uploaded images: {image_caption}\n\n"
186
+
187
+ # Process math OCR image if uploaded
188
+ if math_ocr_image:
189
+ ocr_text = self.perform_math_ocr(math_ocr_image)
190
+ if not ocr_text.startswith("Error"):
191
+ image_context += f"Math OCR Result: {ocr_text}\n\n"
192
+
193
+ # Combine image context with user input
194
+ full_input = image_context + user_input
195
 
196
  # Add user input
197
  messages.append(ChatMessage(
198
  role="user",
199
+ content=full_input
200
  ).to_dict())
201
 
202
  # Calculate available tokens
 
242
 
243
 
244
  def create_interface(self):
245
+ def get_clipboard_image():
246
+ """Capture image from clipboard"""
247
+ try:
248
+ img = ImageGrab.grabclipboard()
249
+ if img is not None:
250
+ # Save clipboard image to a temporary file
251
+ temp_path = "clipboard_image.png"
252
+ img.save(temp_path)
253
+ return temp_path
254
+ return None
255
+ except Exception as e:
256
+ print(f"Error getting clipboard image: {e}")
257
+ return None
258
+
259
+ def streaming_response(message, chat_history, image1, image2, image3, image4, image5, math_ocr_image_path):
260
+ # Collect non-None images
261
+ images = [img for img in [image1, image2, image3, image4, image5] if img is not None]
262
 
263
+ # Generate response
264
+ response_stream = self.get_response(message, images, math_ocr_image_path)
 
 
 
 
265
 
266
  # Handle errors in get_response
267
  if isinstance(response_stream, str):
268
  # Return immediately with the error message
269
  updated_history = chat_history + [[message, response_stream]]
270
+ yield ("", updated_history) + ((None,) * 6)
271
  return
272
 
273
  # Prepare for streaming response
 
283
 
284
  # Update the last message in chat history with partial response
285
  updated_history[-1][1] = full_response
286
+ yield ("", updated_history) + ((None,) * 6)
287
  except Exception as e:
288
  print(f"Streaming error: {e}")
289
  # Display error in the chat interface
290
  updated_history[-1][1] = f"Error during response: {e}"
291
+ yield ("", updated_history) + ((None,) * 6)
292
  return
293
 
294
  # Update conversation history
 
303
  if len(self.conversation_history) > 10:
304
  self.conversation_history = self.conversation_history[-10:]
305
 
306
+ # Reset image inputs after processing
307
+ yield ("", updated_history, None, None, None, None, None, None)
308
+
309
  # Custom CSS for Inter font and improved styling
310
  custom_css = """
311
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
 
320
  .gradio-container button {
321
  font-family: 'Inter', sans-serif !important;
322
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  """
324
 
325
  with gr.Blocks(theme='soft', css=custom_css) as demo:
 
331
  show_copy_button=True,
332
  )
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  # Input row with improved layout
335
  with gr.Row():
336
  with gr.Column(scale=4):
 
339
  placeholder="Type your message...",
340
  container=False
341
  )
342
+
343
+ # Image and Math upload buttons
344
+ with gr.Column(scale=1):
345
+ # Buttons for image and math uploads with symbolic icons
346
+ with gr.Row():
347
+ img_upload_btn = gr.Button("🖼️") # Image upload button
348
+ math_upload_btn = gr.Button("➗") # Math upload button
349
+ clipboard_btn = gr.Button("📋") # Clipboard paste button
350
+
351
+ # Multiple image inputs
352
+ with gr.Accordion("Images", open=False):
353
+ with gr.Column():
354
+ with gr.Row():
355
+ img1 = gr.Image(
356
+ sources=["upload", "webcam"],
357
+ type="filepath",
358
+ label="Image 1",
359
+ height=200
360
+ )
361
+ img2 = gr.Image(
362
+ sources=["upload", "webcam"],
363
+ type="filepath",
364
+ label="Image 2",
365
+ height=200
366
+ )
367
+ with gr.Row():
368
+ img3 = gr.Image(
369
+ sources=["upload", "webcam"],
370
+ type="filepath",
371
+ label="Image 3",
372
+ height=200
373
+ )
374
+ img4 = gr.Image(
375
+ sources=["upload", "webcam"],
376
+ type="filepath",
377
+ label="Image 4",
378
+ height=200
379
+ )
380
+ img5 = gr.Image(
381
+ sources=["upload", "webcam"],
382
+ type="filepath",
383
+ label="Image 5",
384
+ height=200
385
+ )
386
+
387
+ # Math OCR Image Upload
388
+ with gr.Accordion("Math Input", open=False):
389
+ math_ocr_img = gr.Image(
390
+ sources=["upload", "webcam"],
391
+ type="filepath",
392
+ label="Upload Image for math",
393
+ height=200
394
+ )
395
 
396
  # Clear history and memory buttons
397
  with gr.Row():
398
  clear = gr.Button("Clear Conversation")
399
  clear_memory = gr.Button("Clear Memory")
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  # Submit functionality with streaming and image support
402
+ btn = gr.Button("Send")
403
  btn.click(
404
  fn=streaming_response,
405
+ inputs=[txt, chatbot, img1, img2, img3, img4, img5, math_ocr_img],
406
+ outputs=[txt, chatbot, img1, img2, img3, img4, img5, math_ocr_img]
407
  )
408
  txt.submit(
409
  fn=streaming_response,
410
+ inputs=[txt, chatbot, img1, img2, img3, img4, img5, math_ocr_img],
411
+ outputs=[txt, chatbot, img1, img2, img3, img4, img5, math_ocr_img]
412
+ )
413
+
414
+ # Clipboard button functionality
415
+ clipboard_btn.click(
416
+ fn=get_clipboard_image,
417
+ outputs=[img1]
418
  )
419
 
420
+ # Clear conversation button
421
  clear.click(
422
+ fn=self.reset_conversation,
423
  inputs=None,
424
+ outputs=[chatbot, txt, img1, img2, img3, img4, img5, math_ocr_img]
 
425
  )
426
 
427
+ # Clear memory button
428
  clear_memory.click(
429
+ fn=lambda: self.persistent_memory.clear(),
430
  inputs=None,
431
+ outputs=[]
 
432
  )
433
 
 
 
 
434
  return demo
435
 
436
+ # Optional: If you want to run the interface
437
+ if __name__ == "__main__":
438
  chat = XylariaChat()
439
  interface = chat.create_interface()
440
+ interface.launch()