Spaces:
Running
Running
Commit
·
18810b9
1
Parent(s):
005ef83
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
import altair as alt
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from datetime import date, timedelta
|
12 |
+
|
13 |
+
from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForSequenceClassification
|
14 |
+
|
15 |
+
def inference_sentence(text):
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained(inference_modelpath)
|
17 |
+
model = AutoModelForSequenceClassification.from_pretrained(inference_modelpath)
|
18 |
+
for text in tqdm([text]):
|
19 |
+
inputs = tokenizer(text, return_tensors="pt")
|
20 |
+
with torch.no_grad(): # run model
|
21 |
+
logits = model(**inputs).logits
|
22 |
+
predicted_class_id = logits.argmax().item()
|
23 |
+
output = model.config.id2label[predicted_class_id]
|
24 |
+
return "Predicted emotion:\n" + output
|
25 |
+
|
26 |
+
def freq(file_output):
|
27 |
+
f = open(file_output, 'r')
|
28 |
+
data = f.read().split("\n")
|
29 |
+
f.close()
|
30 |
+
data = [line.split(",") for line in data[1:-1]]
|
31 |
+
|
32 |
+
freq_dict = {}
|
33 |
+
for line in data:
|
34 |
+
if line[1] not in freq_dict.keys():
|
35 |
+
freq_dict[line[1]] = 1
|
36 |
+
else:
|
37 |
+
freq_dict[line[1]] += 1
|
38 |
+
|
39 |
+
simple = pd.DataFrame({
|
40 |
+
'Emotion category': ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness'],
|
41 |
+
'Frequency': [freq_dict['0'], freq_dict['1'], freq_dict['2'], freq_dict['3'], freq_dict['4'], freq_dict['5']]})
|
42 |
+
|
43 |
+
domain = ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']
|
44 |
+
range_ = ['#999999', '#b22222', '#663399', '#ffcc00', '#db7093', '#6495ed']
|
45 |
+
n = max(simple['Frequency'])
|
46 |
+
|
47 |
+
plot = alt.Chart(simple).mark_bar().encode(
|
48 |
+
x=alt.X("Emotion category", sort=['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']),
|
49 |
+
y=alt.Y("Frequency", axis=alt.Axis(grid=False), scale=alt.Scale(domain=[0, (n + 9) // 10 * 10])),
|
50 |
+
color=alt.Color("Emotion category", scale=alt.Scale(domain=domain, range=range_), legend=None),
|
51 |
+
tooltip=['Emotion category', 'Frequency']).properties(
|
52 |
+
width=600).configure_axis(
|
53 |
+
grid=False).interactive()
|
54 |
+
return plot
|
55 |
+
|
56 |
+
def dist(file_output):
|
57 |
+
f = open(file_output, 'r')
|
58 |
+
data = f.read().split("\n")
|
59 |
+
f.close()
|
60 |
+
data = [line.split(",") for line in data[1:-1]]
|
61 |
+
|
62 |
+
freq_dict = {}
|
63 |
+
mapping_dict = {'0': 'neutral', '1': 'anger', '2': 'fear', '3': 'joy', '4': 'love', '5': 'sadness'}
|
64 |
+
for line in data:
|
65 |
+
dat = str(date(int(line[0][:4]), int(line[0][4:6]), int(line[0][6:8])))
|
66 |
+
|
67 |
+
if dat not in freq_dict.keys():
|
68 |
+
freq_dict[dat] = {}
|
69 |
+
if mapping_dict[line[1]] not in freq_dict[dat].keys():
|
70 |
+
freq_dict[dat][mapping_dict[line[1]]] = 1
|
71 |
+
else:
|
72 |
+
freq_dict[dat][mapping_dict[line[1]]] += 1
|
73 |
+
else:
|
74 |
+
if mapping_dict[line[1]] not in freq_dict[dat].keys():
|
75 |
+
freq_dict[dat][mapping_dict[line[1]]] = 1
|
76 |
+
else:
|
77 |
+
freq_dict[dat][mapping_dict[line[1]]] += 1
|
78 |
+
|
79 |
+
start_date = date(int(data[0][0][:4]), int(data[0][0][4:6]), int(data[0][0][6:8]))
|
80 |
+
end_date = date(int(data[-1][0][:4]), int(data[-1][0][4:6]), int(data[-1][0][6:8]))
|
81 |
+
delta = end_date - start_date # returns timedelta
|
82 |
+
date_range = [str(start_date + timedelta(days=i)) for i in range(delta.days + 1)]
|
83 |
+
|
84 |
+
dates = [dat for dat in date_range for i in range(6)]
|
85 |
+
frequency = [freq_dict[dat][emotion] if (dat in freq_dict.keys() and emotion in freq_dict[dat].keys()) else 0 for dat in date_range for emotion in ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']]
|
86 |
+
categories = [emotion for dat in date_range for emotion in ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']]
|
87 |
+
|
88 |
+
data = pd.DataFrame({
|
89 |
+
'Date': dates,
|
90 |
+
'Frequency': frequency,
|
91 |
+
'Emotion category': categories})
|
92 |
+
|
93 |
+
domain = ['neutral', 'anger', 'fear', 'joy', 'love', 'sadness']
|
94 |
+
range_ = ['#999999', '#b22222', '#663399', '#ffcc00', '#db7093', '#6495ed']
|
95 |
+
n = max(data['Frequency'])
|
96 |
+
|
97 |
+
highlight = alt.selection(
|
98 |
+
type='single', on='mouseover', fields=["Emotion category"], nearest=True)
|
99 |
+
|
100 |
+
|
101 |
+
base = alt.Chart(data).encode(
|
102 |
+
x ="Date:T",
|
103 |
+
y=alt.Y("Frequency", scale=alt.Scale(domain=[0, (n + 9) // 10 * 10])),
|
104 |
+
color=alt.Color("Emotion category", scale=alt.Scale(domain=domain, range=range_), legend=alt.Legend(orient='bottom', direction='horizontal')))
|
105 |
+
|
106 |
+
|
107 |
+
points = base.mark_circle().encode(
|
108 |
+
opacity=alt.value(0),
|
109 |
+
tooltip=[
|
110 |
+
alt.Tooltip('Emotion category', title='Emotion category'),
|
111 |
+
alt.Tooltip('Date:T', title='Date'),
|
112 |
+
alt.Tooltip('Frequency', title='Frequency')
|
113 |
+
]).add_selection(highlight)
|
114 |
+
|
115 |
+
|
116 |
+
lines = base.mark_line().encode(
|
117 |
+
size=alt.condition(~highlight, alt.value(1), alt.value(3)))
|
118 |
+
|
119 |
+
plot = (points + lines).properties(width=600, height=350).interactive()
|
120 |
+
return plot
|
121 |
+
|
122 |
+
def showcase(dataset):
|
123 |
+
# predictions file
|
124 |
+
if dataset == "The Voice of Holland":
|
125 |
+
file_output = "output/predictions_tvoh.txt"
|
126 |
+
elif dataset == "Floodings":
|
127 |
+
file_output = "output/predictions_floodings.txt"
|
128 |
+
elif dataset == "COVID-19":
|
129 |
+
file_output = "output/predictions_covid.txt"
|
130 |
+
elif dataset == "Childcare Benefits":
|
131 |
+
file_output = "output/predictions_toeslagen.txt"
|
132 |
+
# freq bar plot
|
133 |
+
freq_output = freq(file_output)
|
134 |
+
# dist plot
|
135 |
+
dist_output = dist(file_output)
|
136 |
+
# peaks
|
137 |
+
if dataset == "The Voice of Holland":
|
138 |
+
peaks_output = pickle.load(open('output/peaks_tvoh.p', 'rb'))
|
139 |
+
elif dataset == "Floodings":
|
140 |
+
peaks_output = pickle.load(open('output/peaks_floodings.p', 'rb'))
|
141 |
+
elif dataset == "COVID-19":
|
142 |
+
peaks_output = pickle.load(open('output/peaks_covid.p', 'rb'))
|
143 |
+
elif dataset == "Childcare Benefits":
|
144 |
+
peaks_output = pickle.load(open('output/peaks_toeslagen.p', 'rb'))
|
145 |
+
# topics
|
146 |
+
if dataset == "The Voice of Holland":
|
147 |
+
topics_output = pickle.load(open('output/topics_tvoh.p', 'rb'))
|
148 |
+
elif dataset == "Floodings":
|
149 |
+
topics_output = pickle.load(open('output/topics_floodings.p', 'rb'))
|
150 |
+
elif dataset == "COVID-19":
|
151 |
+
topics_output = pickle.load(open('output/topics_covid.p', 'rb'))
|
152 |
+
elif dataset == "Childcare Benefits":
|
153 |
+
topics_output = pickle.load(open('output/topics_toeslagen.p', 'rb'))
|
154 |
+
return gr.update(visible=True), gr.update(value=file_output, visible=True), gr.update(value=freq_output,visible=True), gr.update(value=dist_output,visible=True), gr.update(value=peaks_output,visible=True), gr.update(value=topics_output,visible=True)
|
155 |
+
|
156 |
+
|
157 |
+
inference_modelpath = "model/checkpoint-128"
|
158 |
+
|
159 |
+
with gr.Blocks() as demo:
|
160 |
+
with gr.Column(scale=1, min_width=50):
|
161 |
+
gr.Markdown("""
|
162 |
+
""")
|
163 |
+
with gr.Column(scale=5):
|
164 |
+
gr.Markdown("""
|
165 |
+
<div style="text-align: center"><h1>EmotioNL: A framework for Dutch emotion detection</h1></div>
|
166 |
+
|
167 |
+
<div style="display: block;margin-left: auto;margin-right: auto;width: 60%;"><img alt="EmotioNL logo" src="https://users.ugent.be/~lundbruy/EmotioNL.png" width="100%"></div>
|
168 |
+
|
169 |
+
<div style="display: block;margin-left: auto;margin-right: auto;width: 75%;">This demo was made to demonstrate the EmotioNL model, a transformer-based classification model that analyses emotions in Dutch texts. The model uses <a href="https://github.com/iPieter/RobBERT">RobBERT</a>, which was further fine-tuned on the <a href="https://lt3.ugent.be/resources/emotionl/">EmotioNL dataset</a>. The resulting model is a classifier that, given a sentence, predicts one of the following emotion categories: <i>anger</i>, <i>fear</i>, <i>joy</i>, <i>love</i>, <i>sadness</i> or <i>neutral</i>. The demo can be used either in <b>sentence mode</b>, which allows you to enter a sentence for which an emotion will be predicted; or in <b>dataset mode</b>, which allows you to upload a dataset or see the full functionality with example data.</div>
|
170 |
+
""")
|
171 |
+
with gr.Tab("Sentence"):
|
172 |
+
gr.Markdown("""
|
173 |
+
""")
|
174 |
+
with gr.Row():
|
175 |
+
with gr.Column():
|
176 |
+
input = gr.Textbox(
|
177 |
+
label="Enter a sentence",
|
178 |
+
value="Jaaah! Volgende vakantie Barcelona en na het zomerseizoen naar de Algarve",
|
179 |
+
lines=1)
|
180 |
+
send_btn = gr.Button("Send")
|
181 |
+
output = gr.Textbox()
|
182 |
+
send_btn.click(fn=inference_sentence, inputs=input, outputs=output)
|
183 |
+
with gr.Tab("Showcase"):
|
184 |
+
with gr.Row():
|
185 |
+
with gr.Column():
|
186 |
+
gr.Markdown("""
|
187 |
+
**<font size="4">Run the demo on the data of a specific crisis case</font>**
|
188 |
+
Select the desired dataset and click the button to run the demo.
|
189 |
+
""")
|
190 |
+
with gr.Column():
|
191 |
+
gr.Markdown("""
|
192 |
+
""")
|
193 |
+
with gr.Column():
|
194 |
+
gr.Markdown("""
|
195 |
+
**<font size="4">Output</font>**
|
196 |
+
After having clicked on the run button, scroll down to see the output (running may take a while):
|
197 |
+
""")
|
198 |
+
with gr.Row():
|
199 |
+
with gr.Column():
|
200 |
+
# demo1_btn = gr.Button("The Voice of Holland", variant="primary")
|
201 |
+
# demo2_btn = gr.Button("Floodings", variant="primary")
|
202 |
+
# demo3_btn = gr.Button("COVID-19", variant="primary")
|
203 |
+
# demo4_btn = gr.Button("Childcare Benefits", variant="primary")
|
204 |
+
dataset = gr.Dropdown(["The Voice of Holland", "Floodings", "COVID-19", "Childcare Benefits"], show_label=False)
|
205 |
+
run_btn = gr.Button("Run", variant="primary")
|
206 |
+
|
207 |
+
with gr.Column():
|
208 |
+
gr.Markdown("""
|
209 |
+
**The Voice of Holland:** 18,502 tweets about a scandal about sexual misconduct in the Dutch reality TV singing competition 'The Voice of Holland'.
|
210 |
+
|
211 |
+
**Floodings:** 9,923 tweets about the floodings that affected Belgium and the Netherlands in the Summer of 2021.
|
212 |
+
|
213 |
+
**COVID-19:** 609,206 tweets about the COVID-19 pandemic, posted in the first eight months of the crisis.
|
214 |
+
|
215 |
+
**Chilcare Benefits:** 66,961 tweets about the political scandal concerning false allegations of fraud regarding childcare allowance in the Netherlands.
|
216 |
+
""")
|
217 |
+
with gr.Column():
|
218 |
+
gr.Markdown("""
|
219 |
+
**Predictions:** file with the predicted emotion label for each instance in the dataset.
|
220 |
+
**Emotion frequencies:** bar plot with the prediction frequencies of each emotion category (anger, fear, joy, love, sadness or neutral).
|
221 |
+
**Emotion distribution over time:** line plot that visualises the frequency of predicted emotions over time for each emotion category.
|
222 |
+
**Peaks:** step graph that only shows the significant fluctuations (upwards and downwards) in emotion frequencies over time.
|
223 |
+
**Topics:** a bar plot that shows the emotion distribution for different topics in the dataset. Topics are extracted using [BERTopic](https://maartengr.github.io/BERTopic/index.html).
|
224 |
+
""")
|
225 |
+
|
226 |
+
with gr.Row():
|
227 |
+
gr.Markdown("""
|
228 |
+
___
|
229 |
+
""")
|
230 |
+
with gr.Row():
|
231 |
+
with gr.Column():
|
232 |
+
output_markdown = gr.Markdown("""
|
233 |
+
**<font size="4">Output</font>**
|
234 |
+
""", visible=False)
|
235 |
+
|
236 |
+
message = gr.Textbox(label="Message", visible=False)
|
237 |
+
output_file = gr.File(label="Predictions", visible=False)
|
238 |
+
output_plot = gr.Plot(show_label=False, visible=False).style(container=True)
|
239 |
+
output_dist = gr.Plot(show_label=False, visible=False)
|
240 |
+
output_peaks = gr.Plot(show_label=False, visible=False)
|
241 |
+
output_topics = gr.Plot(show_label=False, visible=False)
|
242 |
+
|
243 |
+
run_btn.click(fn=showcase, inputs=[dataset], outputs=[output_markdown, output_file, output_plot, output_dist, output_peaks, output_topics])
|
244 |
+
|
245 |
+
|
246 |
+
with gr.Row():
|
247 |
+
with gr.Column():
|
248 |
+
gr.Markdown("""
|
249 |
+
<font size="2">Both this demo and the dataset have been created by [LT3](https://lt3.ugent.be/), the Language and Translation Technology Team of Ghent University. The EmotioNL project has been carried out with support from the Research Foundation – Flanders (FWO). For any questions, please contact [email protected].</font>
|
250 |
+
|
251 |
+
<div style="display: grid;grid-template-columns:150px auto;"> <img style="margin-right: 1em" alt="LT3 logo" src="https://lt3.ugent.be/static/images/logo_v2_single.png" width="136" height="58"> <img style="margin-right: 1em" alt="FWO logo" src="https://www.fwo.be/images/logo_desktop.png" height="58"></div>
|
252 |
+
""")
|
253 |
+
with gr.Column(scale=1, min_width=50):
|
254 |
+
gr.Markdown("""
|
255 |
+
""")
|
256 |
+
|
257 |
+
demo.launch()
|