hertogateis commited on
Commit
bc8fae9
·
verified ·
1 Parent(s): d622429

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -52
app.py CHANGED
@@ -1,57 +1,92 @@
 
1
  import streamlit as st
 
2
  import pandas as pd
3
- import openpyxl
4
- from io import BytesIO
5
- from fetaqa import question_answering # Hypothetical module for FeTaQA logic
6
-
7
- # Cache the DataFrame for performance
8
- @st.cache(allow_output_mutation=True)
9
- def load_data(uploaded_file):
10
- if uploaded_file.name.endswith('.csv'):
11
- df = pd.read_csv(uploaded_file)
12
- elif uploaded_file.name.endswith(('.xlsx', '.xls')):
13
- df = pd.read_excel(uploaded_file, engine='openpyxl')
14
- else:
15
- st.error("Unsupported file format. Please upload a CSV or XLSX file.")
16
- return None
17
- return df
18
-
19
- def main():
20
- st.title("FeTaQA Table Question Answering")
21
-
22
- # File uploader
23
- uploaded_file = st.file_uploader("Choose a CSV or Excel file", type=["csv", "xlsx", "xls"])
24
-
25
- if uploaded_file is not None:
26
- df = load_data(uploaded_file)
27
-
 
 
 
 
 
 
 
 
 
28
  if df is not None:
29
- st.write("Uploaded Table:")
30
- st.dataframe(df)
31
-
32
- # Question input
33
- question = st.text_input("Ask a question about the table:")
34
-
35
- # Question history
36
- if 'question_history' not in st.session_state:
37
- st.session_state['question_history'] = []
38
-
39
- if st.button('Ask'):
40
- if question:
41
- answer = question_answering(df, question)
42
- st.write(f"Answer: {answer}")
43
- st.session_state['question_history'].append((question, answer))
44
-
45
- # Displaying history
46
- st.write("Question History:")
47
- for q, a in st.session_state['question_history'][-5:]: # Show last 5 questions
48
- st.write(f"**Q:** {q}")
49
- st.write(f"**A:** {a}")
50
- st.write("---")
51
 
52
- # Reset history
53
- if st.button('Clear History'):
54
- st.session_state['question_history'] = []
55
 
56
- if __name__ == "__main__":
57
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import streamlit as st
3
+ from st_aggrid import AgGrid
4
  import pandas as pd
5
+ from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
6
+
7
+ # Set the page layout for Streamlit
8
+ st.set_page_config(layout="wide")
9
+
10
+ # CSS styling
11
+ # ... (keep your existing CSS code)
12
+
13
+ # Initialize TAPAS pipeline
14
+ tqa = pipeline(task="table-question-answering",
15
+ model="google/tapas-large-finetuned-wtq",
16
+ device="cpu")
17
+
18
+ # Initialize T5 tokenizer and model for text generation
19
+ t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
20
+ t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
21
+
22
+ # File uploader in the sidebar
23
+ file_name = st.sidebar.file_uploader("Upload file:", type=['csv', 'xlsx'])
24
+
25
+ # File processing and question answering
26
+ if file_name is None:
27
+ st.markdown('<p class="font">Please upload an excel or csv file </p>', unsafe_allow_html=True)
28
+ else:
29
+ try:
30
+ # Check file type and handle reading accordingly
31
+ if file_name.name.endswith('.csv'):
32
+ df = pd.read_csv(file_name, sep=';', encoding='ISO-8859-1') # Adjust encoding if needed
33
+ elif file_name.name.endswith('.xlsx'):
34
+ df = pd.read_excel(file_name, engine='openpyxl') # Use openpyxl to read .xlsx files
35
+ else:
36
+ st.error("Unsupported file type")
37
+ df = None
38
+
39
  if df is not None:
40
+ numeric_columns = df.select_dtypes(include=['object']).columns
41
+ for col in numeric_columns:
42
+ df[col] = pd.to_numeric(df[col], errors='ignore')
43
+
44
+ st.write("Original Data:")
45
+ st.write(df)
46
+
47
+ df_numeric = df.copy()
48
+ df = df.astype(str)
49
+
50
+ # Display the first 5 rows of the dataframe in an editable grid
51
+ grid_response = AgGrid(
52
+ df.head(5),
53
+ columns_auto_size_mode='FIT_CONTENTS',
54
+ editable=True,
55
+ height=300,
56
+ width='100%',
57
+ )
 
 
 
 
58
 
59
+ except Exception as e:
60
+ st.error(f"Error reading file: {str(e)}")
 
61
 
62
+ # User input for the question
63
+ question = st.text_input('Type your question')
64
+
65
+ # Process the answer using TAPAS and T5
66
+ with st.spinner():
67
+ if st.button('Answer'):
68
+ try:
69
+ raw_answer = tqa(table=df, query=question, truncation=True)
70
+
71
+ st.markdown("<p style='font-family:sans-serif;font-size: 0.9rem;'> Raw Result From TAPAS: </p>",
72
+ unsafe_allow_html=True)
73
+ st.success(raw_answer)
74
+
75
+ answer = raw_answer['answer']
76
+ aggregator = raw_answer.get('aggregator', '')
77
+ coordinates = raw_answer.get('coordinates', [])
78
+ cells = raw_answer.get('cells', [])
79
+
80
+ if aggregator == 'SUM':
81
+ # Convert cell values to numbers and sum them
82
+ values = [float(cell) for cell in cells if cell.replace('.', '').isdigit()]
83
+ total_sum = sum(values)
84
+ base_sentence = f"The sum for '{question}' is {total_sum}."
85
+ else:
86
+ # Construct a base sentence for other aggregators or no aggregation
87
+ base_sentence = f"The answer from TAPAS for '{question}' is {answer}."
88
+ if coordinates and cells:
89
+ rows_info = [f"Row {coordinate[0] + 1}, Column '{df.columns[coordinate[1]]}' with value {cell}"
90
+ for coordinate, cell in zip(coordinates, cells)]
91
+ rows_description = " and ".join(rows_info)
92
+ base_sentence += f" This includes the following data: