Quazim0t0 commited on
Commit
a5b666f
·
verified ·
1 Parent(s): ef52b84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -30
app.py CHANGED
@@ -22,18 +22,24 @@ def get_data_table():
22
  tables = con.execute(text(
23
  "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
24
  )).fetchall()
 
25
  if not tables:
26
  return pd.DataFrame()
 
27
  # Use the first table found
28
  table_name = tables[0][0]
 
29
  with engine.connect() as con:
30
  result = con.execute(text(f"SELECT * FROM {table_name}"))
31
  rows = result.fetchall()
 
32
  if not rows:
33
  return pd.DataFrame()
 
34
  columns = result.keys()
35
  df = pd.DataFrame(rows, columns=columns)
36
  return df
 
37
  except Exception as e:
38
  return pd.DataFrame({"Error": [str(e)]})
39
 
@@ -49,13 +55,17 @@ def get_table_info():
49
  tables = con.execute(text(
50
  "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
51
  )).fetchall()
 
52
  if not tables:
53
  return None, [], {}
 
54
  # Use the first table found
55
  table_name = tables[0][0]
 
56
  # Get column information
57
  with engine.connect() as con:
58
  columns = con.execute(text(f"PRAGMA table_info({table_name})")).fetchall()
 
59
  # Extract column names and types
60
  column_names = [col[1] for col in columns]
61
  column_info = {
@@ -65,7 +75,9 @@ def get_table_info():
65
  }
66
  for col in columns
67
  }
 
68
  return table_name, column_names, column_info
 
69
  except Exception as e:
70
  print(f"Error getting table info: {str(e)}")
71
  return None, [], {}
@@ -78,18 +90,24 @@ def process_sql_file(file_path):
78
  # Read the SQL file
79
  with open(file_path, 'r') as file:
80
  sql_content = file.read()
 
81
  # Replace AUTO_INCREMENT with AUTOINCREMENT for SQLite compatibility
82
  sql_content = sql_content.replace('AUTO_INCREMENT', 'AUTOINCREMENT')
 
83
  # Split into individual statements
84
  statements = [stmt.strip() for stmt in sql_content.split(';') if stmt.strip()]
 
85
  # Clear existing database
86
  clear_database()
 
87
  # Execute each statement
88
  with engine.begin() as conn:
89
  for statement in statements:
90
  if statement.strip():
91
  conn.execute(text(statement))
 
92
  return True, "SQL file successfully executed!"
 
93
  except Exception as e:
94
  return False, f"Error processing SQL file: {str(e)}"
95
 
@@ -100,15 +118,20 @@ def process_csv_file(file_path):
100
  try:
101
  # Read the CSV file
102
  df = pd.read_csv(file_path)
 
103
  if len(df.columns) == 0:
104
  return False, "Error: File contains no columns"
 
105
  # Clear existing database and create new table
106
  clear_database()
107
  table = create_dynamic_table(df)
 
108
  # Convert DataFrame to list of dictionaries and insert
109
  records = df.to_dict('records')
110
  insert_rows_into_table(records, table)
 
111
  return True, "CSV file successfully loaded!"
 
112
  except Exception as e:
113
  return False, f"Error processing CSV file: {str(e)}"
114
 
@@ -119,14 +142,17 @@ def process_uploaded_file(file):
119
  try:
120
  if file is None:
121
  return False, "Please upload a file."
 
122
  # Get file extension
123
  file_ext = os.path.splitext(file)[1].lower()
 
124
  if file_ext == '.sql':
125
  return process_sql_file(file)
126
  elif file_ext == '.csv':
127
  return process_csv_file(file)
128
  else:
129
  return False, "Error: Unsupported file type. Please upload either a .sql or .csv file."
 
130
  except Exception as e:
131
  return False, f"Error processing file: {str(e)}"
132
 
@@ -134,19 +160,25 @@ def process_uploaded_file(file):
134
  def sql_engine(query: str) -> str:
135
  """
136
  Executes an SQL query and returns formatted results.
 
137
  Args:
138
  query: The SQL query string to execute on the database. Must be a valid SELECT query.
 
139
  Returns:
140
  str: The formatted query results as a string.
141
  """
142
  try:
143
  with engine.connect() as con:
144
  rows = con.execute(text(query)).fetchall()
 
145
  if not rows:
146
  return "No results found."
147
- if len(rows) == 1 and len(rows[0]) == 1: # Fixed the comparison operators here
 
148
  return str(rows[0][0])
 
149
  return "\n".join([", ".join(map(str, row)) for row in rows])
 
150
  except Exception as e:
151
  return f"Error: {str(e)}"
152
 
