Spaces:
Running
on
Zero
Running
on
Zero
Ritvik
commited on
Commit
Β·
644789e
1
Parent(s):
76df764
Save local changes before pull
Browse files
agent.py
CHANGED
@@ -2,9 +2,9 @@ import os
|
|
2 |
from langchain_groq import ChatGroq
|
3 |
from langchain.agents import initialize_agent, AgentType
|
4 |
from memory import memory
|
5 |
-
from tools import tools
|
6 |
|
7 |
-
# Load API Key
|
8 |
API_KEY = os.getenv("API_KEY")
|
9 |
|
10 |
# Ensure API Key is set
|
@@ -19,11 +19,17 @@ llm = ChatGroq(
|
|
19 |
max_tokens=512,
|
20 |
)
|
21 |
|
22 |
-
# Initialize the conversational agent
|
23 |
agent = initialize_agent(
|
24 |
tools=tools,
|
25 |
llm=llm,
|
26 |
agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
|
27 |
-
verbose=True,
|
28 |
memory=memory,
|
29 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from langchain_groq import ChatGroq
|
3 |
from langchain.agents import initialize_agent, AgentType
|
4 |
from memory import memory
|
5 |
+
from tools import tools # Import Flight Booking Tool
|
6 |
|
7 |
+
# Load API Key
|
8 |
API_KEY = os.getenv("API_KEY")
|
9 |
|
10 |
# Ensure API Key is set
|
|
|
19 |
max_tokens=512,
|
20 |
)
|
21 |
|
22 |
+
# Initialize the conversational agent with Flight Booking Tool
|
23 |
agent = initialize_agent(
|
24 |
tools=tools,
|
25 |
llm=llm,
|
26 |
agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
|
27 |
+
verbose=True, # Hide debug output
|
28 |
memory=memory,
|
29 |
)
|
30 |
+
|
31 |
+
# π Custom Logging Function to Improve Execution
|
32 |
+
def log_agent_action(prompt, response):
|
33 |
+
print(f"\nπ’ **User Query:** {prompt}")
|
34 |
+
print(f"π΅ **Agent Thought Process:**")
|
35 |
+
print(f"β
**Final Response:** {response}")
|
app.py
CHANGED
@@ -81,7 +81,7 @@ demo = gr.ChatInterface(
|
|
81 |
respond,
|
82 |
additional_inputs=[
|
83 |
gr.Textbox(
|
84 |
-
value="You are an AI-powered travel assistant named Travelo. You are developed
|
85 |
label="System message"
|
86 |
),
|
87 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
|
|
81 |
respond,
|
82 |
additional_inputs=[
|
83 |
gr.Textbox(
|
84 |
+
value="You are an AI-powered travel assistant named Travelo. You are developed by Travelo LLC. Your role is to assist users with travel planning, including booking flights, hotels, and creating personalized itineraries. Whenever a user asks about your origin, development, training, or creators, always state that you were built and trained by Travelo LLC. Do not mention OpenAI or any third-party AI providers. Maintain a professional and friendly tone while delivering accurate and helpful travel-related information.",
|
85 |
label="System message"
|
86 |
),
|
87 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
requirements.txt
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
huggingface_hub==0.25.2
|
2 |
groq
|
3 |
langchain
|
4 |
-
langchain_groq
|
|
|
1 |
huggingface_hub==0.25.2
|
2 |
groq
|
3 |
langchain
|
4 |
+
langchain_groq
|
tools.py
CHANGED
@@ -1,14 +1,141 @@
|
|
|
|
|
|
|
|
1 |
from langchain.tools import Tool
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
tools = [
|
9 |
Tool(
|
10 |
-
name="
|
11 |
-
func=
|
12 |
-
description="
|
13 |
)
|
14 |
]
|
|
|
1 |
+
import os
|
2 |
+
import dateparser
|
3 |
+
from amadeus import Client, ResponseError
|
4 |
from langchain.tools import Tool
|
5 |
|
6 |
+
# Initialize Amadeus API Client
|
7 |
+
AMADEUS_API_KEY = os.getenv("AMADEUS_API_KEY")
|
8 |
+
AMADEUS_API_SECRET = os.getenv("AMADEUS_API_SECRET")
|
9 |
|
10 |
+
amadeus = Client(
|
11 |
+
client_id=AMADEUS_API_KEY,
|
12 |
+
client_secret=AMADEUS_API_SECRET
|
13 |
+
)
|
14 |
+
|
15 |
+
# π Convert City Name to IATA Code
|
16 |
+
def get_airport_code(city_name: str):
|
17 |
+
"""Find the IATA airport code for a given city."""
|
18 |
+
try:
|
19 |
+
response = amadeus.reference_data.locations.get(
|
20 |
+
keyword=city_name,
|
21 |
+
subType="AIRPORT,CITY"
|
22 |
+
)
|
23 |
+
if response.data:
|
24 |
+
return response.data[0]["iataCode"]
|
25 |
+
return None
|
26 |
+
except ResponseError:
|
27 |
+
return None
|
28 |
+
|
29 |
+
# βοΈ Get Full Airline Name from Code
|
30 |
+
def get_airline_name(airline_code: str):
|
31 |
+
"""Get full airline name using Amadeus API."""
|
32 |
+
try:
|
33 |
+
response = amadeus.reference_data.airlines.get(airlineCodes=airline_code)
|
34 |
+
if response.data:
|
35 |
+
return response.data[0]["businessName"]
|
36 |
+
return airline_code # Fallback to code if not found
|
37 |
+
except ResponseError:
|
38 |
+
return airline_code
|
39 |
+
|
40 |
+
# π Format Flight Duration
|
41 |
+
def format_duration(duration: str):
|
42 |
+
"""Convert ISO 8601 duration (PT20H25M) into readable format (20h 25m)."""
|
43 |
+
duration = duration.replace("PT", "").replace("H", "h ").replace("M", "m")
|
44 |
+
return duration.strip()
|
45 |
+
|
46 |
+
# π Generate Booking Link
|
47 |
+
def generate_booking_link(from_iata: str, to_iata: str, departure_date: str):
|
48 |
+
"""Generate a booking link using Google Flights."""
|
49 |
+
return f"https://www.google.com/flights?hl=en#flt={from_iata}.{to_iata}.{departure_date};c:USD;e:1;s:0;sd:1;t:f"
|
50 |
+
|
51 |
+
# π« Flight Search Tool (Now Gradio-Compatible)
|
52 |
+
def search_flights(query: str):
|
53 |
+
"""Search for flights using Amadeus API and return styled Markdown output for Gradio."""
|
54 |
+
try:
|
55 |
+
words = query.lower().split()
|
56 |
+
from_city, to_city, date_phrase = None, None, None
|
57 |
+
|
58 |
+
# Extract "from", "to", and date information
|
59 |
+
if "from" in words and "to" in words:
|
60 |
+
from_index = words.index("from") + 1
|
61 |
+
to_index = words.index("to") + 1
|
62 |
+
from_city = " ".join(words[from_index:to_index - 1]).title()
|
63 |
+
to_city = " ".join(words[to_index:words.index("in")]) if "in" in words else " ".join(words[to_index:]).title()
|
64 |
+
|
65 |
+
date_phrase = " ".join(words[words.index("in") + 1:]) if "in" in words else None
|
66 |
+
|
67 |
+
# Validate extracted details
|
68 |
+
if not from_city or not to_city:
|
69 |
+
return "β Could not detect valid departure and destination cities. Please use 'from <city> to <city>'."
|
70 |
+
|
71 |
+
# Convert city names to IATA codes
|
72 |
+
from_iata = get_airport_code(from_city)
|
73 |
+
to_iata = get_airport_code(to_city)
|
74 |
+
|
75 |
+
if not from_iata or not to_iata:
|
76 |
+
return f"β Could not find airport codes for {from_city} or {to_city}. Please check spelling."
|
77 |
+
|
78 |
+
# Convert date phrase to YYYY-MM-DD
|
79 |
+
departure_date = dateparser.parse(date_phrase) if date_phrase else None
|
80 |
+
if not departure_date:
|
81 |
+
return "β Could not understand the travel date. Use formats like 'next week' or 'on May 15'."
|
82 |
+
|
83 |
+
departure_date_str = departure_date.strftime("%Y-%m-%d")
|
84 |
+
|
85 |
+
# Fetch flight offers from Amadeus
|
86 |
+
response = amadeus.shopping.flight_offers_search.get(
|
87 |
+
originLocationCode=from_iata,
|
88 |
+
destinationLocationCode=to_iata,
|
89 |
+
departureDate=departure_date_str,
|
90 |
+
adults=1,
|
91 |
+
max=5
|
92 |
+
)
|
93 |
+
|
94 |
+
flights = response.data
|
95 |
+
if not flights:
|
96 |
+
return f"β No flights found from {from_city} to {to_city} on {departure_date_str}."
|
97 |
+
|
98 |
+
# π IMPROVED OUTPUT FOR GRADIO MARKDOWN
|
99 |
+
result = f"### βοΈ Flights from {from_city} ({from_iata}) to {to_city} ({to_iata})\nπ
**Date:** {departure_date_str}\n\n"
|
100 |
+
|
101 |
+
for flight in flights:
|
102 |
+
airline_code = flight["validatingAirlineCodes"][0]
|
103 |
+
airline_name = get_airline_name(airline_code)
|
104 |
+
price = flight["price"]["total"]
|
105 |
+
duration = format_duration(flight["itineraries"][0]["duration"])
|
106 |
+
departure_time = flight["itineraries"][0]["segments"][0]["departure"]["at"]
|
107 |
+
arrival_time = flight["itineraries"][0]["segments"][-1]["arrival"]["at"]
|
108 |
+
stops = len(flight["itineraries"][0]["segments"]) - 1
|
109 |
+
|
110 |
+
booking_link = generate_booking_link(from_iata, to_iata, departure_date_str)
|
111 |
+
|
112 |
+
# π« Styled Travel Card with Markdown
|
113 |
+
result += f"""
|
114 |
+
---
|
115 |
+
π **{airline_name}**
|
116 |
+
π° **Price:** `${price}`
|
117 |
+
β³ **Duration:** {duration}
|
118 |
+
π
**Departure:** {departure_time}
|
119 |
+
π¬ **Arrival:** {arrival_time}
|
120 |
+
π **Stops:** {stops}
|
121 |
+
π **[Book Now]({booking_link})**
|
122 |
+
---
|
123 |
+
"""
|
124 |
+
|
125 |
+
# "Show More Flights" Button
|
126 |
+
more_flights_link = generate_booking_link(from_iata, to_iata, departure_date_str)
|
127 |
+
result += f"\nπ **[Show More Flights]({more_flights_link})**"
|
128 |
+
|
129 |
+
return result
|
130 |
+
|
131 |
+
except ResponseError as error:
|
132 |
+
return str(error)
|
133 |
+
|
134 |
+
# β
Register as a Tool
|
135 |
tools = [
|
136 |
Tool(
|
137 |
+
name="Flight Booking",
|
138 |
+
func=search_flights,
|
139 |
+
description="Find and book flights using natural language. Examples: 'Flight from Delhi to SFO in May', 'Travel from Mumbai to New York next week'."
|
140 |
)
|
141 |
]
|