Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import pandas as pd | |
from transformers import pipeline | |
# Initialize model pipelines | |
classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english") | |
generator = pipeline("text2text-generation", model="google/flan-t5-base") | |
# Initialize an empty DataFrame | |
df = pd.DataFrame() | |
# Function to classify customer comments | |
def classify_comments(categories): | |
if df.empty: | |
return "No data loaded." | |
if not categories: | |
return "No categories defined." | |
sentiments = [] | |
comment_categories = [] | |
for comment in df['customer_comment']: | |
sentiment = classifier(comment)[0]['label'] | |
category_str = ', '.join(categories) | |
prompt = f"What category best describes this comment? '{comment}' Please answer using only the name of the category: {category_str}." | |
category = generator(prompt, max_length=30)[0]['generated_text'] | |
sentiments.append(sentiment) | |
comment_categories.append(category) | |
df['comment_sentiment'] = sentiments | |
df['comment_category'] = comment_categories | |
return df[['customer_comment', 'comment_sentiment', 'comment_category']].to_html(index=False) | |
# Function to load data from uploaded file | |
def load_data(file): | |
global df | |
if file is not None: | |
try: | |
df = pd.read_csv(file.name) | |
if 'customer_comment' not in df.columns: | |
return "Error: Uploaded CSV must contain a column named 'customer_comment'" | |
return "Custom CSV loaded successfully!" | |
except Exception as e: | |
return f"Error loading CSV: {e}" | |
else: | |
return "No file uploaded." | |
# Function to add a new category | |
def add_category(categories, new_category): | |
if new_category.strip() and new_category not in categories: | |
categories.append(new_category.strip()) | |
return categories, gr.update(value="") # Clear the input box after adding | |
# Function to remove a category | |
def remove_category(categories, category_to_remove): | |
categories = [cat for cat in categories if cat != category_to_remove] | |
return categories | |
# Function to display categories with remove buttons | |
def display_categories(categories): | |
category_elements = [] | |
for category in categories: | |
with gr.Row() as category_row: | |
gr.Markdown(f"- {category}") | |
remove_button = gr.Button("Remove") | |
remove_button.click(fn=remove_category, inputs=[gr.State(categories), gr.State(category)], outputs=gr.State(categories), queue=False) | |
category_elements.append(category_row) | |
return category_elements | |
# Gradio Interface | |
with gr.Blocks() as nps: | |
gr.Markdown("# NPS Comment Categorization") | |
with gr.Row(): | |
category_input = gr.Textbox(label="New Category", placeholder="Enter category name") | |
add_category_btn = gr.Button("Add Category") | |
category_column = gr.Column() | |
categories = gr.State([]) # Initialize an empty list for categories | |
add_category_btn.click(fn=add_category, inputs=[categories, category_input], outputs=[categories, category_input], queue=False) | |
categories.change(fn=display_categories, inputs=categories, outputs=category_column, queue=False) | |
uploaded_file = gr.File(label="Upload CSV", type="file") | |
upload_output = gr.HTML() | |
uploaded_file.change(fn=load_data, inputs=uploaded_file, outputs=upload_output, queue=False) | |
classify_btn = gr.Button("Classify Comments") | |
output = gr.HTML() | |
classify_btn.click(fn=classify_comments, inputs=categories, outputs=output, queue=False) | |
nps.launch() | |