Spaces:
Running
Running
import os | |
import random | |
import logging | |
import gradio as gr | |
import asyncio | |
from typing import List, Tuple, Generator, Any | |
from inference_client import InferenceClient # Adjust the import as needed | |
# Set up logging to capture errors and warnings. | |
logging.basicConfig( | |
level=logging.INFO, | |
filename='chatbot.log', | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
# Encapsulated configuration to avoid global variable pitfalls. | |
class ChatbotConfig: | |
def __init__( | |
self, | |
max_history: int = 100, | |
verbose: bool = True, | |
max_iterations: int = 1000, | |
max_new_tokens: int = 256, | |
default_seed: int = None | |
): | |
self.max_history = max_history | |
self.verbose = verbose | |
self.max_iterations = max_iterations | |
self.max_new_tokens = max_new_tokens | |
self.default_seed = default_seed or random.randint(1, 2**32 - 1) | |
# Global configuration instance. | |
config = ChatbotConfig() | |
# Externalize prompts into a dictionary, optionally overridden by environment variables. | |
PROMPTS = { | |
"ACTION_PROMPT": os.environ.get("ACTION_PROMPT", "action prompt"), | |
"ADD_PROMPT": os.environ.get("ADD_PROMPT", "add prompt"), | |
"COMPRESS_HISTORY_PROMPT": os.environ.get("COMPRESS_HISTORY_PROMPT", "compress history prompt"), | |
"LOG_PROMPT": os.environ.get("LOG_PROMPT", "log prompt"), | |
"LOG_RESPONSE": os.environ.get("LOG_RESPONSE", "log response"), | |
"MODIFY_PROMPT": os.environ.get("MODIFY_PROMPT", "modify prompt"), | |
"PREFIX": os.environ.get("PREFIX", "prefix"), | |
"SEARCH_QUERY": os.environ.get("SEARCH_QUERY", "search query"), | |
"READ_PROMPT": os.environ.get("READ_PROMPT", "read prompt"), | |
"TASK_PROMPT": os.environ.get("TASK_PROMPT", "task prompt"), | |
"UNDERSTAND_TEST_RESULTS_PROMPT": os.environ.get("UNDERSTAND_TEST_RESULTS_PROMPT", "understand test results prompt") | |
} | |
# Instantiate the AI client. | |
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
def format_prompt_var(message: str, history: List[str]) -> str: | |
""" | |
Format the provided message and conversation history into the required prompt format. | |
Args: | |
message (str): The current instruction/message. | |
history (List[str]): List of previous conversation entries. | |
Returns: | |
str: A formatted prompt string. | |
Raises: | |
TypeError: If message is not a string or any history entry is not a string. | |
""" | |
if not isinstance(message, str): | |
raise TypeError("The instruction message must be a string.") | |
if not all(isinstance(item, str) for item in history): | |
raise TypeError("All items in history must be strings.") | |
history_text = "\n".join(history) if history else "No previous conversation." | |
prompt = f"\n### Instruction:\n{message}\n### History:\n{history_text}" | |
return prompt | |
def run_agent(instruction: str, history: List[str]) -> Tuple[str, List[str]]: | |
""" | |
Run the AI agent with the given instruction and conversation history. | |
Args: | |
instruction (str): The user instruction. | |
history (List[str]): The conversation history. | |
Returns: | |
Tuple[str, List[str]]: A tuple containing the full AI response and a list of extracted actions. | |
Raises: | |
TypeError: If inputs are of invalid type. | |
""" | |
if not isinstance(instruction, str): | |
raise TypeError("Instruction must be a string.") | |
if not isinstance(history, list) or not all(isinstance(item, str) for item in history): | |
raise TypeError("History must be a list of strings.") | |
prompt = format_prompt_var(instruction, history) | |
response = "" | |
iterations = 0 | |
try: | |
for chunk in generate(prompt, history[-config.max_history:], temperature=0.7): | |
response += chunk | |
iterations += 1 | |
if "\n\n### Instruction:" in chunk or iterations >= config.max_iterations: | |
break | |
except Exception as e: | |
logging.error("Error in run_agent: %s", e) | |
response += f"\n[Error in run_agent: {e}]" | |
# Extract actions from the response. | |
response_actions = [] | |
for line in response.strip().split("\n"): | |
if line.startswith("action:"): | |
response_actions.append(line.replace("action: ", "")) | |
return response, response_actions | |
def generate(prompt: str, history: List[str], temperature: float) -> Generator[str, None, None]: | |
""" | |
Generate text from the AI model using the formatted prompt. | |
Args: | |
prompt (str): The input prompt. | |
history (List[str]): Recent conversation history. | |
temperature (float): Sampling temperature. | |
Yields: | |
str: Incremental output from the text-generation stream. | |
""" | |
seed = random.randint(1, 2**32 - 1) if config.default_seed is None else config.default_seed | |
generate_kwargs = { | |
"temperature": temperature, | |
"max_new_tokens": config.max_new_tokens, | |
"top_p": 0.95, | |
"repetition_penalty": 1.0, | |
"do_sample": True, | |
"seed": seed, | |
} | |
formatted_prompt = format_prompt_var(prompt, history) | |
try: | |
stream = client.text_generation( | |
formatted_prompt, | |
**generate_kwargs, | |
stream=True, | |
details=True, | |
return_full_text=False | |
) | |
except Exception as e: | |
logging.error("Error during text_generation call: %s", e) | |
yield f"[Error during text_generation call: {e}]" | |
return | |
output = "" | |
iterations = 0 | |
for response in stream: | |
iterations += 1 | |
try: | |
output += response.token.text | |
except AttributeError as ae: | |
logging.error("Malformed response token: %s", ae) | |
yield f"[Malformed response token: {ae}]" | |
break | |
yield output | |
if iterations >= config.max_iterations: | |
yield "\n[Response truncated due to length limitations]" | |
break | |
async def async_run_agent(instruction: str, history: List[str]) -> Tuple[str, List[str]]: | |
""" | |
Asynchronous wrapper to run the agent in a separate thread. | |
Args: | |
instruction (str): The instruction for the AI. | |
history (List[str]): The conversation history. | |
Returns: | |
Tuple[str, List[str]]: The response and extracted actions. | |
""" | |
return await asyncio.to_thread(run_agent, instruction, history) | |
def clear_conversation() -> List[str]: | |
""" | |
Clear the conversation history. | |
Returns: | |
List[str]: An empty conversation history. | |
""" | |
return [] | |
def update_chatbot_styles(history: List[Any]) -> Any: | |
""" | |
Update the chatbot display styles based on the number of messages. | |
Args: | |
history (List[Any]): The current conversation history. | |
Returns: | |
Update object for Gradio Chatbot. | |
""" | |
num_messages = sum(1 for item in history if isinstance(item, tuple)) | |
return gr.Chatbot.update({"num_messages": num_messages}) | |
def update_max_history(value: int) -> int: | |
""" | |
Update the max_history in configuration. | |
Args: | |
value (int): New maximum history value. | |
Returns: | |
int: The updated max_history. | |
""" | |
config.max_history = int(value) | |
return config.max_history | |
def create_interface() -> gr.Blocks: | |
""" | |
Create and return the Gradio interface for the chatbot application. | |
Returns: | |
gr.Blocks: The Gradio Blocks object representing the UI. | |
""" | |
block = gr.Blocks() | |
chatbot = gr.Chatbot() | |
with block: | |
gr.Markdown("## Expert Web Developer Assistant") | |
with gr.Tab("Conversation"): | |
txt = gr.Textbox(show_label=False, placeholder="Type something...") | |
btn = gr.Button("Send", variant="primary") | |
# When text is submitted, run the agent asynchronously. | |
txt.submit( | |
async_run_agent, | |
inputs=[txt, chatbot], | |
outputs=[chatbot, None] | |
) | |
# Clear conversation history and update chatbot UI. | |
txt.clear(fn=clear_conversation, outputs=chatbot).then( | |
update_chatbot_styles, chatbot, chatbot | |
) | |
btn.click(fn=clear_conversation, outputs=chatbot).then( | |
update_chatbot_styles, chatbot, chatbot | |
) | |
with gr.Tab("Settings"): | |
max_history_slider = gr.Slider( | |
minimum=1, maximum=100, step=1, | |
label="Max history", | |
value=config.max_history | |
) | |
max_history_slider.change( | |
update_max_history, max_history_slider, max_history_slider | |
) | |
return block | |
if __name__ == "__main__": | |
interface = create_interface() | |
interface.launch() |