Travel_AI_V1 / app.py
Ritvik
Updated app.py
ca9dac5
raw
history blame
2.26 kB
import gradio as gr
import os
from langchain_groq import ChatGroq # Using Groq's API
from langchain.memory import ConversationBufferMemory
from langchain.schema import SystemMessage, HumanMessage, AIMessage
from langchain.agents import initialize_agent, AgentType
from langchain.tools import Tool
# Set API Key for Groq
API_KEY = os.getenv("API_KEY") # Ensure API Key is set in the environment
# Initialize the LLM (Groq's Mixtral)
llm = ChatGroq(
groq_api_key=API_KEY,
model_name="mixtral-8x7b-32768",
temperature=0.7,
max_tokens=512,
)
# Memory for conversation history
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# Define useful tools
def search_tool(query: str) -> str:
"""A simple search function (can be connected to real APIs)."""
return f"Searching for: {query}... [Sample Response]"
tools = [
Tool(
name="Search Tool",
func=search_tool,
description="Searches for information based on user queries."
)
]
# Initialize the agent
agent = initialize_agent(
tools=tools,
llm=llm,
agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
verbose=True,
memory=memory,
)
# Define response function
def respond(message, history, system_message, max_tokens, temperature, top_p):
memory.chat_memory.add_message(SystemMessage(content=system_message))
for user_input, bot_response in history:
if user_input:
memory.chat_memory.add_message(HumanMessage(content=user_input))
if bot_response:
memory.chat_memory.add_message(AIMessage(content=bot_response))
memory.chat_memory.add_message(HumanMessage(content=message))
response = agent.run(message)
return response
# Gradio Interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a helpful Travel AI assistant.Your name is Travelo", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
demo.launch()