Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import pipeline | |
import pandas as pd | |
import spaces | |
# Load dataset | |
from datasets import load_dataset | |
ds = load_dataset('ZennyKenny/demo_customer_nps') | |
df = pd.DataFrame(ds['train']) | |
# Initialize model pipeline | |
from huggingface_hub import login | |
import os | |
# Login using the API key stored as an environment variable | |
hf_api_key = os.getenv("API_KEY") | |
login(token=hf_api_key) | |
classifier = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english") | |
generator = pipeline("text2text-generation", model="google/flan-t5-base") | |
# Function to classify customer comments | |
def classify_comments(category_boxes): | |
sentiments = [] | |
categories = [] | |
for comment in df['customer_comment']: | |
sentiment = classifier(comment)[0]['label'] | |
category_list = [box for box in category_boxes if box.strip() != ''] | |
category_str = ', '.join([cat.strip() for cat in category_list]) | |
prompt = f"What category best describes this comment? '{comment}' Please answer using only the name of the category: {category_str}." | |
category = generator(prompt, max_length=30)[0]['generated_text'] | |
categories.append(category) | |
sentiments.append(sentiment) | |
df['comment_sentiment'] = sentiments | |
df['comment_category'] = categories | |
return df[['customer_comment', 'comment_sentiment', 'comment_category']].to_html(index=False) | |
# Gradio Interface | |
with gr.Blocks() as nps: | |
def add_category(category_list, new_category): | |
if new_category.strip() != "": | |
category_list.append(new_category.strip()) # Add new category | |
return category_list | |
def remove_category(category, category_list): | |
category_list.remove(category) # Remove selected category | |
return category_list | |
def display_categories(categories): | |
category_components = [] | |
for i, cat in enumerate(categories): | |
with gr.Row(): | |
gr.Markdown(f"- {cat}") | |
remove_btn = gr.Button("X", elem_id=f"remove_{i}", interactive=True) | |
remove_btn.click( | |
fn=lambda x=cat: remove_category(x, categories), | |
inputs=[], | |
outputs=category_boxes | |
) | |
category_components.append(gr.Row()) | |
return category_components | |
category_boxes = gr.State([]) # Store category input boxes as state | |
category_column = gr.Column() | |
with gr.Row(): | |
category_input = gr.Textbox(label="New Category", placeholder="Enter category name") | |
add_category_btn = gr.Button("Add Category") | |
add_category_btn.click( | |
fn=add_category, | |
inputs=[category_boxes, category_input], | |
outputs=category_boxes | |
) | |
category_boxes.change( | |
fn=display_categories, | |
inputs=category_boxes, | |
outputs=category_column | |
) | |
uploaded_file = gr.File(label="Upload CSV", type="filepath") | |
template_btn = gr.Button("Use Template") | |
gr.Markdown("# NPS Comment Categorization") | |
classify_btn = gr.Button("Classify Comments") | |
output = gr.HTML() | |
def load_data(file): | |
if file is not None: | |
file.seek(0) # Reset file pointer | |
import io | |
if file.name.endswith('.csv'): | |
custom_df = pd.read_csv(file, encoding='utf-8') | |
else: | |
return "Error: Uploaded file is not a CSV." | |
if 'customer_comment' not in custom_df.columns: | |
return "Error: Uploaded CSV must contain a column named 'customer_comment'" | |
global df | |
df = custom_df | |
return "Custom CSV loaded successfully!" | |
else: | |
return "No file uploaded." | |
uploaded_file.change(fn=load_data, inputs=uploaded_file, outputs=output) | |
template_btn.click(fn=lambda: "Using Template Dataset", outputs=output) | |
def use_template(): | |
return ["Product Experience", "Customer Support", "Price of Service", "Other"] | |
template_btn.click(fn=use_template, outputs=category_boxes) | |
category_boxes.change(fn=display_categories, inputs=category_boxes, outputs=category_column) | |
classify_btn.click(fn=classify_comments, inputs=category_boxes, outputs=output) | |
nps.launch(share=True) |