hertogateis commited on
Commit
0a39414
·
verified ·
1 Parent(s): 879d20e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -141
app.py CHANGED
@@ -1,143 +1,57 @@
1
- import pandas as pd
2
  import streamlit as st
3
- from transformers import TapasForQuestionAnswering, TapasTokenizer, T5ForConditionalGeneration, T5Tokenizer
4
- import torch
5
-
6
- # Assuming df is uploaded or pre-defined (you can replace with actual data loading logic)
7
- # Example DataFrame (replace with your actual file or data)
8
- data = {
9
- 'Column1': [1, 2, 3, 4],
10
- 'Column2': [5.5, 6.5, 7.5, 8.5],
11
- 'Column3': ['a', 'b', 'c', 'd']
12
- }
13
- df = pd.DataFrame(data)
14
-
15
- # Check if DataFrame is valid
16
- if df is not None and not df.empty:
17
- # Select numeric columns
18
- df_numeric = df.select_dtypes(include='number')
19
- else:
20
- df_numeric = pd.DataFrame() # Empty DataFrame if input is invalid
21
-
22
- # Load TAPAS model and tokenizer
23
- tqa_model = TapasForQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq")
24
- tqa_tokenizer = TapasTokenizer.from_pretrained("google/tapas-large-finetuned-wtq")
25
-
26
- # Load T5 model and tokenizer for rephrasing
27
- t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
28
- t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
29
-
30
- # User input for the question
31
- question = st.text_input('Type your question')
32
-
33
- # Process the answer using TAPAS and T5
34
- with st.spinner():
35
- if st.button('Answer'):
36
- try:
37
- # Get the raw answer from TAPAS
38
- inputs = tqa_tokenizer(table=df, query=question, return_tensors="pt")
39
- with torch.no_grad():
40
- outputs = tqa_model(**inputs)
41
- raw_answer = tqa_tokenizer.decode(outputs.logits.argmax(dim=-1), skip_special_tokens=True)
42
-
43
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>", unsafe_allow_html=True)
44
- st.success(raw_answer)
45
-
46
- # Extract relevant information from the TAPAS result
47
- answer = raw_answer
48
- aggregator = "average" # Example aggregator, adjust based on raw_answer if needed
49
- coordinates = [] # Example, adjust based on raw_answer
50
- cells = [] # Example, adjust based on raw_answer
51
-
52
- # Construct a base sentence replacing 'SUM' with the query term
53
- base_sentence = f"The {question.lower()} of the selected data is {answer}."
54
- if coordinates and cells:
55
- rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}"
56
- for coordinate, cell in zip(coordinates, cells)]
57
- rows_description = " and ".join(rows_info)
58
- base_sentence += f" This includes the following data: {rows_description}."
59
-
60
- # Generate a fluent response using the T5 model, rephrasing the base sentence
61
- input_text = f"Given the question: '{question}', generate a more human-readable response: {base_sentence}"
62
-
63
- inputs = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
64
- summary_ids = t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True)
65
-
66
- generated_text = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
67
-
68
- # Display the final generated response
69
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Final Generated Response with LLM: </p>", unsafe_allow_html=True)
70
- st.success(generated_text)
71
-
72
- except Exception as e:
73
- st.warning("Please retype your question and make sure to use the column name and cell value correctly.")
74
-
75
-
76
- # Assuming 'column_name' exists and is selected or provided by the user
77
- # Example of getting 'column_name' from user input (adjust this part according to your app):
78
- column_name = st.selectbox("Select a column", df.columns)
79
-
80
- # Manually fix the aggregator if it returns an incorrect one
81
- if 'MEDIAN' in question.upper() and 'AVERAGE' in aggregator.upper():
82
- aggregator = 'MEDIAN'
83
- elif 'MIN' in question.upper() and 'AVERAGE' in aggregator.upper():
84
- aggregator = 'MIN'
85
- elif 'MAX' in question.upper() and 'AVERAGE' in aggregator.upper():
86
- aggregator = 'MAX'
87
- elif 'TOTAL' in question.upper() and 'SUM' in aggregator.upper():
88
- aggregator = 'SUM'
89
-
90
- # Use the corrected aggregator for further processing
91
- summary_type = aggregator.lower()
92
-
93
- # Check if `column_name` is valid before proceeding
94
- if column_name and column_name in df_numeric.columns:
95
- # Now, calculate the correct value using pandas based on the corrected aggregator
96
- if summary_type == 'sum':
97
- numeric_value = df_numeric[column_name].sum()
98
- elif summary_type == 'max':
99
- numeric_value = df_numeric[column_name].max()
100
- elif summary_type == 'min':
101
- numeric_value = df_numeric[column_name].min()
102
- elif summary_type == 'average':
103
- numeric_value = df_numeric[column_name].mean()
104
- elif summary_type == 'count':
105
- numeric_value = df_numeric[column_name].count()
106
- elif summary_type == 'median':
107
- numeric_value = df_numeric[column_name].median()
108
- elif summary_type == 'std_dev':
109
- numeric_value = df_numeric[column_name].std()
110
  else:
