File size: 8,807 Bytes
4b1134a
e4a4649
 
4b1134a
e4a4649
4b1134a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4a4649
4b1134a
 
e4a4649
4b1134a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4a4649
 
4b1134a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4a4649
 
4b1134a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672e877
e4a4649
4b1134a
e4a4649
4b1134a
 
 
 
 
 
 
 
 
e4a4649
4b1134a
 
 
 
 
 
e4a4649
 
4b1134a
 
 
 
 
 
 
 
 
 
 
 
 
e4a4649
 
4b1134a
e4a4649
 
 
 
 
4b1134a
e4a4649
4b1134a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4a4649
4b1134a
 
 
 
 
 
 
e4a4649
4b1134a
 
 
e4a4649
4b1134a
 
 
 
 
 
 
 
 
 
 
 
e4a4649
4b1134a
 
 
 
 
 
 
 
e4a4649
4b1134a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4a4649
4b1134a
 
 
 
e4a4649
 
 
4b1134a
 
 
 
 
 
 
 
 
 
 
 
 
e4a4649
4b1134a
 
 
 
 
 
 
 
 
e4a4649
 
 
4b1134a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

import os
import random
import logging
import gradio as gr
import asyncio
from typing import List, Tuple, Generator, Any
from inference_client import InferenceClient  # Adjust the import as needed

