Rezky Mulia Kam commited on
Commit
08e222d
·
verified ·
1 Parent(s): 6071c10

Delete _multiclass_confusion_matrix.py

Browse files
Files changed (1) hide show
  1. _multiclass_confusion_matrix.py +0 -161
_multiclass_confusion_matrix.py DELETED
@@ -1,161 +0,0 @@
1
- import pandas as pd
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
- from sklearn.metrics import confusion_matrix
5
- import seaborn as sns
6
- import matplotlib
7
- matplotlib.use('Qt5Agg')
8
- import matplotlib.pyplot as plt
9
- from sklearn.model_selection import train_test_split
10
- import numpy as np
11
- import os
12
- os.environ['QT_QPA_PLATFORM'] = 'xcb'
13
-
14
- # Define label mappings
15
- label_map = {0: 'sadness', 1: 'joy', 2: 'love', 3: 'anger', 4: 'fear', 5: 'surprise'}
16
- reverse_label_map = {v: k for k, v in label_map.items()} # Reverse mapping for converting labels to integers
17
-
18
- # Load the dataset
19
- df = pd.read_csv('./dataset/emotions.csv')
20
-
21
- # Ensure the 'label' column exists
22
- if 'label' not in df.columns:
23
- print("Error: 'label' column is missing from the dataset.")
24
- exit(1)
25
-
26
- # Convert text labels to numeric if they're not already numeric
27
- if df['label'].dtype == 'object':
28
- df['label'] = df['label'].map(reverse_label_map)
29
-
30
- # Verify label conversion
31
- if df['label'].isnull().any():
32
- print("Error: Some labels could not be mapped properly.")
33
- exit(1)
34
-
35
- # Sample a smaller subset for faster debugging
36
- sample_size = 20000 # Adjust sample size as needed
37
- df_sampled = df.sample(n=sample_size, random_state=42)
38
-
39
-
40
- # Split the sampled dataset
41
- train_texts, val_texts, train_labels, val_labels = train_test_split(
42
- df_sampled['text'].tolist(),
43
- df_sampled['label'].tolist(),
44
- test_size=0.2,
45
- random_state=42
46
- )
47
-
48
- model_6_path = "./models/stardust_6"
49
- tokenizer = AutoTokenizer.from_pretrained(model_6_path)
50
- model = AutoModelForSequenceClassification.from_pretrained(model_6_path, num_labels=6)
51
- model.eval() # Set model to evaluation mode
52
-
53
- # Define a function for tokenization and encoding
54
- def tokenize_and_encode(texts, labels):
55
- inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
56
- inputs['labels'] = torch.tensor(labels)
57
- return inputs
58
-
59
- # Create datasets with labels
60
- train_dataset = tokenize_and_encode(train_texts, train_labels)
61
- val_dataset = tokenize_and_encode(val_texts, val_labels)
62
-
63
- # Move model to GPU if available
64
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
- model.to(device)
66
-
67
- # Move validation inputs to the device
68
- val_inputs = {k: v.to(device) for k, v in val_dataset.items() if k != 'labels'}
69
- val_labels = val_dataset['labels'].to(device)
70
-
71
- def plot_classification_analysis(val_labels, val_inputs, model, label_map):
72
- # Convert labels if they're one-hot encoded
73
- true_labels = val_labels.argmax(dim=-1).cpu().numpy() if len(val_labels.shape) > 1 else val_labels.cpu().numpy()
74
-
75
- with torch.no_grad():
76
- # Get the raw logits from the model
77
- outputs = model(**val_inputs)
78
- logits = outputs.logits.cpu().numpy()
79
-
80
- # Calculate softmax probabilities
81
- probabilities = np.exp(logits) / np.exp(logits).sum(axis=1, keepdims=True)
82
- predictions_softmax = np.argmax(probabilities, axis=-1)
83
-
84
- # Convert label_map to list for plotting
85
- label_map_list = list(label_map.values())
86
-
87
- # Create figure with two subplots
88
- fig, axes = plt.subplots(1, 2, figsize=(20, 8))
89
-
90
- # First subplot: Confusion Matrix
91
- cm_softmax = confusion_matrix(true_labels, predictions_softmax)
92
- sns.heatmap(
93
- cm_softmax,
94
- annot=True,
95
- fmt="d",
96
- cmap="Oranges",
97
- xticklabels=label_map_list,
98
- yticklabels=label_map_list,
99
- ax=axes[0],
100
- square=True
101
- )
102
- axes[0].set_xlabel("Prediction")
103
- axes[0].set_ylabel("Truth")
104
- axes[0].set_title(f"Softmax [{sample_size}]")
105
-
106
- # Rotate x-axis labels for better readability
107
- axes[0].set_xticklabels(axes[0].get_xticklabels(), rotation=45, ha='right')
108
- axes[0].set_yticklabels(axes[0].get_yticklabels(), rotation=0)
109
-
110
- # Second subplot: Raw Logits Heatmap
111
- sample_size_r = min(sample_size, logits.shape[0]) # Show up to 50 samples
112
- logits_subset = logits[:sample_size_r]
113
-
114
- sns.heatmap(
115
- logits_subset,
116
- annot=False,
117
- cmap="Oranges",
118
- cbar=True,
119
- xticklabels=label_map_list,
120
- yticklabels=False,
121
- ax=axes[1]
122
- )
123
- axes[1].set_xlabel("Classes")
124
- axes[1].set_ylabel("Samples")
125
- axes[1].set_title(f"Logits Distribution [{sample_size}]")
126
-
127
- # Rotate x-axis labels for better readability
128
- axes[1].set_xticklabels(axes[1].get_xticklabels(), rotation=45, ha='right')
129
-
130
- # Add color bar labels
131
- for im, title in zip(axes, ['Number of Samples', 'Logit Value']):
132
- cbar = im.collections[0].colorbar
133
- cbar.set_label(title)
134
-
135
- plt.tight_layout()
136
-
137
- # Calculate and return additional metrics
138
- metrics = {
139
- 'confusion_matrix': cm_softmax,
140
- 'raw_logits_stats': {
141
- 'mean': np.mean(logits, axis=0),
142
- 'std': np.std(logits, axis=0),
143
- 'min': np.min(logits, axis=0),
144
- 'max': np.max(logits, axis=0)
145
- }
146
- }
147
-
148
- return fig, metrics
149
-
150
- fig, metrics = plot_classification_analysis(
151
- val_labels=val_labels,
152
- val_inputs=val_inputs,
153
- model=model,
154
- label_map=label_map
155
- )
156
-
157
- plt.show()
158
-
159
-
160
-
161
-