ZennyKenny's picture
Update app.py
1ea874c verified
raw
history blame
3.59 kB
import gradio as gr
import pandas as pd
from transformers import pipeline
# Initialize model pipelines
classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
generator = pipeline("text2text-generation", model="google/flan-t5-base")
# Initialize an empty DataFrame
df = pd.DataFrame()
# Function to classify customer comments
def classify_comments(categories):
if df.empty:
return "No data loaded."
if not categories:
return "No categories defined."
sentiments = []
comment_categories = []
for comment in df['customer_comment']:
sentiment = classifier(comment)[0]['label']
category_str = ', '.join(categories)
prompt = f"What category best describes this comment? '{comment}' Please answer using only the name of the category: {category_str}."
category = generator(prompt, max_length=30)[0]['generated_text']
sentiments.append(sentiment)
comment_categories.append(category)
df['comment_sentiment'] = sentiments
df['comment_category'] = comment_categories
return df[['customer_comment', 'comment_sentiment', 'comment_category']].to_html(index=False)
# Function to load data from uploaded file
def load_data(file):
global df
if file is not None:
try:
df = pd.read_csv(file.name)
if 'customer_comment' not in df.columns:
return "Error: Uploaded CSV must contain a column named 'customer_comment'"
return "Custom CSV loaded successfully!"
except Exception as e:
return f"Error loading CSV: {e}"
else:
return "No file uploaded."
# Function to add a new category
def add_category(categories, new_category):
if new_category.strip() and new_category not in categories:
categories.append(new_category.strip())
return categories, gr.update(value="") # Clear the input box after adding
# Function to remove a category
def remove_category(categories, category_to_remove):
categories = [cat for cat in categories if cat != category_to_remove]
return categories
# Function to display categories with remove buttons
def display_categories(categories):
category_elements = []
for category in categories:
with gr.Row() as category_row:
gr.Markdown(f"- {category}")
remove_button = gr.Button("Remove")
remove_button.click(fn=remove_category, inputs=[gr.State(categories), gr.State(category)], outputs=gr.State(categories), queue=False)
category_elements.append(category_row)
return category_elements
# Gradio Interface
with gr.Blocks() as nps:
gr.Markdown("# NPS Comment Categorization")
with gr.Row():
category_input = gr.Textbox(label="New Category", placeholder="Enter category name")
add_category_btn = gr.Button("Add Category")
category_column = gr.Column()
categories = gr.State([]) # Initialize an empty list for categories
add_category_btn.click(fn=add_category, inputs=[categories, category_input], outputs=[categories, category_input], queue=False)
categories.change(fn=display_categories, inputs=categories, outputs=category_column, queue=False)
uploaded_file = gr.File(label="Upload CSV", type="file")
upload_output = gr.HTML()
uploaded_file.change(fn=load_data, inputs=uploaded_file, outputs=upload_output, queue=False)
classify_btn = gr.Button("Classify Comments")
output = gr.HTML()
classify_btn.click(fn=classify_comments, inputs=categories, outputs=output, queue=False)
nps.launch()