111
- numeric_value = answer # Fallback if something went wrong
112
- else:
113
- numeric_value = "Invalid column"
114
-
115
- # Construct a natural language response
116
- if summary_type == 'sum':
117
- natural_language_answer = f"The total {column_name} is {numeric_value}."
118
- elif summary_type == 'maximum':
119
- natural_language_answer = f"The highest {column_name} is {numeric_value}."
120
- elif summary_type == 'minimum':
121
- natural_language_answer = f"The lowest {column_name} is {numeric_value}."
122
- elif summary_type == 'average':
123
- natural_language_answer = f"The average {column_name} is {numeric_value}."
124
- elif summary_type == 'count':
125
- natural_language_answer = f"The number of entries in {column_name} is {numeric_value}."
126
- elif summary_type == 'median':
127
- natural_language_answer = f"The median {column_name} is {numeric_value}."
128
- elif summary_type == 'std_dev':
129
- natural_language_answer = f"The standard deviation of {column_name} is {numeric_value}."
130
- else:
131
- natural_language_answer = f"The value for {column_name} is {numeric_value}."
132
-
133
- # Display the result to the user
134
- st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Analysis Results: </p>", unsafe_allow_html=True)
135
- st.success(f"""
136
- Answer: {natural_language_answer}
137
-
138
- Data Location:
139
- Column: {column_name}
140
-
141
- Additional Context:
142
- Query Asked: "{question}"
143
- """)
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import openpyxl
4
+ from io import BytesIO
5
+ from fetaqa import question_answering # Hypothetical module for FeTaQA logic
6
+
7
+ # Cache the DataFrame for performance
8
+ @st.cache(allow_output_mutation=True)
9
+ def load_data(uploaded_file):
10
+ if uploaded_file.name.endswith('.csv'):
11
+ df = pd.read_csv(uploaded_file)
12
+ elif uploaded_file.name.endswith(('.xlsx', '.xls')):
13
+ df = pd.read_excel(uploaded_file, engine='openpyxl')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  else:
15
+ st.error("Unsupported file format. Please upload a CSV or XLSX file.")
16
+ return None
17
+ return df
18
+
19
+ def main():
20
+ st.title("FeTaQA Table Question Answering")
21
+
22
+ # File uploader
23
+ uploaded_file = st.file_uploader("Choose a CSV or Excel file", type=["csv", "xlsx", "xls"])
24
+
25
+ if uploaded_file is not None:
26
+ df = load_data(uploaded_file)
27
+
28
+ if df is not None:
29
+ st.write("Uploaded Table:")
30
+ st.dataframe(df)
31
+
32
+ # Question input
33
+ question = st.text_input("Ask a question about the table:")
34
+
35
+ # Question history
36
+ if 'question_history' not in st.session_state:
37
+ st.session_state['question_history'] = []
38
+
39
+ if st.button('Ask'):
40
+ if question:
41
+ answer = question_answering(df, question)
42
+ st.write(f"Answer: {answer}")
43
+ st.session_state['question_history'].append((question, answer))
44
+
45
+ # Displaying history
46
+ st.write("Question History:")
47
+ for q, a in st.session_state['question_history'][-5:]: # Show last 5 questions
48
+ st.write(f"**Q:** {q}")
49
+ st.write(f"**A:** {a}")
50
+ st.write("---")
51
+
52
+ # Reset history
53
+ if st.button('Clear History'):
54
+ st.session_state['question_history'] = []
55
+
56
+ if __name__ == "__main__":
57
+ main()