File size: 5,510 Bytes
cbcd78b
 
 
ee68faf
cbcd78b
a32f02f
4530b74
52d8051
 
cbcd78b
a32f02f
4530b74
 
 
 
 
 
 
3fac692
3ae1eb6
cbcd78b
 
0180738
 
cbcd78b
3fac692
 
cbcd78b
 
0a5100e
9d03f28
 
 
3fac692
 
 
 
 
 
cbcd78b
 
 
b40dd56
 
 
9a2c26b
9d03f28
9a2c26b
36157a2
 
9d32e7a
 
 
 
 
 
 
36157a2
 
 
 
 
 
 
9d32e7a
 
 
b40dd56
 
 
0ad28ff
b40dd56
 
 
 
 
 
 
0ad28ff
 
 
 
 
 
 
b452d2a
9d03f28
b40dd56
0ad28ff
9d32e7a
7d44e84
ea5a489
cbcd78b
 
 
 
ea5a489
 
63651ec
848f580
 
1f0cb35
7d44e84
848f580
 
 
ea5a489
 
 
 
 
 
 
 
 
 
9a2c26b
0ad28ff
9a2c26b
0ad28ff
ecef70e
cbcd78b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
from transformers import pipeline
import pandas as pd
import spaces

# Load dataset
from datasets import load_dataset
ds = load_dataset('ZennyKenny/demo_customer_nps')
df = pd.DataFrame(ds['train'])

# Initialize model pipeline
from huggingface_hub import login
import os

# Login using the API key stored as an environment variable
hf_api_key = os.getenv("API_KEY")
login(token=hf_api_key)

classifier = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")
generator = pipeline("text2text-generation", model="google/flan-t5-base")

# Function to classify customer comments
# https://huggingface.co/docs/hub/spaces-zerogpu
@spaces.GPU
def classify_comments():
    sentiments = []
    categories = []
    results = []
    for comment in df['customer_comment']:
        sentiment = classifier(comment)[0]['label']
        category_list = [box.value for box in category_boxes if box.value.strip() != '']
        category_str = ', '.join([cat.strip() for cat in category_list])
        prompt = f"What category best describes this comment? '{comment}' Please answer using only the name of the category: {category_str}."
        category = generator(prompt, max_length=30)[0]['generated_text']
        categories.append(category)
        sentiments.append(sentiment)
    df['comment_sentiment'] = sentiments
    df['comment_category'] = categories
    return df[['customer_comment', 'comment_sentiment', 'comment_category']].to_html(index=False)

# Gradio Interface
with gr.Blocks() as nps:
    def add_category(category_list, new_category):
        if new_category.strip() != "":
            category_list.append(new_category.strip())  # Add new category
        return category_list

    category_boxes = gr.State([])  # Store category input boxes as state
    
    def display_categories(categories):
        category_column.clear()  # Clear previous categories
        for i, cat in enumerate(categories):
            with category_column:
                with gr.Row():
                    gr.Markdown(f"- {cat}")
                    remove_btn = gr.Button("X", elem_id=f"remove_{i}", interactive=True)
                    remove_btn.click(fn=lambda x=cat: remove_category(x, categories), inputs=[], outputs=category_boxes)
        category_components = []
        for i, cat in enumerate(categories):
            with gr.Row():
                gr.Markdown(f"- {cat}")
                remove_btn = gr.Button("X", elem_id=f"remove_{i}", interactive=True)
                remove_btn.click(fn=lambda x=cat: remove_category(x, categories), inputs=[], outputs=category_boxes)
        return category_components
    with gr.Row():
        category_input = gr.Textbox(label="New Category", placeholder="Enter category name")
        add_category_btn = gr.Button("Add Category")
    def remove_category(category, category_list):
        category_list.remove(category)  # Remove selected category
        return category_list
        components = []
        for i, cat in enumerate(categories):
            row = gr.Row([
                gr.Markdown(f"- {cat}"),
                gr.Button("X", elem_id=f"remove_{i}", interactive=True).click(fn=lambda x=cat: remove_category(x, categories), inputs=[], outputs=category_boxes)
            ])
            components.append(row)
        return components
        for i, cat in enumerate(categories):
            row = gr.Row([
                gr.Textbox(value=cat, label=f"Category {i+1}", interactive=True),
                gr.Button("X", elem_id=f"remove_{i}")
            ])
            components.append(row)
        return components
    category_column = gr.Row()
    add_category_btn = gr.Button("Add Category")
    add_category_btn.click(fn=add_category, inputs=[category_boxes, category_input], outputs=category_boxes)
    category_boxes.change(fn=display_categories, inputs=category_boxes, outputs=category_column)
    category_boxes.change(fn=display_categories, inputs=category_boxes, outputs=category_column)
    uploaded_file = gr.File(label="Upload CSV", type="filepath")
    template_btn = gr.Button("Use Template")
    gr.Markdown("# NPS Comment Categorization")
    classify_btn = gr.Button("Classify Comments")
    output = gr.HTML()

    def load_data(file):
        if file is not None:
            file.seek(0)  # Reset file pointer
            import io
            if file.name.endswith('.csv'):
                file.seek(0)  # Reset file pointer
                custom_df = pd.read_csv(file, encoding='utf-8')
                custom_df = pd.read_csv(io.StringIO(content))
            else:
                return "Error: Uploaded file is not a CSV."
            if 'customer_comment' not in custom_df.columns:
                return "Error: Uploaded CSV must contain a column named 'customer_comment'"
            global df
            df = custom_df
            return "Custom CSV loaded successfully!"
        else:
            return "No file uploaded."

    uploaded_file.change(fn=load_data, inputs=uploaded_file, outputs=output)
    template_btn.click(fn=lambda: "Using Template Dataset", outputs=output)
    def use_template():
        return ["Product Experience", "Customer Support", "Price of Service", "Other"]
    template_btn.click(fn=use_template, outputs=category_boxes)
    category_boxes.change(fn=display_categories, inputs=category_boxes, outputs=category_column)
    classify_btn.click(fn=classify_comments, inputs=category_boxes, outputs=output)

nps.launch()