avichr commited on
Commit
eb2e6c5
·
1 Parent(s): 771e9d0

Create HebEMO.py

Browse files
Files changed (1) hide show
  1. HebEMO.py +85 -0
HebEMO.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class HebEMO:
2
+ def __init__(self, device=0, emotions = ['expectation', 'happy', 'trust', 'fear', 'surprise', 'anger',
3
+ 'sadness', 'disgust']):
4
+
5
+ from transformers import pipeline
6
+ from tqdm import tqdm
7
+
8
+ self.device = device
9
+ self.emotions = emotions
10
+ self.hebemo_models = {}
11
+
12
+ for emo in tqdm(emotions):
13
+ self.hebemo_models[emo] = pipeline(
14
+ "sentiment-analysis",
15
+ model="../hebEMO/"+emo+'_classifier',
16
+ tokenizer="../heBERT_base_oscar",
17
+ device = self.device #run on GPU
18
+ )
19
+
20
+
21
+ def hebemo(self, text = None, input_path=False, save_results=False, read_lines=False, plot=False):
22
+ '''
23
+ text (str): a text or list of text to analyze
24
+ input_path(str): the path to the text file (txt file, each row for different instance)
25
+ returns pandas DataFrame of the analyzed texts and save it to the same dir of the input file
26
+ '''
27
+ from pyplutchik import plutchik
28
+ import matplotlib.pyplot as plt
29
+ import pandas as pd
30
+ import time
31
+ import torch
32
+ from tqdm import tqdm
33
+
34
+
35
+ if text is None and type(input_path) is str:
36
+ # read the file
37
+ with open(input_path, encoding='utf8') as p:
38
+ txt = p.readlines()
39
+
40
+ elif text is not None and (input_path is None or input_path is False):
41
+ if type(text) is str:
42
+ if read_lines:
43
+ txt = text.split('\n')
44
+ else:
45
+ txt = [text]
46
+ elif type(text) is list:
47
+ txt = text
48
+ else:
49
+ raise ValueError('text should be text or list of text.')
50
+ else:
51
+ raise ValueError('you should provide a text string, list of strings or text path.')
52
+
53
+
54
+
55
+
56
+ # run hebEMO
57
+ hebEMO_df = pd.DataFrame(txt)
58
+ for emo in tqdm(self.emotions):
59
+ x = self.hebemo_models[emo](txt)
60
+ hebEMO_df = hebEMO_df.join(pd.DataFrame(x).rename(columns = {'label': emo, 'score':'confidence_'+emo}))
61
+ del x
62
+ torch.cuda.empty_cache()
63
+ hebEMO_df = hebEMO_df.applymap(lambda x: 0 if x=='LABEL_0' else 1 if x=='LABEL_1' else x)
64
+
65
+ if save_results is not False:
66
+ gen_name = str(int(time.time()*1e7))
67
+ if type(save_results) is str:
68
+ hebEMO_df.to_csv(save_results+'/'+gen_name+'_heEMOed.csv', encoding='utf8')
69
+ else:
70
+ hebEMO_df.to_csv(gen_name+'_heEMOed.csv', encoding='utf8')
71
+
72
+ if plot:
73
+ hebEMO = pd.DataFrame()
74
+ for emo in hebEMO_df.columns[1::2]:
75
+ hebEMO[emo] = abs(hebEMO_df[emo]-(1-hebEMO_df['confidence_'+emo]))
76
+ hebEMO = hebEMO.rename(columns= {'happy': 'joy', 'expectation':'anticipation'})
77
+
78
+ for i in range(0,1):
79
+ ax = plutchik(hebEMO.to_dict(orient='records')[i])
80
+ print(hebEMO_df[0][i])
81
+ plt.show()
82
+ return (plt.figure())
83
+ else:
84
+ return (hebEMO_df)
85
+ HebEMO_model = HebEMO()