import gradio as gr import pandas as pd from transformers import pipeline # Initialize model pipelines classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english") generator = pipeline("text2text-generation", model="google/flan-t5-base") # Initialize an empty DataFrame df = pd.DataFrame() # Function to classify customer comments def classify_comments(categories): if df.empty: return "No data loaded." if not categories: return "No categories defined." sentiments = [] comment_categories = [] for comment in df['customer_comment']: sentiment = classifier(comment)[0]['label'] category_str = ', '.join(categories) prompt = f"What category best describes this comment? '{comment}' Please answer using only the name of the category: {category_str}." category = generator(prompt, max_length=30)[0]['generated_text'] sentiments.append(sentiment) comment_categories.append(category) df['comment_sentiment'] = sentiments df['comment_category'] = comment_categories return df[['customer_comment', 'comment_sentiment', 'comment_category']].to_html(index=False) # Function to load data from uploaded file def load_data(file): global df if file is not None: try: df = pd.read_csv(file.name) if 'customer_comment' not in df.columns: return "Error: Uploaded CSV must contain a column named 'customer_comment'" return "Custom CSV loaded successfully!" except Exception as e: return f"Error loading CSV: {e}" else: return "No file uploaded." # Function to add a new category def add_category(categories, new_category): if new_category.strip() and new_category not in categories: categories.append(new_category.strip()) return categories, gr.update(value="") # Clear the input box after adding # Function to remove a category def remove_category(categories, category_to_remove): categories = [cat for cat in categories if cat != category_to_remove] return categories # Function to display categories with remove buttons def display_categories(categories): category_elements = [] for category in categories: with gr.Row() as category_row: gr.Markdown(f"- {category}") remove_button = gr.Button("Remove") remove_button.click(fn=remove_category, inputs=[gr.State(categories), gr.State(category)], outputs=gr.State(categories), queue=False) category_elements.append(category_row) return category_elements # Gradio Interface with gr.Blocks() as nps: gr.Markdown("# NPS Comment Categorization") with gr.Row(): category_input = gr.Textbox(label="New Category", placeholder="Enter category name") add_category_btn = gr.Button("Add Category") category_column = gr.Column() categories = gr.State([]) # Initialize an empty list for categories add_category_btn.click(fn=add_category, inputs=[categories, category_input], outputs=[categories, category_input], queue=False) categories.change(fn=display_categories, inputs=categories, outputs=category_column, queue=False) uploaded_file = gr.File(label="Upload CSV", type="file") upload_output = gr.HTML() uploaded_file.change(fn=load_data, inputs=uploaded_file, outputs=upload_output, queue=False) classify_btn = gr.Button("Classify Comments") output = gr.HTML() classify_btn.click(fn=classify_comments, inputs=categories, outputs=output, queue=False) nps.launch()