elsayedissa
commited on
Commit
·
f3f2922
1
Parent(s):
ef96ccf
Update README.md
Browse files
README.md
CHANGED
@@ -113,7 +113,7 @@ processor = WhisperProcessor.from_pretrained("clu-ling/whisper-large-v2-arabic-5
|
|
113 |
model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-large-v2-arabic-5k-steps")
|
114 |
|
115 |
# dataset
|
116 |
-
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "ar", split="test", )#cache_dir=args.cache_dir
|
117 |
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
|
118 |
|
119 |
#for debuggings: it gets two examples
|
@@ -136,11 +136,11 @@ def normalize(batch):
|
|
136 |
return batch
|
137 |
|
138 |
def map_wer(batch):
|
139 |
-
model.to(
|
140 |
-
forced_decoder_ids = processor.get_decoder_prompt_ids(language =
|
141 |
inputs = processor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt").input_features
|
142 |
with torch.no_grad():
|
143 |
-
generated_ids = model.generate(inputs=inputs.to(
|
144 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
145 |
batch["predicted_text"] = clean_text(transcription)
|
146 |
return batch
|
@@ -148,10 +148,10 @@ def map_wer(batch):
|
|
148 |
# process GOLD text
|
149 |
processed_dataset = dataset.map(normalize)
|
150 |
# get predictions
|
151 |
-
|
152 |
|
153 |
# word error rate
|
154 |
-
wer = wer_metric.compute(references=
|
155 |
wer = round(100 * wer, 2)
|
156 |
print("WER:", wer)
|
157 |
```
|
|
|
113 |
model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-large-v2-arabic-5k-steps")
|
114 |
|
115 |
# dataset
|
116 |
+
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "ar", split="test", ) #cache_dir=args.cache_dir
|
117 |
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
|
118 |
|
119 |
#for debuggings: it gets two examples
|
|
|
136 |
return batch
|
137 |
|
138 |
def map_wer(batch):
|
139 |
+
model.to(device)
|
140 |
+
forced_decoder_ids = processor.get_decoder_prompt_ids(language = "ar", task = "transcribe")
|
141 |
inputs = processor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt").input_features
|
142 |
with torch.no_grad():
|
143 |
+
generated_ids = model.generate(inputs=inputs.to(device), forced_decoder_ids=forced_decoder_ids)
|
144 |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
145 |
batch["predicted_text"] = clean_text(transcription)
|
146 |
return batch
|
|
|
148 |
# process GOLD text
|
149 |
processed_dataset = dataset.map(normalize)
|
150 |
# get predictions
|
151 |
+
predicted = processed_dataset.map(map_wer)
|
152 |
|
153 |
# word error rate
|
154 |
+
wer = wer_metric.compute(references=predicted['gold_text'], predictions=predicted['predicted_text'])
|
155 |
wer = round(100 * wer, 2)
|
156 |
print("WER:", wer)
|
157 |
```
|