Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,90 +1,131 @@
|
|
1 |
import gradio as gr
|
2 |
-
import pandas as pd
|
3 |
from transformers import pipeline
|
|
|
|
|
4 |
|
5 |
-
#
|
6 |
-
|
7 |
-
|
|
|
8 |
|
9 |
-
# Initialize
|
10 |
-
|
|
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
|
|
|
|
|
|
|
|
|
19 |
sentiments = []
|
20 |
-
|
|
|
21 |
for comment in df['customer_comment']:
|
22 |
sentiment = classifier(comment)[0]['label']
|
23 |
-
|
|
|
24 |
prompt = f"What category best describes this comment? '{comment}' Please answer using only the name of the category: {category_str}."
|
25 |
category = generator(prompt, max_length=30)[0]['generated_text']
|
|
|
26 |
sentiments.append(sentiment)
|
27 |
-
comment_categories.append(category)
|
28 |
df['comment_sentiment'] = sentiments
|
29 |
-
df['comment_category'] =
|
30 |
return df[['customer_comment', 'comment_sentiment', 'comment_category']].to_html(index=False)
|
31 |
|
32 |
-
# Function to load data from uploaded file
|
33 |
-
def load_data(file):
|
34 |
-
global df
|
35 |
-
if file is not None:
|
36 |
-
try:
|
37 |
-
df = pd.read_csv(file.name)
|
38 |
-
if 'customer_comment' not in df.columns:
|
39 |
-
return "Error: Uploaded CSV must contain a column named 'customer_comment'"
|
40 |
-
return "Custom CSV loaded successfully!"
|
41 |
-
except Exception as e:
|
42 |
-
return f"Error loading CSV: {e}"
|
43 |
-
else:
|
44 |
-
return "No file uploaded."
|
45 |
-
|
46 |
-
# Function to add a new category
|
47 |
-
def add_category(categories, new_category):
|
48 |
-
if new_category.strip() and new_category not in categories:
|
49 |
-
categories.append(new_category.strip())
|
50 |
-
return categories, gr.update(value="") # Clear the input box after adding
|
51 |
-
|
52 |
-
# Function to remove a category
|
53 |
-
def remove_category(categories, category_to_remove):
|
54 |
-
categories = [cat for cat in categories if cat != category_to_remove]
|
55 |
-
return categories
|
56 |
-
|
57 |
-
# Function to display categories with remove buttons
|
58 |
-
def display_categories(categories):
|
59 |
-
category_elements = []
|
60 |
-
for category in categories:
|
61 |
-
with gr.Row() as category_row:
|
62 |
-
gr.Markdown(f"- {category}")
|
63 |
-
remove_button = gr.Button("Remove")
|
64 |
-
remove_button.click(fn=remove_category, inputs=[gr.State(categories), gr.State(category)], outputs=gr.State(categories), queue=False)
|
65 |
-
category_elements.append(category_row)
|
66 |
-
return category_elements
|
67 |
-
|
68 |
# Gradio Interface
|
69 |
with gr.Blocks() as nps:
|
70 |
-
|
|
|
|
|
|
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
with gr.Row():
|
73 |
category_input = gr.Textbox(label="New Category", placeholder="Enter category name")
|
74 |
add_category_btn = gr.Button("Add Category")
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
category_column = gr.Column()
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
upload_output = gr.HTML()
|
84 |
-
uploaded_file.change(fn=load_data, inputs=uploaded_file, outputs=upload_output, queue=False)
|
85 |
-
|
86 |
classify_btn = gr.Button("Classify Comments")
|
87 |
output = gr.HTML()
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
nps.launch()
|
|
|
1 |
import gradio as gr
|
|
|
2 |
from transformers import pipeline
|
3 |
+
import pandas as pd
|
4 |
+
import spaces
|
5 |
|
6 |
+
# Load dataset
|
7 |
+
from datasets import load_dataset
|
8 |
+
ds = load_dataset('ZennyKenny/demo_customer_nps')
|
9 |
+
df = pd.DataFrame(ds['train'])
|
10 |
|
11 |
+
# Initialize model pipeline
|
12 |
+
from huggingface_hub import login
|
13 |
+
import os
|
14 |
|
15 |
+
# Login using the API key stored as an environment variable
|
16 |
+
hf_api_key = os.getenv("API_KEY")
|
17 |
+
login(token=hf_api_key)
|
18 |
+
|
19 |
+
classifier = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
|
20 |
+
generator = pipeline("text2text-generation", model="google/flan-t5-base")
|
21 |
|
22 |
+
# Function to classify customer comments
|
23 |
+
# https://huggingface.co/docs/hub/spaces-zerogpu
|
24 |
+
@spaces.GPU
|
25 |
+
def classify_comments():
|
26 |
sentiments = []
|
27 |
+
categories = []
|
28 |
+
results = []
|
29 |
for comment in df['customer_comment']:
|
30 |
sentiment = classifier(comment)[0]['label']
|
31 |
+
category_list = [box.value for box in category_boxes if box.value.strip() != '']
|
32 |
+
category_str = ', '.join([cat.strip() for cat in category_list])
|
33 |
prompt = f"What category best describes this comment? '{comment}' Please answer using only the name of the category: {category_str}."
|
34 |
category = generator(prompt, max_length=30)[0]['generated_text']
|
35 |
+
categories.append(category)
|
36 |
sentiments.append(sentiment)
|
|
|
37 |
df['comment_sentiment'] = sentiments
|
38 |
+
df['comment_category'] = categories
|
39 |
return df[['customer_comment', 'comment_sentiment', 'comment_category']].to_html(index=False)
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
# Gradio Interface
|
42 |
with gr.Blocks() as nps:
|
43 |
+
def add_category(category_list, new_category):
|
44 |
+
if new_category.strip() != "":
|
45 |
+
category_list.append(new_category.strip()) # Add new category
|
46 |
+
return category_list
|
47 |
|
48 |
+
category_boxes = gr.State([]) # Store category input boxes as state
|
49 |
+
|
50 |
+
def display_categories(categories):
|
51 |
+
category_column.children = [] # Clear previous categories
|
52 |
+
for i, cat in enumerate(categories):
|
53 |
+
with category_column:
|
54 |
+
with gr.Row():
|
55 |
+
gr.Markdown(f"- {cat}")
|
56 |
+
remove_btn = gr.Button("X", elem_id=f"remove_{i}", interactive=True)
|
57 |
+
remove_btn.click(fn=lambda x=cat: remove_category(x, categories), inputs=[], outputs=category_boxes)
|
58 |
+
category_column.children = [] # Reset children to clear previous categories
|
59 |
+
for i, cat in enumerate(categories):
|
60 |
+
with category_column:
|
61 |
+
with gr.Row():
|
62 |
+
gr.Markdown(f"- {cat}")
|
63 |
+
remove_btn = gr.Button("X", elem_id=f"remove_{i}", interactive=True)
|
64 |
+
remove_btn.click(fn=lambda x=cat: remove_category(x, categories), inputs=[], outputs=category_boxes)
|
65 |
+
category_components = []
|
66 |
+
for i, cat in enumerate(categories):
|
67 |
+
with gr.Row():
|
68 |
+
gr.Markdown(f"- {cat}")
|
69 |
+
remove_btn = gr.Button("X", elem_id=f"remove_{i}", interactive=True)
|
70 |
+
remove_btn.click(fn=lambda x=cat: remove_category(x, categories), inputs=[], outputs=category_boxes)
|
71 |
+
return category_components
|
72 |
with gr.Row():
|
73 |
category_input = gr.Textbox(label="New Category", placeholder="Enter category name")
|
74 |
add_category_btn = gr.Button("Add Category")
|
75 |
+
add_category_btn.click(fn=add_category, inputs=[category_boxes, category_input], outputs=category_boxes)
|
76 |
+
category_boxes.change(fn=display_categories, inputs=category_boxes, outputs=category_column)
|
77 |
+
def remove_category(category, category_list):
|
78 |
+
category_list.remove(category) # Remove selected category
|
79 |
+
return category_list
|
80 |
+
components = []
|
81 |
+
for i, cat in enumerate(categories):
|
82 |
+
row = gr.Row([
|
83 |
+
gr.Markdown(f"- {cat}"),
|
84 |
+
gr.Button("X", elem_id=f"remove_{i}", interactive=True).click(fn=lambda x=cat: remove_category(x, categories), inputs=[], outputs=category_boxes)
|
85 |
+
])
|
86 |
+
components.append(row)
|
87 |
+
return components
|
88 |
+
for i, cat in enumerate(categories):
|
89 |
+
row = gr.Row([
|
90 |
+
gr.Textbox(value=cat, label=f"Category {i+1}", interactive=True),
|
91 |
+
gr.Button("X", elem_id=f"remove_{i}")
|
92 |
+
])
|
93 |
+
components.append(row)
|
94 |
+
return components
|
95 |
category_column = gr.Column()
|
96 |
+
add_category_btn.click(fn=add_category, inputs=[category_boxes, category_input], outputs=category_boxes)
|
97 |
+
category_boxes.change(fn=display_categories, inputs=category_boxes, outputs=category_column)
|
98 |
+
category_boxes.change(fn=display_categories, inputs=category_boxes, outputs=category_column)
|
99 |
+
uploaded_file = gr.File(label="Upload CSV", type="filepath")
|
100 |
+
template_btn = gr.Button("Use Template")
|
101 |
+
gr.Markdown("# NPS Comment Categorization")
|
|
|
|
|
|
|
102 |
classify_btn = gr.Button("Classify Comments")
|
103 |
output = gr.HTML()
|
104 |
+
|
105 |
+
def load_data(file):
|
106 |
+
if file is not None:
|
107 |
+
file.seek(0) # Reset file pointer
|
108 |
+
import io
|
109 |
+
if file.name.endswith('.csv'):
|
110 |
+
file.seek(0) # Reset file pointer
|
111 |
+
custom_df = pd.read_csv(file, encoding='utf-8')
|
112 |
+
custom_df = pd.read_csv(io.StringIO(content))
|
113 |
+
else:
|
114 |
+
return "Error: Uploaded file is not a CSV."
|
115 |
+
if 'customer_comment' not in custom_df.columns:
|
116 |
+
return "Error: Uploaded CSV must contain a column named 'customer_comment'"
|
117 |
+
global df
|
118 |
+
df = custom_df
|
119 |
+
return "Custom CSV loaded successfully!"
|
120 |
+
else:
|
121 |
+
return "No file uploaded."
|
122 |
+
|
123 |
+
uploaded_file.change(fn=load_data, inputs=uploaded_file, outputs=output)
|
124 |
+
template_btn.click(fn=lambda: "Using Template Dataset", outputs=output)
|
125 |
+
def use_template():
|
126 |
+
return ["Product Experience", "Customer Support", "Price of Service", "Other"]
|
127 |
+
template_btn.click(fn=use_template, outputs=category_boxes)
|
128 |
+
category_boxes.change(fn=display_categories, inputs=category_boxes, outputs=category_column)
|
129 |
+
classify_btn.click(fn=classify_comments, inputs=category_boxes, outputs=output)
|
130 |
|
131 |
nps.launch()
|