|
from transformers import TapexTokenizer, BartForConditionalGeneration |
|
import pandas as pd |
|
import datetime |
|
import torch |
|
import gradio as gr |
|
|
|
def execute_query(query, csv_file): |
|
a = datetime.datetime.now() |
|
|
|
table = pd.read_csv(csv_file.name, delimiter=",") |
|
table = table.astype(str) |
|
|
|
model_name = "microsoft/tapex-large-finetuned-wtq" |
|
model = BartForConditionalGeneration.from_pretrained(model_name) |
|
tokenizer = TapexTokenizer.from_pretrained(model_name) |
|
|
|
queries = [query] |
|
|
|
encoding = tokenizer(table=table, query=queries, padding=True, return_tensors="pt", truncation=True) |
|
outputs = model.generate(**encoding) |
|
ans = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
query_result = { |
|
"query": query, |
|
"answer": ans[0] |
|
} |
|
|
|
b = datetime.datetime.now() |
|
print(b - a) |
|
|
|
return query_result, table |
|
|
|
def main(): |
|
description = "Querying a CSV using the TAPEX model. You can ask a question about tabular data, and the TAPEX model will produce the result. The finetuned TAPEX model runs on data with a maximum of 5000 rows and 20 columns. A sample dataset of Shopify store sales is provided." |
|
|
|
article = "<p style='text-align: center'><a href='https://unscrambl.com/' target='_blank'>Unscrambl</a> | <a href='https://huggingface.co/microsoft/tapex-large-finetuned-wtq' target='_blank'>TAPEX Model</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=abaranovskij_tablequery' alt='visitor badge'></center>" |
|
|
|
iface = gr.Interface(fn=execute_query, |
|
inputs=[gr.Textbox(label="Search query"), |
|
gr.File(label="CSV file")], |
|
outputs=[gr.JSON(label="Result"), |
|
gr.Dataframe(label="All data")], |
|
title="Table Question Answering (TAPEX)", |
|
description=description, |
|
article=article, |
|
allow_flagging='never') |
|
|
|
|
|
iface.launch(enable_queue=True) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|