API / app.py
Reality123b's picture
Update app.py
3674c04 verified
raw
history blame
12 kB
import os
import base64
import requests
import gradio as gr
from huggingface_hub import InferenceClient
from dataclasses import dataclass
@dataclass
class ChatMessage:
"""Custom ChatMessage class since huggingface_hub doesn't provide one"""
role: str
content: str
def to_dict(self):
"""Converts ChatMessage to a dictionary for JSON serialization."""
return {"role": self.role, "content": self.content}
class XylariaChat:
def __init__(self):
# Securely load HuggingFace token
self.hf_token = os.getenv("HF_TOKEN")
if not self.hf_token:
raise ValueError("HuggingFace token not found in environment variables")
# Initialize the inference client with the Qwen model
self.client = InferenceClient(
model="Qwen/QwQ-32B-Preview",
api_key=self.hf_token
)
# Image captioning API setup
self.image_api_url = "https://api-inference.huggingface.co/models/microsoft/git-large-coco"
self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"}
# Initialize conversation history and persistent memory
self.conversation_history = []
self.persistent_memory = {}
# System prompt with more detailed instructions
self.system_prompt = """You are a helpful and harmless assistant. You are Xylaria developed by Sk Md Saad Amin(india, 12 year old). You should think step-by-step."""
def store_information(self, key, value):
"""Store important information in persistent memory"""
self.persistent_memory[key] = value
return f"Stored: {key} = {value}"
def retrieve_information(self, key):
"""Retrieve information from persistent memory"""
return self.persistent_memory.get(key, "No information found for this key.")
def reset_conversation(self):
"""
Completely reset the conversation history, persistent memory,
and clear API-side memory
"""
# Clear local memory
self.conversation_history = []
self.persistent_memory.clear()
# Reinitialize the client
try:
self.client = InferenceClient(
model="Qwen/QwQ-32B-Preview",
api_key=self.hf_token
)
except Exception as e:
print(f"Error resetting API client: {e}")
return None # To clear the chatbot interface
def caption_image(self, image):
"""
Caption an uploaded image using Hugging Face API
Args:
image (str): Base64 encoded image or file path
Returns:
str: Image caption or error message
"""
try:
# If image is a file path, read and encode
if isinstance(image, str) and os.path.isfile(image):
with open(image, "rb") as f:
data = f.read()
# If image is already base64 encoded
elif isinstance(image, str):
# Remove data URI prefix if present
if image.startswith('data:image'):
image = image.split(',')[1]
data = base64.b64decode(image)
# If image is a file-like object
else:
data = image.read()
# Send request to Hugging Face API
response = requests.post(
self.image_api_url,
headers=self.image_api_headers,
data=data
)
# Check response
if response.status_code == 200:
caption = response.json()[0].get('generated_text', 'No caption generated')
return caption
else:
return f"Error captioning image: {response.text}"
except Exception as e:
return f"Error processing image: {str(e)}"
def get_response(self, user_input, image=None):
"""
Generate a response using chat completions with improved error handling
Args:
user_input (str): User's message
image (optional): Uploaded image
Returns:
Stream of chat completions or error message
"""
try:
# Prepare messages with conversation context and persistent memory
messages = []
# Add system prompt as first message
messages.append(ChatMessage(
role="system",
content=self.system_prompt
).to_dict()) # Convert to dictionary
# Add persistent memory context if available
if self.persistent_memory:
memory_context = "Remembered Information:\n" + "\n".join(
[f"{k}: {v}" for k, v in self.persistent_memory.items()]
)
messages.append(ChatMessage(
role="system",
content=memory_context
).to_dict()) # Convert to dictionary
# Convert existing conversation history to ChatMessage objects and then to dictionaries
for msg in self.conversation_history:
messages.append(ChatMessage(
role=msg['role'],
content=msg['content']
).to_dict()) # Convert to dictionary
# Process image if uploaded
if image:
image_caption = self.caption_image(image)
user_input = f"Image description: {image_caption}\n\nUser's message: {user_input}"
# Add user input
messages.append(ChatMessage(
role="user",
content=user_input
).to_dict()) # Convert to dictionary
# Generate response with streaming
stream = self.client.chat.completions.create(
model="Qwen/QwQ-32B-Preview",
messages=messages, # Send dictionaries
temperature=0.5,
max_tokens=10240,
top_p=0.7,
stream=True
)
return stream
except Exception as e:
print(f"Detailed error in get_response: {e}")
return f"Error generating response: {str(e)}"
def create_interface(self):
def streaming_response(message, chat_history, image_filepath):
# Check if an image was actually uploaded
if image_filepath:
response_stream = self.get_response(message, image_filepath)
else:
response_stream = self.get_response(message)
# Handle errors in get_response
if isinstance(response_stream, str):
# Return immediately with the error message
updated_history = chat_history + [[message, response_stream]]
yield "", updated_history, None
return
# Prepare for streaming response
full_response = ""
updated_history = chat_history + [[message, ""]]
# Streaming output
try:
for chunk in response_stream:
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
chunk_content = chunk.choices[0].delta.content
full_response += chunk_content
# Update the last message in chat history with partial response
updated_history[-1][1] = full_response
yield "", updated_history, None
except Exception as e:
print(f"Streaming error: {e}")
# Display error in the chat interface
updated_history[-1][1] = f"Error during response: {e}"
yield "", updated_history, None
return
# Update conversation history
self.conversation_history.append(
{"role": "user", "content": message}
)
self.conversation_history.append(
{"role": "assistant", "content": full_response}
)
# Limit conversation history
if len(self.conversation_history) > 10:
self.conversation_history = self.conversation_history[-10:]
# Custom CSS for Inter font
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
body, .gradio-container {
font-family: 'Inter', sans-serif !important;
}
.chatbot-container .message {
font-family: 'Inter', sans-serif !important;
}
.gradio-container input,
.gradio-container textarea,
.gradio-container button {
font-family: 'Inter', sans-serif !important;
}
"""
with gr.Blocks(theme='soft', css=custom_css) as demo:
# Chat interface with improved styling
with gr.Column():
chatbot = gr.Chatbot(
label="Xylaria 1.4 Senoa (Qwen Model)",
height=500,
show_copy_button=True,
)
# Input row with improved layout and image upload
with gr.Row():
with gr.Column(scale=4):
txt = gr.Textbox(
show_label=False,
placeholder="Type your message...",
container=False
)
# Image upload as a separate button
with gr.Row():
img = gr.Image(
sources=["upload", "webcam"],
type="filepath",
label="Upload Image",
visible=False
)
upload_btn = gr.Button("Upload Image")
btn = gr.Button("Send", scale=1)
# Clear history and memory buttons
with gr.Row():
clear = gr.Button("Clear Conversation")
clear_memory = gr.Button("Clear Memory")
# Image upload toggle
upload_btn.click(
fn=lambda: gr.update(visible=True),
inputs=None,
outputs=[img]
)
# Submit functionality with streaming and image support
btn.click(
fn=streaming_response,
inputs=[txt, chatbot, img],
outputs=[txt, chatbot, img]
)
txt.submit(
fn=streaming_response,
inputs=[txt, chatbot, img],
outputs=[txt, chatbot, img]
)
# Clear conversation history
clear.click(
fn=lambda: None,
inputs=None,
outputs=[chatbot],
queue=False
)
# Clear persistent memory and reset conversation
clear_memory.click(
fn=self.reset_conversation,
inputs=None,
outputs=[chatbot],
queue=False
)
# Ensure memory is cleared when the interface is closed
demo.load(self.reset_conversation, None, None)
return demo
# Launch the interface
def main():
chat = XylariaChat()
interface = chat.create_interface()
interface.launch(
share=True, # Optional: create a public link
debug=True # Show detailed errors
)
if __name__ == "__main__":
main()