Spaces:
Running
Running
import os | |
import base64 | |
import requests | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from dataclasses import dataclass | |
class ChatMessage: | |
"""Custom ChatMessage class since huggingface_hub doesn't provide one""" | |
role: str | |
content: str | |
class XylariaChat: | |
def __init__(self): | |
# Securely load HuggingFace token | |
self.hf_token = os.getenv("HF_TOKEN") | |
if not self.hf_token: | |
raise ValueError("HuggingFace token not found in environment variables") | |
# Initialize the inference client with the Qwen model | |
self.client = InferenceClient( | |
model="Qwen/QwQ-32B-Preview", | |
api_key=self.hf_token | |
) | |
# Image captioning API setup | |
self.image_api_url = "https://api-inference.huggingface.co/models/microsoft/git-large-coco" | |
self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"} | |
# Initialize conversation history and persistent memory | |
self.conversation_history = [] | |
self.persistent_memory = {} | |
# System prompt with more detailed instructions | |
self.system_prompt = """You are a helpful and harmless AI assistant named Xylaria. | |
Always think step-by-step and provide clear, thoughtful responses. | |
Be kind, ethical, and supportive in your interactions.""" | |
def store_information(self, key, value): | |
"""Store important information in persistent memory""" | |
self.persistent_memory[key] = value | |
return f"Stored: {key} = {value}" | |
def retrieve_information(self, key): | |
"""Retrieve information from persistent memory""" | |
return self.persistent_memory.get(key, "No information found for this key.") | |
def reset_conversation(self): | |
""" | |
Completely reset the conversation history, persistent memory, | |
and clear API-side memory | |
""" | |
# Clear local memory | |
self.conversation_history = [] | |
self.persistent_memory.clear() | |
# Reinitialize the client | |
try: | |
self.client = InferenceClient( | |
model="Qwen/QwQ-32B-Preview", | |
api_key=self.hf_token | |
) | |
except Exception as e: | |
print(f"Error resetting API client: {e}") | |
return None # To clear the chatbot interface | |
def caption_image(self, image): | |
""" | |
Caption an uploaded image using Hugging Face API | |
Args: | |
image (str): Base64 encoded image or file path | |
Returns: | |
str: Image caption or error message | |
""" | |
try: | |
# If image is a file path, read and encode | |
if isinstance(image, str) and os.path.isfile(image): | |
with open(image, "rb") as f: | |
data = f.read() | |
# If image is already base64 encoded | |
elif isinstance(image, str): | |
# Remove data URI prefix if present | |
if image.startswith('data:image'): | |
image = image.split(',')[1] | |
data = base64.b64decode(image) | |
# If image is a file-like object | |
else: | |
data = image.read() | |
# Send request to Hugging Face API | |
response = requests.post( | |
self.image_api_url, | |
headers=self.image_api_headers, | |
data=data | |
) | |
# Check response | |
if response.status_code == 200: | |
caption = response.json()[0].get('generated_text', 'No caption generated') | |
return caption | |
else: | |
return f"Error captioning image: {response.text}" | |
except Exception as e: | |
return f"Error processing image: {str(e)}" | |
def get_response(self, user_input, image=None): | |
""" | |
Generate a response using chat completions with improved error handling | |
Args: | |
user_input (str): User's message | |
image (optional): Uploaded image | |
Returns: | |
Stream of chat completions or error message | |
""" | |
try: | |
# Prepare messages with conversation context and persistent memory | |
messages = [] | |
# Add system prompt as first message | |
messages.append(ChatMessage( | |
role="system", | |
content=self.system_prompt | |
)) | |
# Add persistent memory context if available | |
if self.persistent_memory: | |
memory_context = "Remembered Information:\n" + "\n".join( | |
[f"{k}: {v}" for k, v in self.persistent_memory.items()] | |
) | |
messages.append(ChatMessage( | |
role="system", | |
content=memory_context | |
)) | |
# Convert existing conversation history to ChatMessage objects | |
for msg in self.conversation_history: | |
messages.append(ChatMessage( | |
role=msg['role'], | |
content=msg['content'] | |
)) | |
# Process image if uploaded | |
if image: | |
image_caption = self.caption_image(image) | |
user_input = f"Image description: {image_caption}\n\nUser's message: {user_input}" | |
# Add user input | |
messages.append(ChatMessage( | |
role="user", | |
content=user_input | |
)) | |
# Generate response with streaming | |
stream = self.client.chat.completions.create( | |
model="Qwen/QwQ-32B-Preview", | |
messages=messages, | |
temperature=0.5, | |
max_tokens=10240, | |
top_p=0.7, | |
stream=True | |
) | |
return stream | |
except Exception as e: | |
print(f"Detailed error in get_response: {e}") | |
return f"Error generating response: {str(e)}" | |
def create_interface(self): | |
def streaming_response(message, chat_history, image): | |
# Clear input textbox | |
response_stream = self.get_response(message, image) | |
# If it's an error, return immediately | |
if isinstance(response_stream, str): | |
return "", chat_history + [[message, response_stream]], None | |
# Prepare for streaming response | |
full_response = "" | |
updated_history = chat_history + [[message, ""]] | |
# Streaming output | |
try: | |
for chunk in response_stream: | |
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: | |
chunk_content = chunk.choices[0].delta.content | |
full_response += chunk_content | |
# Update the last message in chat history with partial response | |
updated_history[-1][1] = full_response | |
yield "", updated_history, None | |
except Exception as e: | |
print(f"Streaming error: {e}") | |
yield "", updated_history + [["", f"Error during response: {e}"]], None | |
# Update conversation history | |
self.conversation_history.append( | |
{"role": "user", "content": message} | |
) | |
self.conversation_history.append( | |
{"role": "assistant", "content": full_response} | |
) | |
# Limit conversation history to prevent token overflow | |
if len(self.conversation_history) > 10: | |
self.conversation_history = self.conversation_history[-10:] | |
# Custom CSS for Inter font | |
custom_css = """ | |
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); | |
body, .gradio-container { | |
font-family: 'Inter', sans-serif !important; | |
} | |
.chatbot-container .message { | |
font-family: 'Inter', sans-serif !important; | |
} | |
.gradio-container input, | |
.gradio-container textarea, | |
.gradio-container button { | |
font-family: 'Inter', sans-serif !important; | |
} | |
""" | |
with gr.Blocks(theme='soft', css=custom_css) as demo: | |
# Chat interface with improved styling | |
with gr.Column(): | |
chatbot = gr.Chatbot( | |
label="Xylaria 1.4 Senoa (Qwen Model)", | |
height=500, | |
show_copy_button=True, | |
type="messages" | |
) | |
# Input row with improved layout and image upload | |
with gr.Row(): | |
with gr.Column(scale=4): | |
txt = gr.Textbox( | |
show_label=False, | |
placeholder="Type your message...", | |
container=False | |
) | |
# Image upload as a separate button | |
with gr.Row(): | |
img = gr.Image( | |
sources=["upload", "webcam"], | |
type="filepath", | |
label="Upload Image", | |
visible=False | |
) | |
upload_btn = gr.Button("Upload Image") | |
btn = gr.Button("Send", scale=1) | |
# Clear history and memory buttons | |
with gr.Row(): | |
clear = gr.Button("Clear Conversation") | |
clear_memory = gr.Button("Clear Memory") | |
# Image upload toggle | |
upload_btn.click( | |
fn=lambda: gr.update(visible=True), | |
inputs=None, | |
outputs=[img] | |
) | |
# Submit functionality with streaming and image support | |
btn.click( | |
fn=streaming_response, | |
inputs=[txt, chatbot, img], | |
outputs=[txt, chatbot, img] | |
) | |
txt.submit( | |
fn=streaming_response, | |
inputs=[txt, chatbot, img], | |
outputs=[txt, chatbot, img] | |
) | |
# Clear conversation history | |
clear.click( | |
fn=lambda: None, | |
inputs=None, | |
outputs=[chatbot], | |
queue=False | |
) | |
# Clear persistent memory and reset conversation | |
clear_memory.click( | |
fn=self.reset_conversation, | |
inputs=None, | |
outputs=[chatbot], | |
queue=False | |
) | |
# Ensure memory is cleared when the interface is closed | |
demo.load(self.reset_conversation, None, None) | |
return demo | |
# Launch the interface | |
def main(): | |
chat = XylariaChat() | |
interface = chat.create_interface() | |
interface.launch( | |
share=True, # Optional: create a public link | |
debug=True # Show detailed errors | |
) | |
if __name__ == "__main__": | |
main() |