# Set up logging to capture errors and warnings.
logging.basicConfig(
    level=logging.INFO,
    filename='chatbot.log',
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# Encapsulated configuration to avoid global variable pitfalls.
class ChatbotConfig:
    def __init__(
        self,
        max_history: int = 100,
        verbose: bool = True,
        max_iterations: int = 1000,
        max_new_tokens: int = 256,
        default_seed: int = None
    ):
        self.max_history = max_history
        self.verbose = verbose
        self.max_iterations = max_iterations
        self.max_new_tokens = max_new_tokens
        self.default_seed = default_seed or random.randint(1, 2**32 - 1)

# Global configuration instance.
config = ChatbotConfig()

# Externalize prompts into a dictionary, optionally overridden by environment variables.
PROMPTS = {
    "ACTION_PROMPT": os.environ.get("ACTION_PROMPT", "action prompt"),
    "ADD_PROMPT": os.environ.get("ADD_PROMPT", "add prompt"),
    "COMPRESS_HISTORY_PROMPT": os.environ.get("COMPRESS_HISTORY_PROMPT", "compress history prompt"),
    "LOG_PROMPT": os.environ.get("LOG_PROMPT", "log prompt"),
    "LOG_RESPONSE": os.environ.get("LOG_RESPONSE", "log response"),
    "MODIFY_PROMPT": os.environ.get("MODIFY_PROMPT", "modify prompt"),
    "PREFIX": os.environ.get("PREFIX", "prefix"),
    "SEARCH_QUERY": os.environ.get("SEARCH_QUERY", "search query"),
    "READ_PROMPT": os.environ.get("READ_PROMPT", "read prompt"),
    "TASK_PROMPT": os.environ.get("TASK_PROMPT", "task prompt"),
    "UNDERSTAND_TEST_RESULTS_PROMPT": os.environ.get("UNDERSTAND_TEST_RESULTS_PROMPT", "understand test results prompt")
}

# Instantiate the AI client.
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

def format_prompt_var(message: str, history: List[str]) -> str:
    """
    Format the provided message and conversation history into the required prompt format.

    Args:
        message (str): The current instruction/message.
        history (List[str]): List of previous conversation entries.
    
    Returns:
        str: A formatted prompt string.
    
    Raises:
        TypeError: If message is not a string or any history entry is not a string.
    """
    if not isinstance(message, str):
        raise TypeError("The instruction message must be a string.")
    if not all(isinstance(item, str) for item in history):
        raise TypeError("All items in history must be strings.")
    
    history_text = "\n".join(history) if history else "No previous conversation."
    prompt = f"\n### Instruction:\n{message}\n### History:\n{history_text}"
    return prompt

def run_agent(instruction: str, history: List[str]) -> Tuple[str, List[str]]:
    """
    Run the AI agent with the given instruction and conversation history.

    Args:
        instruction (str): The user instruction.
        history (List[str]): The conversation history.
    
    Returns:
        Tuple[str, List[str]]: A tuple containing the full AI response and a list of extracted actions.
    
    Raises:
        TypeError: If inputs are of invalid type.
    """
    if not isinstance(instruction, str):
        raise TypeError("Instruction must be a string.")
    if not isinstance(history, list) or not all(isinstance(item, str) for item in history):
        raise TypeError("History must be a list of strings.")

    prompt = format_prompt_var(instruction, history)
    response = ""
    iterations = 0

    try:
        for chunk in generate(prompt, history[-config.max_history:], temperature=0.7):
            response += chunk
            iterations += 1
            if "\n\n### Instruction:" in chunk or iterations >= config.max_iterations:
                break
    except Exception as e:
        logging.error("Error in run_agent: %s", e)
        response += f"\n[Error in run_agent: {e}]"

    # Extract actions from the response.
    response_actions = []
    for line in response.strip().split("\n"):
        if line.startswith("action:"):
            response_actions.append(line.replace("action: ", ""))
    
    return response, response_actions

def generate(prompt: str, history: List[str], temperature: float) -> Generator[str, None, None]:
    """
    Generate text from the AI model using the formatted prompt.
    
    Args:
        prompt (str): The input prompt.
        history (List[str]): Recent conversation history.
        temperature (float): Sampling temperature.
    
    Yields:
        str: Incremental output from the text-generation stream.
    """
    seed = random.randint(1, 2**32 - 1) if config.default_seed is None else config.default_seed
    generate_kwargs = {
        "temperature": temperature,
        "max_new_tokens": config.max_new_tokens,
        "top_p": 0.95,
        "repetition_penalty": 1.0,
        "do_sample": True,
        "seed": seed,
    }
    formatted_prompt = format_prompt_var(prompt, history)
    
    try:
        stream = client.text_generation(
            formatted_prompt,
            **generate_kwargs,
            stream=True,
            details=True,
            return_full_text=False
        )
    except Exception as e:
        logging.error("Error during text_generation call: %s", e)
        yield f"[Error during text_generation call: {e}]"
        return

    output = ""
    iterations = 0
    for response in stream:
        iterations += 1
        try:
            output += response.token.text
        except AttributeError as ae:
            logging.error("Malformed response token: %s", ae)
            yield f"[Malformed response token: {ae}]"
            break
        yield output
        if iterations >= config.max_iterations:
            yield "\n[Response truncated due to length limitations]"
            break

async def async_run_agent(instruction: str, history: List[str]) -> Tuple[str, List[str]]:
    """
    Asynchronous wrapper to run the agent in a separate thread.
    
    Args:
        instruction (str): The instruction for the AI.
        history (List[str]): The conversation history.
    
    Returns:
        Tuple[str, List[str]]: The response and extracted actions.
    """
    return await asyncio.to_thread(run_agent, instruction, history)

def clear_conversation() -> List[str]:
    """
    Clear the conversation history.
    
    Returns:
        List[str]: An empty conversation history.
    """
    return []

def update_chatbot_styles(history: List[Any]) -> Any:
    """
    Update the chatbot display styles based on the number of messages.

    Args:
        history (List[Any]): The current conversation history.
    
    Returns:
        Update object for Gradio Chatbot.
    """
    num_messages = sum(1 for item in history if isinstance(item, tuple))
    return gr.Chatbot.update({"num_messages": num_messages})

def update_max_history(value: int) -> int:
    """
    Update the max_history in configuration.

    Args:
        value (int): New maximum history value.
    
    Returns:
        int: The updated max_history.
    """
    config.max_history = int(value)
    return config.max_history

def create_interface() -> gr.Blocks:
    """
    Create and return the Gradio interface for the chatbot application.
    
    Returns:
        gr.Blocks: The Gradio Blocks object representing the UI.
    """
    block = gr.Blocks()
    chatbot = gr.Chatbot()

    with block:
        gr.Markdown("## Expert Web Developer Assistant")
        with gr.Tab("Conversation"):
            txt = gr.Textbox(show_label=False, placeholder="Type something...")
            btn = gr.Button("Send", variant="primary")
            
            # When text is submitted, run the agent asynchronously.
            txt.submit(
                async_run_agent, 
                inputs=[txt, chatbot], 
                outputs=[chatbot, None]
            )
            # Clear conversation history and update chatbot UI.
            txt.clear(fn=clear_conversation, outputs=chatbot).then(
                update_chatbot_styles, chatbot, chatbot
            )
            btn.click(fn=clear_conversation, outputs=chatbot).then(
                update_chatbot_styles, chatbot, chatbot
            )

        with gr.Tab("Settings"):
            max_history_slider = gr.Slider(
                minimum=1, maximum=100, step=1,
                label="Max history", 
                value=config.max_history
            )
            max_history_slider.change(
                update_max_history, max_history_slider, max_history_slider
            )

    return block

if __name__ == "__main__":
    interface = create_interface()
    interface.launch()