@@ -155,15 +187,15 @@ agent = CodeAgent(
155
  model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
156
  )
157
 
158
- def query_sql(user_query: str, show_full: bool) -> tuple:
159
  """
160
  Converts natural language input to an SQL query using CodeAgent.
161
  """
162
  table_name, column_names, column_info = get_table_info()
163
 
164
  if not table_name:
165
- return "Error: No data table exists. Please upload a file first.", ""
166
-
167
  schema_info = (
168
  f"The database has a table named '{table_name}' with the following columns:\n"
169
  + "\n".join([
@@ -177,16 +209,15 @@ def query_sql(user_query: str, show_full: bool) -> tuple:
177
  "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
178
  )
179
 
180
- # Get full response from the agent
181
- full_response = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
182
-
183
- if not isinstance(full_response, str):
184
- return "Error: Invalid query generated", ""
185
 
186
  # Clean up the SQL
187
- generated_sql = full_response
188
  if generated_sql.isnumeric(): # If the agent returned just a number
189
- return generated_sql, full_response
190
 
191
  # Extract just the SQL query if there's additional text
192
  sql_lines = [line for line in generated_sql.split('\n') if 'select' in line.lower()]
@@ -200,7 +231,7 @@ def query_sql(user_query: str, show_full: bool) -> tuple:
200
  for wrong_name in ['table_name', 'customers', 'main']:
201
  if wrong_name in generated_sql:
202
  generated_sql = generated_sql.replace(wrong_name, table_name)
203
-
204
  # Add quotes around column names that need them
205
  for col in column_names:
206
  if ' ' in col: # If column name contains spaces
@@ -210,40 +241,48 @@ def query_sql(user_query: str, show_full: bool) -> tuple:
210
  try:
211
  # Execute the query
212
  result = sql_engine(generated_sql)
 
213
  # Try to format as number if possible
214
  try:
215
  float_result = float(result)
216
- short_response = f"{float_result:,.0f}" # Format with commas, no decimals
217
  except ValueError:
218
- short_response = result
219
- return short_response, full_response
220
  except Exception as e:
221
  if str(e).startswith("(sqlite3.OperationalError) near"):
222
  # If it's a SQL syntax error, return the raw result
223
- return generated_sql, full_response
224
- return f"Error executing query: {str(e)}", full_response
225
 
226
  # Create the Gradio interface
227
  with gr.Blocks() as demo:
228
  with gr.Group() as upload_group:
229
  gr.Markdown("""
230
  # CSVAgent
 
231
  Upload your data file to begin.
 
232
  ### Supported File Types:
233
  - CSV (.csv): CSV file with headers that will be automatically converted to a table
 
234
  ### CSV Requirements:
235
  - Must include headers
236
  - First column will be used as the primary key
237
  - Column types will be automatically detected
238
  - Sample CSV Files: https://github.com/datablist/sample-csv-files
239
  ### Based on ZennyKenny's SqlAgent
 
240
  ### SQL to CSV File Conversion
241
  https://tableconvert.com/sql-to-csv
242
  - Will work on the handling of SQL files soon.
 
 
243
  ### Try it out! Upload a CSV file and then ask a question about the data!
244
- - There is issues with the UI displaying the answer correctly, some questions such as "How many Customers are located in Korea?"
245
  The right answer will appear in the logs, but throws an error on the "Results" section.
246
  """)
 
247
  file_input = gr.File(
248
  label="Upload Data File",
249
  file_types=[".csv", ".sql"],
@@ -256,8 +295,7 @@ with gr.Blocks() as demo:
256
  with gr.Column(scale=1):
257
  user_input = gr.Textbox(label="Ask a question about the data")
258
  query_output = gr.Textbox(label="Result")
259
- full_response_switch = gr.Switch(label="Show Full Response", value=False)
260
- full_response_output = gr.Textbox(label="Full Response", visible=False)
261
  with gr.Column(scale=2):
262
  gr.Markdown("### Current Data")
263
  data_table = gr.Dataframe(
@@ -265,6 +303,7 @@ with gr.Blocks() as demo:
265
  label="Data Table",
266
  interactive=False
267
  )
 
268
  schema_display = gr.Markdown(value="Loading schema...")
269
  refresh_btn = gr.Button("Refresh Data")
270
 
@@ -277,10 +316,11 @@ with gr.Blocks() as demo:
277
  gr.update(visible=True),
278
  gr.update(visible=False)
279
  )
 
280
  success, message = process_uploaded_file(file_obj)
281
  if success:
282
  df = get_data_table()
283
- _,_ , column_info = get_table_info()
284
  schema = "\n".join([
285
  f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}"
286
  for col, info in column_info.items()
@@ -302,7 +342,7 @@ with gr.Blocks() as demo:
302
 
303
  def refresh_data():
304
  df = get_data_table()
305
- _,_ , column_info = get_table_info()
306
  schema = "\n".join([
307
  f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}"
308
  for col, info in column_info.items()
@@ -324,14 +364,8 @@ with gr.Blocks() as demo:
324
 
325
  user_input.change(
326
  fn=query_sql,
327
- inputs=[user_input, full_response_switch],
328
- outputs=[query_output, full_response_output]
329
- )
330
-
331
- full_response_switch.change(
332
- fn=lambda x: gr.update(visible=x),
333
- inputs=full_response_switch,
334
- outputs=full_response_output
335
  )
336
 
337
  refresh_btn.click(
 
22
  tables = con.execute(text(
23
  "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
24
  )).fetchall()
25
+
26
  if not tables:
27
  return pd.DataFrame()
28
+
29
  # Use the first table found
30
  table_name = tables[0][0]
31
+
32
  with engine.connect() as con:
33
  result = con.execute(text(f"SELECT * FROM {table_name}"))
34
  rows = result.fetchall()
35
+
36
  if not rows:
37
  return pd.DataFrame()
38
+
39
  columns = result.keys()
40
  df = pd.DataFrame(rows, columns=columns)
41
  return df
42
+
43
  except Exception as e:
44
  return pd.DataFrame({"Error": [str(e)]})
45
 
 
55
  tables = con.execute(text(
56
  "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
57
  )).fetchall()
58
+
59
  if not tables:
60
  return None, [], {}
61
+
62
  # Use the first table found
63
  table_name = tables[0][0]
64
+
65
  # Get column information
66
  with engine.connect() as con:
67
  columns = con.execute(text(f"PRAGMA table_info({table_name})")).fetchall()
68
+
69
  # Extract column names and types
70
  column_names = [col[1] for col in columns]
71
  column_info = {
 
75
  }
76
  for col in columns
77
  }
78
+
79
  return table_name, column_names, column_info
80
+
81
  except Exception as e:
82
  print(f"Error getting table info: {str(e)}")
83
  return None, [], {}
 
90
  # Read the SQL file
91
  with open(file_path, 'r') as file:
92
  sql_content = file.read()
93
+
94
  # Replace AUTO_INCREMENT with AUTOINCREMENT for SQLite compatibility
95
  sql_content = sql_content.replace('AUTO_INCREMENT', 'AUTOINCREMENT')
96
+
97
  # Split into individual statements
98
  statements = [stmt.strip() for stmt in sql_content.split(';') if stmt.strip()]
99
+
100
  # Clear existing database
101
  clear_database()
102
+
103
  # Execute each statement
104
  with engine.begin() as conn:
105
  for statement in statements:
106
  if statement.strip():
107
  conn.execute(text(statement))
108
+
109
  return True, "SQL file successfully executed!"
110
+
111
  except Exception as e:
112
  return False, f"Error processing SQL file: {str(e)}"
113
 
 
118
  try:
119
  # Read the CSV file
120
  df = pd.read_csv(file_path)
121
+
122
  if len(df.columns) == 0:
123
  return False, "Error: File contains no columns"
124
+
125
  # Clear existing database and create new table
126
  clear_database()
127
  table = create_dynamic_table(df)
128
+
129
  # Convert DataFrame to list of dictionaries and insert
130
  records = df.to_dict('records')
131
  insert_rows_into_table(records, table)
132
+
133
  return True, "CSV file successfully loaded!"
134
+
135
  except Exception as e:
136
  return False, f"Error processing CSV file: {str(e)}"
137
 
 
142
  try:
143
  if file is None:
144
  return False, "Please upload a file."
145
+
146
  # Get file extension
147
  file_ext = os.path.splitext(file)[1].lower()
148
+
149
  if file_ext == '.sql':
150
  return process_sql_file(file)
151
  elif file_ext == '.csv':
152
  return process_csv_file(file)
153
  else:
154
  return False, "Error: Unsupported file type. Please upload either a .sql or .csv file."
155
+
156
  except Exception as e:
157
  return False, f"Error processing file: {str(e)}"
158
 
 
160
  def sql_engine(query: str) -> str:
161
  """
162
  Executes an SQL query and returns formatted results.
163
+
164
  Args:
165
  query: The SQL query string to execute on the database. Must be a valid SELECT query.
166
+
167
  Returns:
168
  str: The formatted query results as a string.
169
  """
170
  try:
171
  with engine.connect() as con:
172
  rows = con.execute(text(query)).fetchall()
173
+
174
  if not rows:
175
  return "No results found."
176
+
177
+ if len(rows) == 1 and len(rows[0]) == 1:
178
  return str(rows[0][0])
179
+
180
  return "\n".join([", ".join(map(str, row)) for row in rows])
181
+
182
  except Exception as e:
183
  return f"Error: {str(e)}"
184
 
 
187
  model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
188
  )
189
 
190
+ def query_sql(user_query: str) -> str:
191
  """
192
  Converts natural language input to an SQL query using CodeAgent.
193
  """
194
  table_name, column_names, column_info = get_table_info()
195
 
196
  if not table_name:
197
+ return "Error: No data table exists. Please upload a file first."
198
+
199
  schema_info = (
200
  f"The database has a table named '{table_name}' with the following columns:\n"
201
  + "\n".join([
 
209
  "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
210
  )
211
 
212
+ # Get SQL from the agent
213
+ generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
214
+
215
+ if not isinstance(generated_sql, str):
216
+ return "Error: Invalid query generated"
217
 
218
  # Clean up the SQL
 
219
  if generated_sql.isnumeric(): # If the agent returned just a number
220
+ return generated_sql
221
 
222
  # Extract just the SQL query if there's additional text
223
  sql_lines = [line for line in generated_sql.split('\n') if 'select' in line.lower()]
 
231
  for wrong_name in ['table_name', 'customers', 'main']:
232
  if wrong_name in generated_sql:
233
  generated_sql = generated_sql.replace(wrong_name, table_name)
234
+
235
  # Add quotes around column names that need them
236
  for col in column_names:
237
  if ' ' in col: # If column name contains spaces
 
241
  try:
242
  # Execute the query
243
  result = sql_engine(generated_sql)
244
+
245
  # Try to format as number if possible
246
  try:
247
  float_result = float(result)
248
+ return f"{float_result:,.0f}" # Format with commas, no decimals
249
  except ValueError:
250
+ return result
251
+
252
  except Exception as e:
253
  if str(e).startswith("(sqlite3.OperationalError) near"):
254
  # If it's a SQL syntax error, return the raw result
255
+ return generated_sql
256
+ return f"Error executing query: {str(e)}"
257
 
258
  # Create the Gradio interface
259
  with gr.Blocks() as demo:
260
  with gr.Group() as upload_group:
261
  gr.Markdown("""
262
  # CSVAgent
263
+
264
  Upload your data file to begin.
265
+
266
  ### Supported File Types:
267
  - CSV (.csv): CSV file with headers that will be automatically converted to a table
268
+
269
  ### CSV Requirements:
270
  - Must include headers
271
  - First column will be used as the primary key
272
  - Column types will be automatically detected
273
  - Sample CSV Files: https://github.com/datablist/sample-csv-files
274
  ### Based on ZennyKenny's SqlAgent
275
+
276
  ### SQL to CSV File Conversion
277
  https://tableconvert.com/sql-to-csv
278
  - Will work on the handling of SQL files soon.
279
+
280
+
281
  ### Try it out! Upload a CSV file and then ask a question about the data!
282
+ - There is issues with the UI displaying the answer correctly, some questions such as "How many Customers are located in Korea?"
283
  The right answer will appear in the logs, but throws an error on the "Results" section.
284
  """)
285
+
286
  file_input = gr.File(
287
  label="Upload Data File",
288
  file_types=[".csv", ".sql"],
 
295
  with gr.Column(scale=1):
296
  user_input = gr.Textbox(label="Ask a question about the data")
297
  query_output = gr.Textbox(label="Result")
298
+
 
299
  with gr.Column(scale=2):
300
  gr.Markdown("### Current Data")
301
  data_table = gr.Dataframe(
 
303
  label="Data Table",
304
  interactive=False
305
  )
306
+
307
  schema_display = gr.Markdown(value="Loading schema...")
308
  refresh_btn = gr.Button("Refresh Data")
309
 
 
316
  gr.update(visible=True),
317
  gr.update(visible=False)
318
  )
319
+
320
  success, message = process_uploaded_file(file_obj)
321
  if success:
322
  df = get_data_table()
323
+ _, _, column_info = get_table_info()
324
  schema = "\n".join([
325
  f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}"
326
  for col, info in column_info.items()
 
342
 
343
  def refresh_data():
344
  df = get_data_table()
345
+ _, _, column_info = get_table_info()
346
  schema = "\n".join([
347
  f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}"
348
  for col, info in column_info.items()
 
364
 
365
  user_input.change(
366
  fn=query_sql,
367
+ inputs=user_input,
368
+ outputs=query_output
 
 
 
 
 
 
369
  )
370
 
371
  refresh_btn.click(