File size: 3,183 Bytes
eb2e6c5
a3b44bc
eb2e6c5
 
 
 
 
 
 
 
 
 
 
 
6dd4d02
 
a3b44bc
eb2e6c5
 
 
 
 
 
 
 
 
af21e08
 
c8c9e7c
eb2e6c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8c9e7c
 
eb2e6c5
 
dbcc766
eb2e6c5
 
e021892
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class HebEMO:
    def __init__(self, device=-1, emotions = ['anticipation', 'joy', 'trust', 'fear', 'surprise', 'anger',
      'sadness', 'disgust']):

        from transformers import pipeline
        from tqdm import tqdm
        
        self.device = device
        self.emotions = emotions
        self.hebemo_models = {}

        for emo in tqdm(emotions): 
            self.hebemo_models[emo] = pipeline(
                "sentiment-analysis",
                model="avichr/hebEMO_"+emo,
                tokenizer="avichr/heBERT",
                device = self.device #-1 run on CPU, else - device ID
            )
    

    def hebemo(self, text = None, input_path=False, save_results=False, read_lines=False, plot=False):
        '''
        text (str): a text or list of text to analyze
        input_path(str): the path to the text file (txt file, each row for different instance)
        returns pandas DataFrame of the analyzed texts and save it to the same dir of the input file
        '''
        from pyplutchik import plutchik
        from spider_plot import spider_plot
        
        import matplotlib.pyplot as plt
        import pandas as pd
        import time
        import torch
        from tqdm import tqdm


        if text is None and type(input_path) is str:
            # read the file
            with open(input_path, encoding='utf8') as p:
                txt = p.readlines()

        elif text is not None and (input_path is None or input_path is False):
            if type(text) is str:
                if read_lines:
                    txt = text.split('\n')
                else:
                    txt = [text]
            elif type(text) is list:
                txt = text
            else: 
                raise ValueError('text should be text or list of text.')
        else:
            raise ValueError('you should provide a text string, list of strings or text path.')




        # run hebEMO
        hebEMO_df = pd.DataFrame(txt) 
        for emo in tqdm(self.emotions): 
            x = self.hebemo_models[emo](txt)
            hebEMO_df = hebEMO_df.join(pd.DataFrame(x).rename(columns = {'label': emo, 'score':'confidence_'+emo}))
            del x
            torch.cuda.empty_cache()
        hebEMO_df = hebEMO_df.applymap(lambda x: 0 if x=='LABEL_0' else 1 if x=='LABEL_1' else x)

        if save_results is not False:
            gen_name = str(int(time.time()*1e7))
            if type(save_results) is str:      
                hebEMO_df.to_csv(save_results+'/'+gen_name+'_heEMOed.csv', encoding='utf8')
            else: 
                hebEMO_df.to_csv(gen_name+'_heEMOed.csv', encoding='utf8')

        if plot:
            hebEMO = pd.DataFrame()
            for emo in hebEMO_df.columns[1::2]:
                hebEMO[emo] = abs(hebEMO_df[emo]-(1-hebEMO_df['confidence_'+emo]))

            for i in range(0,1):    
                try: ax = plutchik(hebEMO.to_dict(orient='records')[i])
                except: ax = spider_plot(hebEMO)
                print(hebEMO_df[0][i])
                plt.show()
            return (hebEMO_df[0][i], ax)
        else:
            return (hebEMO_df)
# HebEMO_model = HebEMO()