Ritvik commited on
Commit
ca9dac5
·
1 Parent(s): 0a34dbc

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -26
app.py CHANGED
@@ -1,43 +1,69 @@
1
  import gradio as gr
2
- from groq import Groq
3
  import os
 
 
 
 
 
4
 
5
- # Set your Groq API key
6
- API_KEY = os.getenv("API_KEY")
7
- client = Groq(api_key=API_KEY) # Pass API key explicitly
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def respond(message, history, system_message, max_tokens, temperature, top_p):
10
- messages = [{"role": "system", "content": system_message}]
11
 
12
  for user_input, bot_response in history:
13
  if user_input:
14
- messages.append({"role": "user", "content": user_input})
15
  if bot_response:
16
- messages.append({"role": "assistant", "content": bot_response})
17
-
18
- messages.append({"role": "user", "content": message})
19
-
20
- completion = client.chat.completions.create(
21
- model="mixtral-8x7b-32768",
22
- messages=messages,
23
- temperature=temperature,
24
- max_completion_tokens=max_tokens,
25
- top_p=top_p,
26
- stream=True,
27
- stop=None,
28
- )
29
 
30
- response = ""
31
- for chunk in completion:
32
- token = chunk.choices[0].delta.content or ""
33
- response += token
34
- yield response
35
 
36
  # Gradio Interface
37
  demo = gr.ChatInterface(
38
  respond,
39
  additional_inputs=[
40
- gr.Textbox(value="You are a friendly AI assistant.", label="System message"),
41
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
42
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
43
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
@@ -45,4 +71,4 @@ demo = gr.ChatInterface(
45
  )
46
 
47
  if __name__ == "__main__":
48
- demo.launch()
 
1
  import gradio as gr
 
2
  import os
3
+ from langchain_groq import ChatGroq # Using Groq's API
4
+ from langchain.memory import ConversationBufferMemory
5
+ from langchain.schema import SystemMessage, HumanMessage, AIMessage
6
+ from langchain.agents import initialize_agent, AgentType
7
+ from langchain.tools import Tool
8
 
9
+ # Set API Key for Groq
10
+ API_KEY = os.getenv("API_KEY") # Ensure API Key is set in the environment
 
11
 
12
+ # Initialize the LLM (Groq's Mixtral)
13
+ llm = ChatGroq(
14
+ groq_api_key=API_KEY,
15
+ model_name="mixtral-8x7b-32768",
16
+ temperature=0.7,
17
+ max_tokens=512,
18
+ )
19
+
20
+ # Memory for conversation history
21
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
22
+
23
+ # Define useful tools
24
+ def search_tool(query: str) -> str:
25
+ """A simple search function (can be connected to real APIs)."""
26
+ return f"Searching for: {query}... [Sample Response]"
27
+
28
+ tools = [
29
+ Tool(
30
+ name="Search Tool",
31
+ func=search_tool,
32
+ description="Searches for information based on user queries."
33
+ )
34
+ ]
35
+
36
+ # Initialize the agent
37
+ agent = initialize_agent(
38
+ tools=tools,
39
+ llm=llm,
40
+ agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
41
+ verbose=True,
42
+ memory=memory,
43
+ )
44
+
45
+ # Define response function
46
  def respond(message, history, system_message, max_tokens, temperature, top_p):
47
+ memory.chat_memory.add_message(SystemMessage(content=system_message))
48
 
49
  for user_input, bot_response in history:
50
  if user_input:
51
+ memory.chat_memory.add_message(HumanMessage(content=user_input))
52
  if bot_response:
53
+ memory.chat_memory.add_message(AIMessage(content=bot_response))
54
+
55
+ memory.chat_memory.add_message(HumanMessage(content=message))
56
+
57
+ response = agent.run(message)
58
+
59
+ return response
 
 
 
 
 
 
60
 
 
 
 
 
 
61
 
62
  # Gradio Interface
63
  demo = gr.ChatInterface(
64
  respond,
65
  additional_inputs=[
66
+ gr.Textbox(value="You are a helpful Travel AI assistant.Your name is Travelo", label="System message"),
67
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
68
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
69
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
71
  )
72
 
73
  if __name__ == "__main__":
74
+ demo.launch()