antalvdb commited on
Commit
e72d5c9
·
1 Parent(s): 38f7e93

Upload 3 files

Browse files
opentaal-annotaties.txt.corrections ADDED
The diff for this file is too large to render. See raw diff
 
opentaal-annotaties.txt.errors ADDED
The diff for this file is too large to render. See raw diff
 
spell.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, BartForConditionalGeneration, AutoTokenizer
2
+ from evaluate import load
3
+ import re
4
+
5
+ model = BartForConditionalGeneration.from_pretrained('/home/antalb/software/spelling/bart-base-spelling-nl-9m-3')
6
+ tokenizer = AutoTokenizer.from_pretrained('/home/antalb/software/spelling/bart-base-spelling-nl-9m-3')
7
+
8
+ fix_spelling = pipeline("text2text-generation",model=model,tokenizer=tokenizer)
9
+ cer = load("cer")
10
+ wer = load("wer")
11
+ bleu = load("bleu")
12
+ meteor = load("meteor")
13
+
14
+ file1name = 'opentaal-annotaties.txt.errors'
15
+ file2name = 'opentaal-annotaties.txt.corrections'
16
+
17
+ predictions=[]
18
+ references=[]
19
+
20
+ counter=0;
21
+
22
+ #clean_chars = re.compile(r'[^A-Za-zöäüÖÄÜß,.!?’\'$%€0-9\(\)\- ]', re.MULTILINE)
23
+ clean_chars = re.compile(r'[^A-Za-zëïöäüÖÄÜ,.!?’\'$%€0-9\(\)\- ]', re.MULTILINE)
24
+ def cleanup(text):
25
+ text = clean_chars.sub('', text)
26
+ #print("bug: somehow all numbers are removed - this is might be due to thisregex")
27
+ #exit()
28
+ #text = text.replace("\n", "")
29
+ #text = text.replace('"','\\"')
30
+ return text
31
+
32
+ with open(file1name, "r") as file1, open(file2name, "r") as file2:
33
+ for line1, line2 in zip(file1, file2):
34
+
35
+ line1 = cleanup(line1)
36
+
37
+ # for actual spelling correction evaluation:
38
+
39
+ intermediate=(fix_spelling(line1,max_length=2048))
40
+ line=intermediate[0]['generated_text'];
41
+
42
+ # for lower-bound testing on the errors:
43
+ #line = line1
44
+
45
+ print(line1)
46
+ print(line)
47
+
48
+ line2 = cleanup(line2)
49
+ print(line2)
50
+
51
+ if len(line)>0 and len(line2)>0:
52
+ predictions.append(line)
53
+ references.append(line2)
54
+
55
+ if counter%100==0:
56
+ print(counter)
57
+ cer_score = cer.compute(predictions=predictions, references=references)
58
+ print('CER - ' + str(cer_score))
59
+ wer_score = wer.compute(predictions=predictions, references=references)
60
+ print('WER - ' + str(wer_score))
61
+ bleu_score = bleu.compute(predictions=predictions, references=references)
62
+ print('BLEU - ' + str(bleu_score))
63
+ meteor_score = meteor.compute(predictions=predictions, references=references)
64
+ print('METEOR - ' + str(meteor_score))
65
+
66
+ counter+=1
67
+
68
+ cer_score = cer.compute(predictions=predictions, references=references)
69
+ print('CER - ' + str(cer_score))
70
+ wer_score = wer.compute(predictions=predictions, references=references)
71
+ print('WER - ' + str(wer_score))
72
+ bleu_score = bleu.compute(predictions=predictions, references=references)
73
+ print('BLEU - ' + str(bleu_score))
74
+ meteor_score = meteor.compute(predictions=predictions, references=references)
75
+ print('METEOR - ' + str(meteor_score))