AniMAntZeZo's picture
Upload 20 files
c7c2507 verified
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
def plot_emotion_confusion_matrix(results_df, emotion_columns):
correct_count = {emotion: 0 for emotion in emotion_columns}
incorrect_count = {emotion: 0 for emotion in emotion_columns}
undefined_count = {emotion: 0 for emotion in emotion_columns}
for idx, row in results_df.iterrows():
true_emotions = set(row['true emotions'].split()) if isinstance(row['true emotions'], str) else set()
predicted_emotions = set(row['predict emotions'].split()) if isinstance(row['predict emotions'], str) else set()
for emotion in emotion_columns:
if emotion in true_emotions and emotion in predicted_emotions:
correct_count[emotion] += 1
elif emotion in predicted_emotions and emotion not in true_emotions:
incorrect_count[emotion] += 1
elif emotion in true_emotions and emotion not in predicted_emotions:
undefined_count[emotion] += 1
data = []
for emotion in emotion_columns:
data.append([
correct_count[emotion],
incorrect_count[emotion],
undefined_count[emotion]
])
heatmap_df = pd.DataFrame(data, columns=["Correctly Identified", "Incorrectly Identified", "Undefined"], index=emotion_columns)
num_examples = len(results_df)
plt.figure(figsize=(10, 12))
sns.heatmap(heatmap_df, annot=True, cmap="Blues", fmt="d", cbar=False)
plt.title(f"Emotion Prediction Confusion Matrix (Examples: {num_examples})")
plt.xlabel("Prediction Status")
plt.ylabel("Emotion")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()
def plot_true_emotion_frequency(results_df, emotion_columns):
true_emotion_count = {emotion: 0 for emotion in emotion_columns}
for idx, row in results_df.iterrows():
true_emotions = set(row['true emotions'].split()) if isinstance(row['true emotions'], str) else set()
for emotion in emotion_columns:
if emotion in true_emotions:
true_emotion_count[emotion] += 1
data = []
for emotion in emotion_columns:
data.append([true_emotion_count[emotion]])
heatmap_df = pd.DataFrame(data, columns=["True Emotion Count"], index=emotion_columns)
plt.figure(figsize=(10, 12))
sns.heatmap(heatmap_df, annot=True, cmap="YlGnBu", fmt="d", cbar=False)
plt.title(f"True Emotion Frequency (Examples: {len(results_df)})")
plt.xlabel("True Emotion Count")
plt.ylabel("Emotion")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()
def plot_predicted_emotion_frequency(results_df, emotion_columns):
predicted_emotion_count = {emotion: 0 for emotion in emotion_columns}
for idx, row in results_df.iterrows():
predicted_emotions = set(row['predict emotions'].split()) if isinstance(row['predict emotions'], str) else set()
for emotion in emotion_columns:
if emotion in predicted_emotions:
predicted_emotion_count[emotion] += 1
data = []
for emotion in emotion_columns:
data.append([predicted_emotion_count[emotion]])
heatmap_df = pd.DataFrame(data, columns=["Predicted Emotion Count"], index=emotion_columns)
plt.figure(figsize=(10, 12))
sns.heatmap(heatmap_df, annot=True, cmap="YlOrRd", fmt="d", cbar=False)
plt.title(f"Predicted Emotion Frequency (Examples: {len(results_df)})")
plt.xlabel("Predicted Emotion Count")
plt.ylabel("Emotion")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()
csv_path = "RuBert-tiny2-EmotionsDetected/Dstasets/Emotions_detected.csv"
results_df = pd.read_csv(csv_path)
emotion_columns = [
"admiration", "amusement", "anger", "annoyance", "approval", "caring", "confusion", "curiosity", "desire",
"disappointment", "disapproval", "disgust", "embarrassment", "excitement", "fear", "gratitude", "grief", "joy",
"love", "nervousness", "optimism", "pride", "realization", "relief", "remorse", "sadness", "surprise", "neutral"
]
plot_true_emotion_frequency(results_df, emotion_columns)
plot_predicted_emotion_frequency(results_df, emotion_columns)
plot_emotion_confusion_matrix(results_df, emotion_columns)