|
from transformers import pipeline, BartForConditionalGeneration, AutoTokenizer |
|
from evaluate import load |
|
import re |
|
|
|
model = BartForConditionalGeneration.from_pretrained('/home/antalb/software/spelling/bart-base-spelling-nl-9m-3') |
|
tokenizer = AutoTokenizer.from_pretrained('/home/antalb/software/spelling/bart-base-spelling-nl-9m-3') |
|
|
|
fix_spelling = pipeline("text2text-generation",model=model,tokenizer=tokenizer) |
|
cer = load("cer") |
|
wer = load("wer") |
|
bleu = load("bleu") |
|
meteor = load("meteor") |
|
|
|
file1name = 'opentaal-annotaties.txt.errors' |
|
file2name = 'opentaal-annotaties.txt.corrections' |
|
|
|
predictions=[] |
|
references=[] |
|
|
|
counter=0; |
|
|
|
|
|
clean_chars = re.compile(r'[^A-Za-zëïöäüÖÄÜ,.!?’\'$%€0-9\(\)\- ]', re.MULTILINE) |
|
def cleanup(text): |
|
text = clean_chars.sub('', text) |
|
|
|
|
|
|
|
|
|
return text |
|
|
|
with open(file1name, "r") as file1, open(file2name, "r") as file2: |
|
for line1, line2 in zip(file1, file2): |
|
|
|
line1 = cleanup(line1) |
|
|
|
|
|
|
|
intermediate=(fix_spelling(line1,max_length=2048)) |
|
line=intermediate[0]['generated_text']; |
|
|
|
|
|
|
|
|
|
print(line1) |
|
print(line) |
|
|
|
line2 = cleanup(line2) |
|
print(line2) |
|
|
|
if len(line)>0 and len(line2)>0: |
|
predictions.append(line) |
|
references.append(line2) |
|
|
|
if counter%100==0: |
|
print(counter) |
|
cer_score = cer.compute(predictions=predictions, references=references) |
|
print('CER - ' + str(cer_score)) |
|
wer_score = wer.compute(predictions=predictions, references=references) |
|
print('WER - ' + str(wer_score)) |
|
bleu_score = bleu.compute(predictions=predictions, references=references) |
|
print('BLEU - ' + str(bleu_score)) |
|
meteor_score = meteor.compute(predictions=predictions, references=references) |
|
print('METEOR - ' + str(meteor_score)) |
|
|
|
counter+=1 |
|
|
|
cer_score = cer.compute(predictions=predictions, references=references) |
|
print('CER - ' + str(cer_score)) |
|
wer_score = wer.compute(predictions=predictions, references=references) |
|
print('WER - ' + str(wer_score)) |
|
bleu_score = bleu.compute(predictions=predictions, references=references) |
|
print('BLEU - ' + str(bleu_score)) |
|
meteor_score = meteor.compute(predictions=predictions, references=references) |
|
print('METEOR - ' + str(meteor_score)) |
|
|