Spaces:
Runtime error
Runtime error
Olivia Figueira
commited on
Commit
·
e3c1abf
1
Parent(s):
3c050d3
Refactored LM inits and changed app ui
Browse files- critic/critic.py +110 -112
critic/critic.py
CHANGED
@@ -20,7 +20,7 @@ import streamlit as st
|
|
20 |
|
21 |
st.subheader('Exploring Unsupervised Grammatical Error Correction with Transformer-Based Models')
|
22 |
st.write('This live demonstration is adapted from the paper [LM-Critic: Language Models for Unsupervised Grammatical Error Correction](https://aclanthology.org/2021.emnlp-main.611.pdf) (EMNLP 2021) by Michihiro Yasunaga, Jure Leskovec, Percy Liang.')
|
23 |
-
st.write('
|
24 |
|
25 |
def get_gpt2_loss(model, tokenizer, input_ids, attention_mask, labels):
|
26 |
with torch.no_grad():
|
@@ -142,132 +142,120 @@ def gpt2_critic(sent, model, tokenizer, verbose=1, cuda=False, fp16=True, seed='
|
|
142 |
counter_example = [sents[best_idx], float(logps[best_idx])]
|
143 |
return is_good, float(logps[0]), counter_example, return_string
|
144 |
|
145 |
-
def
|
|
|
146 |
placeholder_lm_name = st.empty()
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
model_gpt2.cpu()
|
160 |
-
st.session_state["model_gpt2"] = model_gpt2
|
161 |
-
st.session_state["tokenizer_gpt2"] = tokenizer_gpt2
|
162 |
-
st.session_state["nice_name_gpt2"] = nice_name_gpt2
|
163 |
-
|
164 |
-
prog += 10
|
165 |
-
my_bar.progress(prog)
|
166 |
-
|
167 |
-
if "nice_name_opt" not in st.session_state:
|
168 |
-
## OPT LM
|
169 |
-
model_name_opt = "facebook/opt-350m"
|
170 |
-
nice_name_opt = "OPT"
|
171 |
-
placeholder_lm_name.text(f"Initializing {nice_name_opt}...")
|
172 |
-
model_opt = OPTForCausalLM.from_pretrained(model_name_opt)
|
173 |
-
tokenizer_opt = GPT2Tokenizer.from_pretrained(model_name_opt)
|
174 |
-
tokenizer_opt.pad_token = tokenizer_opt.eos_token
|
175 |
-
model_opt.eval()
|
176 |
-
model_opt.cpu()
|
177 |
-
st.session_state["model_opt"] = model_opt
|
178 |
-
st.session_state["tokenizer_opt"] = tokenizer_opt
|
179 |
-
st.session_state["nice_name_opt"] = nice_name_opt
|
180 |
-
|
181 |
-
prog += 10
|
182 |
-
my_bar.progress(prog)
|
183 |
-
|
184 |
-
if "nice_name_gptneo" not in st.session_state:
|
185 |
-
## GPT NEO
|
186 |
-
model_name_gptneo = "EleutherAI/gpt-neo-1.3B"
|
187 |
-
nice_name_gptneo = "GPT NEO"
|
188 |
-
placeholder_lm_name.text(f"Initializing {nice_name_gptneo}...")
|
189 |
-
model_gptneo = GPTNeoForCausalLM.from_pretrained(model_name_gptneo)
|
190 |
-
tokenizer_gptneo = GPT2Tokenizer.from_pretrained(model_name_gptneo)
|
191 |
-
tokenizer_gptneo.pad_token = tokenizer_gptneo.eos_token
|
192 |
-
model_gptneo.eval()
|
193 |
-
model_gptneo.cpu()
|
194 |
-
st.session_state["model_gptneo"] = model_gptneo
|
195 |
-
st.session_state["tokenizer_gptneo"] = tokenizer_gptneo
|
196 |
-
st.session_state["nice_name_gptneo"] = nice_name_gptneo
|
197 |
-
|
198 |
-
prog += 10
|
199 |
-
my_bar.progress(prog)
|
200 |
-
|
201 |
-
if "nice_name_roberta" not in st.session_state:
|
202 |
-
## RoBERTa
|
203 |
-
model_name_roberta = "roberta-base"
|
204 |
-
nice_name_roberta = "RoBERTa"
|
205 |
-
placeholder_lm_name.text(f"Initializing {nice_name_roberta}...")
|
206 |
-
tokenizer_roberta = RobertaTokenizer.from_pretrained(model_name_roberta)
|
207 |
-
config_roberta = RobertaConfig.from_pretrained(model_name_roberta)
|
208 |
-
config_roberta.is_decoder = True
|
209 |
-
model_roberta = RobertaForCausalLM.from_pretrained(model_name_roberta, config=config_roberta)
|
210 |
-
tokenizer_roberta.pad_token = tokenizer_roberta.eos_token
|
211 |
-
model_roberta.eval()
|
212 |
-
model_roberta.cpu()
|
213 |
-
st.session_state["model_roberta"] = model_gptneo
|
214 |
-
st.session_state["tokenizer_roberta"] = tokenizer_roberta
|
215 |
-
st.session_state["nice_name_roberta"] = nice_name_roberta
|
216 |
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
234 |
|
235 |
-
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
253 |
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
placeholder_lm_name.empty()
|
257 |
-
|
|
|
|
|
258 |
|
259 |
def main():
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
|
264 |
-
|
265 |
-
if sent != '':
|
266 |
st.markdown(f"**Input Sentence**: {sent}")
|
267 |
results = {}
|
268 |
|
269 |
with st.spinner('Running with GPT-2 LM...'):
|
270 |
## GPT-2 LM (original LM-critic)
|
|
|
|
|
271 |
is_good, score, counter_example, return_string_GPT2 = gpt2_critic(sent, st.session_state['model_gpt2'], st.session_state['tokenizer_gpt2'])
|
272 |
st.markdown("**Results with GPT-2 LM:**")
|
273 |
st.write('\n'.join(return_string_GPT2))
|
@@ -275,6 +263,8 @@ def main():
|
|
275 |
|
276 |
with st.spinner('Running with OPT LM...'):
|
277 |
## OPT LM
|
|
|
|
|
278 |
is_good, score, counter_example, return_string_OPT = gpt2_critic(sent, st.session_state['model_opt'], st.session_state['tokenizer_opt'])
|
279 |
st.markdown("**Results with OPT LM:**")
|
280 |
st.write('\n'.join(return_string_OPT))
|
@@ -282,6 +272,8 @@ def main():
|
|
282 |
|
283 |
with st.spinner('Running with GPT NEO LM...'):
|
284 |
## GPT NEO
|
|
|
|
|
285 |
is_good, score, counter_example, return_string_GPTNEO = gpt2_critic(sent, st.session_state['model_gptneo'], st.session_state['tokenizer_gptneo'])
|
286 |
st.markdown("**Results with GPT NEO LM:**")
|
287 |
st.write('\n'.join(return_string_GPTNEO))
|
@@ -289,6 +281,8 @@ def main():
|
|
289 |
|
290 |
with st.spinner('Running with RoBERTa LM...'):
|
291 |
## RoBERTa
|
|
|
|
|
292 |
is_good, score, counter_example, return_string_RoBERTa = gpt2_critic(sent, st.session_state['model_roberta'], st.session_state['tokenizer_roberta'])
|
293 |
st.markdown("**Results with RoBERTa LM:**")
|
294 |
st.write('\n'.join(return_string_RoBERTa))
|
@@ -296,6 +290,8 @@ def main():
|
|
296 |
|
297 |
with st.spinner('Running with BART LM...'):
|
298 |
## BART
|
|
|
|
|
299 |
is_good, score, counter_example, return_string_BART = gpt2_critic(sent, st.session_state['model_bart'], st.session_state['tokenizer_bart'])
|
300 |
st.markdown("**Results with BART LM:**")
|
301 |
st.write('\n'.join(return_string_BART))
|
@@ -303,6 +299,8 @@ def main():
|
|
303 |
|
304 |
with st.spinner('Running with XLM RoBERTa LM...'):
|
305 |
## XLM RoBERTa
|
|
|
|
|
306 |
is_good, score, counter_example, return_string_XLMRoBERTa = gpt2_critic(sent, st.session_state['model_xlmroberta'], st.session_state['tokenizer_xlmroberta'])
|
307 |
st.markdown("**Results with XLM RoBERTa LM:**")
|
308 |
st.write('\n'.join(return_string_XLMRoBERTa))
|
|
|
20 |
|
21 |
st.subheader('Exploring Unsupervised Grammatical Error Correction with Transformer-Based Models')
|
22 |
st.write('This live demonstration is adapted from the paper [LM-Critic: Language Models for Unsupervised Grammatical Error Correction](https://aclanthology.org/2021.emnlp-main.611.pdf) (EMNLP 2021) by Michihiro Yasunaga, Jure Leskovec, Percy Liang.')
|
23 |
+
st.write('Enter any sentence in the text box, press submit, and see the grammatical scoring and judgement results outputted by LM-Critic using different LMs dislpayed below. Upon running this for the first time, it will initialize each LM.')
|
24 |
|
25 |
def get_gpt2_loss(model, tokenizer, input_ids, attention_mask, labels):
|
26 |
with torch.no_grad():
|
|
|
142 |
counter_example = [sents[best_idx], float(logps[best_idx])]
|
143 |
return is_good, float(logps[0]), counter_example, return_string
|
144 |
|
145 |
+
def gpt2():
|
146 |
+
## GPT-2 LM (original LM-critic)
|
147 |
placeholder_lm_name = st.empty()
|
148 |
+
model_name_gpt2 = 'gpt2'
|
149 |
+
nice_name_gpt2 = "GPT-2"
|
150 |
+
placeholder_lm_name.text(f"Initializing {nice_name_gpt2}...")
|
151 |
+
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained(model_name_gpt2)
|
152 |
+
tokenizer_gpt2.pad_token = tokenizer_gpt2.eos_token
|
153 |
+
model_gpt2 = GPT2LMHeadModel.from_pretrained(model_name_gpt2)
|
154 |
+
model_gpt2.eval()
|
155 |
+
model_gpt2.cpu()
|
156 |
+
placeholder_lm_name.empty()
|
157 |
+
st.session_state["model_gpt2"] = model_gpt2
|
158 |
+
st.session_state["tokenizer_gpt2"] = tokenizer_gpt2
|
159 |
+
st.session_state["nice_name_gpt2"] = nice_name_gpt2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
+
def opt():
|
162 |
+
## OPT LM
|
163 |
+
placeholder_lm_name = st.empty()
|
164 |
+
model_name_opt = "facebook/opt-350m"
|
165 |
+
nice_name_opt = "OPT"
|
166 |
+
placeholder_lm_name.text(f"Initializing {nice_name_opt}...")
|
167 |
+
model_opt = OPTForCausalLM.from_pretrained(model_name_opt)
|
168 |
+
tokenizer_opt = GPT2Tokenizer.from_pretrained(model_name_opt)
|
169 |
+
tokenizer_opt.pad_token = tokenizer_opt.eos_token
|
170 |
+
model_opt.eval()
|
171 |
+
model_opt.cpu()
|
172 |
+
placeholder_lm_name.empty()
|
173 |
+
st.session_state["model_opt"] = model_opt
|
174 |
+
st.session_state["tokenizer_opt"] = tokenizer_opt
|
175 |
+
st.session_state["nice_name_opt"] = nice_name_opt
|
176 |
|
177 |
+
def gpt_neo():
|
178 |
+
## GPT NEO
|
179 |
+
placeholder_lm_name = st.empty()
|
180 |
+
model_name_gptneo = "EleutherAI/gpt-neo-1.3B"
|
181 |
+
nice_name_gptneo = "GPT NEO"
|
182 |
+
placeholder_lm_name.text(f"Initializing {nice_name_gptneo}...")
|
183 |
+
model_gptneo = GPTNeoForCausalLM.from_pretrained(model_name_gptneo)
|
184 |
+
tokenizer_gptneo = GPT2Tokenizer.from_pretrained(model_name_gptneo)
|
185 |
+
tokenizer_gptneo.pad_token = tokenizer_gptneo.eos_token
|
186 |
+
model_gptneo.eval()
|
187 |
+
model_gptneo.cpu()
|
188 |
+
placeholder_lm_name.empty()
|
189 |
+
st.session_state["model_gptneo"] = model_gptneo
|
190 |
+
st.session_state["tokenizer_gptneo"] = tokenizer_gptneo
|
191 |
+
st.session_state["nice_name_gptneo"] = nice_name_gptneo
|
192 |
|
193 |
+
def roberta():
|
194 |
+
## RoBERTa
|
195 |
+
placeholder_lm_name = st.empty()
|
196 |
+
model_name_roberta = "roberta-base"
|
197 |
+
nice_name_roberta = "RoBERTa"
|
198 |
+
placeholder_lm_name.text(f"Initializing {nice_name_roberta}...")
|
199 |
+
tokenizer_roberta = RobertaTokenizer.from_pretrained(model_name_roberta)
|
200 |
+
config_roberta = RobertaConfig.from_pretrained(model_name_roberta)
|
201 |
+
config_roberta.is_decoder = True
|
202 |
+
model_roberta = RobertaForCausalLM.from_pretrained(model_name_roberta, config=config_roberta)
|
203 |
+
tokenizer_roberta.pad_token = tokenizer_roberta.eos_token
|
204 |
+
model_roberta.eval()
|
205 |
+
model_roberta.cpu()
|
206 |
+
placeholder_lm_name.empty()
|
207 |
+
st.session_state["model_roberta"] = model_roberta
|
208 |
+
st.session_state["tokenizer_roberta"] = tokenizer_roberta
|
209 |
+
st.session_state["nice_name_roberta"] = nice_name_roberta
|
210 |
|
211 |
+
def bart():
|
212 |
+
## BART
|
213 |
+
placeholder_lm_name = st.empty()
|
214 |
+
model_name_bart = "facebook/bart-base"
|
215 |
+
nice_name_bart = "BART"
|
216 |
+
placeholder_lm_name.text(f"Initializing {nice_name_bart}...")
|
217 |
+
tokenizer_bart = BartTokenizer.from_pretrained(model_name_bart)
|
218 |
+
model_bart = BartForCausalLM.from_pretrained(model_name_bart, add_cross_attention=False)
|
219 |
+
assert model_bart.config.is_decoder, f"{model_bart.__class__} has to be configured as a decoder."
|
220 |
+
tokenizer_bart.pad_token = tokenizer_bart.eos_token
|
221 |
+
model_bart.eval()
|
222 |
+
model_bart.cpu()
|
223 |
+
placeholder_lm_name.empty()
|
224 |
+
st.session_state["model_bart"] = model_bart
|
225 |
+
st.session_state["tokenizer_bart"] = tokenizer_bart
|
226 |
+
st.session_state["nice_name_bart"] = nice_name_bart
|
227 |
|
228 |
+
def xlm_roberta():
|
229 |
+
## XLM RoBERTa
|
230 |
+
placeholder_lm_name = st.empty()
|
231 |
+
model_name_xlmroberta = 'xlm-roberta-base'
|
232 |
+
nice_name_xlmroberta = 'XLM RoBERTa'
|
233 |
+
placeholder_lm_name.text(f"Initializing {nice_name_xlmroberta}...")
|
234 |
+
tokenizer_xlmroberta = XLMRobertaTokenizer.from_pretrained(model_name_xlmroberta)
|
235 |
+
config_xlmroberta = XLMRobertaConfig.from_pretrained(model_name_xlmroberta)
|
236 |
+
config_xlmroberta.is_decoder = True
|
237 |
+
model_xlmroberta = XLMRobertaForCausalLM.from_pretrained(model_name_xlmroberta, config=config_xlmroberta)
|
238 |
+
tokenizer_xlmroberta.pad_token = tokenizer_xlmroberta.eos_token
|
239 |
+
model_xlmroberta.eval()
|
240 |
+
model_xlmroberta.cpu()
|
241 |
placeholder_lm_name.empty()
|
242 |
+
st.session_state["model_xlmroberta"] = model_xlmroberta
|
243 |
+
st.session_state["tokenizer_xlmroberta"] = tokenizer_xlmroberta
|
244 |
+
st.session_state["nice_name_xlmroberta"] = nice_name_xlmroberta
|
245 |
|
246 |
def main():
|
247 |
+
form = st.form(key='my_form')
|
248 |
+
sent = form.text_input(label='Enter a sentence:', value="")
|
249 |
+
submit = form.form_submit_button(label='Submit')
|
250 |
|
251 |
+
if submit and sent != '':
|
|
|
252 |
st.markdown(f"**Input Sentence**: {sent}")
|
253 |
results = {}
|
254 |
|
255 |
with st.spinner('Running with GPT-2 LM...'):
|
256 |
## GPT-2 LM (original LM-critic)
|
257 |
+
if "nice_name_gpt2" not in st.session_state:
|
258 |
+
gpt2()
|
259 |
is_good, score, counter_example, return_string_GPT2 = gpt2_critic(sent, st.session_state['model_gpt2'], st.session_state['tokenizer_gpt2'])
|
260 |
st.markdown("**Results with GPT-2 LM:**")
|
261 |
st.write('\n'.join(return_string_GPT2))
|
|
|
263 |
|
264 |
with st.spinner('Running with OPT LM...'):
|
265 |
## OPT LM
|
266 |
+
if "nice_name_opt" not in st.session_state:
|
267 |
+
opt()
|
268 |
is_good, score, counter_example, return_string_OPT = gpt2_critic(sent, st.session_state['model_opt'], st.session_state['tokenizer_opt'])
|
269 |
st.markdown("**Results with OPT LM:**")
|
270 |
st.write('\n'.join(return_string_OPT))
|
|
|
272 |
|
273 |
with st.spinner('Running with GPT NEO LM...'):
|
274 |
## GPT NEO
|
275 |
+
if "nice_name_gptneo" not in st.session_state:
|
276 |
+
gpt_neo()
|
277 |
is_good, score, counter_example, return_string_GPTNEO = gpt2_critic(sent, st.session_state['model_gptneo'], st.session_state['tokenizer_gptneo'])
|
278 |
st.markdown("**Results with GPT NEO LM:**")
|
279 |
st.write('\n'.join(return_string_GPTNEO))
|
|
|
281 |
|
282 |
with st.spinner('Running with RoBERTa LM...'):
|
283 |
## RoBERTa
|
284 |
+
if "nice_name_roberta" not in st.session_state:
|
285 |
+
roberta()
|
286 |
is_good, score, counter_example, return_string_RoBERTa = gpt2_critic(sent, st.session_state['model_roberta'], st.session_state['tokenizer_roberta'])
|
287 |
st.markdown("**Results with RoBERTa LM:**")
|
288 |
st.write('\n'.join(return_string_RoBERTa))
|
|
|
290 |
|
291 |
with st.spinner('Running with BART LM...'):
|
292 |
## BART
|
293 |
+
if "nice_name_bart" not in st.session_state:
|
294 |
+
bart()
|
295 |
is_good, score, counter_example, return_string_BART = gpt2_critic(sent, st.session_state['model_bart'], st.session_state['tokenizer_bart'])
|
296 |
st.markdown("**Results with BART LM:**")
|
297 |
st.write('\n'.join(return_string_BART))
|
|
|
299 |
|
300 |
with st.spinner('Running with XLM RoBERTa LM...'):
|
301 |
## XLM RoBERTa
|
302 |
+
if "nice_name_xlmroberta" not in st.session_state:
|
303 |
+
xlm_roberta()
|
304 |
is_good, score, counter_example, return_string_XLMRoBERTa = gpt2_critic(sent, st.session_state['model_xlmroberta'], st.session_state['tokenizer_xlmroberta'])
|
305 |
st.markdown("**Results with XLM RoBERTa LM:**")
|
306 |
st.write('\n'.join(return_string_XLMRoBERTa))
|