Spaces:
Running
Running
import os | |
import gradio as gr | |
from sqlalchemy import text | |
from smolagents import tool, CodeAgent, HfApiModel | |
import spaces | |
import pandas as pd | |
from database import ( | |
engine, | |
create_dynamic_table, | |
clear_database, | |
insert_rows_into_table, | |
get_table_schema | |
) | |
def get_data_table(): | |
""" | |
Fetches all data from the current table and returns it as a Pandas DataFrame. | |
""" | |
try: | |
# Get list of tables | |
with engine.connect() as con: | |
tables = con.execute(text( | |
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'" | |
)).fetchall() | |
if not tables: | |
return pd.DataFrame() | |
# Use the first table found | |
table_name = tables[0][0] | |
with engine.connect() as con: | |
result = con.execute(text(f"SELECT * FROM {table_name}")) | |
rows = result.fetchall() | |
if not rows: | |
return pd.DataFrame() | |
columns = result.keys() | |
df = pd.DataFrame(rows, columns=columns) | |
return df | |
except Exception as e: | |
return pd.DataFrame({"Error": [str(e)]}) | |
def get_table_info(): | |
""" | |
Gets the current table name and column information. | |
Returns: | |
tuple: (table_name, list of column names, column info) | |
""" | |
try: | |
# Get list of tables | |
with engine.connect() as con: | |
tables = con.execute(text( | |
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'" | |
)).fetchall() | |
if not tables: | |
return None, [], {} | |
# Use the first table found | |
table_name = tables[0][0] | |
# Get column information | |
with engine.connect() as con: | |
columns = con.execute(text(f"PRAGMA table_info({table_name})")).fetchall() | |
# Extract column names and types | |
column_names = [col[1] for col in columns] | |
column_info = { | |
col[1]: { | |
'type': col[2], | |
'is_primary': bool(col[5]) | |
} | |
for col in columns | |
} | |
return table_name, column_names, column_info | |
except Exception as e: | |
print(f"Error getting table info: {str(e)}") | |
return None, [], {} | |
def process_sql_file(file_path): | |
""" | |
Process an SQL file and execute its contents. | |
""" | |
try: | |
# Read the SQL file | |
with open(file_path, 'r') as file: | |
sql_content = file.read() | |
# Replace AUTO_INCREMENT with AUTOINCREMENT for SQLite compatibility | |
sql_content = sql_content.replace('AUTO_INCREMENT', 'AUTOINCREMENT') | |
# Split into individual statements | |
statements = [stmt.strip() for stmt in sql_content.split(';') if stmt.strip()] | |
# Clear existing database | |
clear_database() | |
# Execute each statement | |
with engine.begin() as conn: | |
for statement in statements: | |
if statement.strip(): | |
conn.execute(text(statement)) | |
return True, "SQL file successfully executed!" | |
except Exception as e: | |
return False, f"Error processing SQL file: {str(e)}" | |
def process_csv_file(file_path): | |
""" | |
Process a CSV file and load it into the database. | |
""" | |
try: | |
# Read the CSV file | |
df = pd.read_csv(file_path) | |
if len(df.columns) == 0: | |
return False, "Error: File contains no columns" | |
# Clear existing database and create new table | |
clear_database() | |
table = create_dynamic_table(df) | |
# Convert DataFrame to list of dictionaries and insert | |
records = df.to_dict('records') | |
insert_rows_into_table(records, table) | |
return True, "CSV file successfully loaded!" | |
except Exception as e: | |
return False, f"Error processing CSV file: {str(e)}" | |
def process_uploaded_file(file): | |
""" | |
Process the uploaded file (either SQL or CSV). | |
""" | |
try: | |
if file is None: | |
return False, "Please upload a file." | |
# Get file extension | |
file_ext = os.path.splitext(file)[1].lower() | |
if file_ext == '.sql': | |
return process_sql_file(file) | |
elif file_ext == '.csv': | |
return process_csv_file(file) | |
else: | |
return False, "Error: Unsupported file type. Please upload either a .sql or .csv file." | |
except Exception as e: | |
return False, f"Error processing file: {str(e)}" | |
def sql_engine(query: str) -> str: | |
""" | |
Executes an SQL query and returns formatted results. | |
Args: | |
query: The SQL query string to execute on the database. Must be a valid SELECT query. | |
Returns: | |
str: The formatted query results as a string. | |
""" | |
try: | |
with engine.connect() as con: | |
rows = con.execute(text(query)).fetchall() | |
if not rows: | |
return "No results found." | |
if len(rows) == 1 and len(rows[0]) == 1: | |
return str(rows[0][0]) | |
return "\n".join([", ".join(map(str, row)) for row in rows]) | |
except Exception as e: | |
return f"Error: {str(e)}" | |
agent = CodeAgent( | |
tools=[sql_engine], | |
model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"), | |
) | |
def query_sql(user_query: str) -> str: | |
""" | |
Converts natural language input to an SQL query using CodeAgent. | |
""" | |
table_name, column_names, column_info = get_table_info() | |
if not table_name: | |
return "Error: No data table exists. Please upload a file first." | |
schema_info = ( | |
f"The database has a table named '{table_name}' with the following columns:\n" | |
+ "\n".join([ | |
f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}" | |
for col, info in column_info.items() | |
]) | |
+ "\n\nGenerate a valid SQL SELECT query using ONLY these column names.\n" | |
"The table name is '" + table_name + "'.\n" | |
"If column names contain spaces, they must be quoted.\n" | |
"You can use aggregate functions like COUNT, AVG, SUM, etc.\n" | |
"DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself." | |
) | |
# Get SQL from the agent | |
generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}") | |
if not isinstance(generated_sql, str): | |
return "Error: Invalid query generated" | |
# Clean up the SQL | |
if generated_sql.isnumeric(): # If the agent returned just a number | |
return generated_sql | |
# Extract just the SQL query if there's additional text | |
sql_lines = [line for line in generated_sql.split('\n') if 'select' in line.lower()] | |
if sql_lines: | |
generated_sql = sql_lines[0] | |
# Remove any trailing semicolons | |
generated_sql = generated_sql.strip().rstrip(';') | |
# Fix table names | |
for wrong_name in ['table_name', 'customers', 'main']: | |
if wrong_name in generated_sql: | |
generated_sql = generated_sql.replace(wrong_name, table_name) | |
# Add quotes around column names that need them | |
for col in column_names: | |
if ' ' in col: # If column name contains spaces | |
if col in generated_sql and f'"{col}"' not in generated_sql and f'`{col}`' not in generated_sql: | |
generated_sql = generated_sql.replace(col, f'"{col}"') | |
try: | |
# Execute the query | |
result = sql_engine(generated_sql) | |
# Try to format as number if possible | |
try: | |
float_result = float(result) | |
return f"{float_result:,.0f}" # Format with commas, no decimals | |
except ValueError: | |
return result | |
except Exception as e: | |
if str(e).startswith("(sqlite3.OperationalError) near"): | |
# If it's a SQL syntax error, return the raw result | |
return generated_sql | |
return f"Error executing query: {str(e)}" | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
with gr.Group() as upload_group: | |
gr.Markdown(""" | |
# CSVAgent | |
Upload your data file to begin. | |
### Supported File Types: | |
- CSV (.csv): CSV file with headers that will be automatically converted to a table | |
### CSV Requirements: | |
- Must include headers | |
- First column will be used as the primary key | |
- Column types will be automatically detected | |
- Sample CSV Files: https://github.com/datablist/sample-csv-files | |
### Based on ZennyKenny's SqlAgent | |
### SQL to CSV File Conversion | |
https://tableconvert.com/sql-to-csv | |
- Will work on the handling of SQL files soon. | |
### Try it out! Upload a CSV file and then ask a question about the data! | |
- There is issues with the UI displaying the answer correctly, some questions such as "How many Customers are located in Korea?" | |
The right answer will appear in the logs, but throws an error on the "Results" section. | |
""") | |
file_input = gr.File( | |
label="Upload Data File", | |
file_types=[".csv", ".sql"], | |
type="filepath" | |
) | |
status = gr.Textbox(label="Status", interactive=False) | |
with gr.Group(visible=False) as query_group: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
user_input = gr.Textbox(label="Ask a question about the data") | |
query_output = gr.Textbox(label="Result") | |
with gr.Column(scale=2): | |
gr.Markdown("### Current Data") | |
data_table = gr.Dataframe( | |
value=None, | |
label="Data Table", | |
interactive=False | |
) | |
schema_display = gr.Markdown(value="Loading schema...") | |
refresh_btn = gr.Button("Refresh Data") | |
def handle_upload(file_obj): | |
if file_obj is None: | |
return ( | |
"Please upload a file.", | |
None, | |
"No schema available", | |
gr.update(visible=True), | |
gr.update(visible=False) | |
) | |
success, message = process_uploaded_file(file_obj) | |
if success: | |
df = get_data_table() | |
_, _, column_info = get_table_info() | |
schema = "\n".join([ | |
f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}" | |
for col, info in column_info.items() | |
]) | |
return ( | |
message, | |
df, | |
f"### Current Schema:\n```\n{schema}\n```", | |
gr.update(visible=False), | |
gr.update(visible=True) | |
) | |
return ( | |
message, | |
None, | |
"No schema available", | |
gr.update(visible=True), | |
gr.update(visible=False) | |
) | |
def refresh_data(): | |
df = get_data_table() | |
_, _, column_info = get_table_info() | |
schema = "\n".join([ | |
f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}" | |
for col, info in column_info.items() | |
]) | |
return df, f"### Current Schema:\n```\n{schema}\n```" | |
# Event handlers | |
file_input.upload( | |
fn=handle_upload, | |
inputs=file_input, | |
outputs=[ | |
status, | |
data_table, | |
schema_display, | |
upload_group, | |
query_group | |
] | |
) | |
user_input.change( | |
fn=query_sql, | |
inputs=user_input, | |
outputs=query_output | |
) | |
refresh_btn.click( | |
fn=refresh_data, | |
outputs=[data_table, schema_display] | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |