bakrianoo commited on
Commit
bf21a79
·
1 Parent(s): ff21d20

Upload Initial Model

Browse files
README.md ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: ar
3
+ datasets:
4
+ - common_voice
5
+ metrics:
6
+ - wer
7
+ tags:
8
+ - audio
9
+ - automatic-speech-recognition
10
+ - speech
11
+ - xlsr-fine-tuning-week
12
+ license: apache-2.0
13
+ model-index:
14
+ - name: Sinai Voice Arabic Specch Recognition Model
15
+ results:
16
+ - task:
17
+ name: Speech Recognition
18
+ type: automatic-speech-recognition
19
+ dataset:
20
+ name: Common Voice ar
21
+ type: common_voice
22
+ args: ar
23
+ metrics:
24
+ - name: Test WER
25
+ type: wer
26
+ value: 40.2
27
+ ---
28
+
29
+ # Sinai Voice Arabic Speech Recognition Model
30
+ # نموذج **صوت سيناء** للتعرف على الأصوات العربية الفصحى و تحويلها إلى نصوص
31
+ Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53)
32
+ on Arabic using the [Common Voice](https://huggingface.co/datasets/common_voice)
33
+
34
+
35
+ ## Usage
36
+
37
+ Please install:
38
+ - [PyTorch](https://pytorch.org/)
39
+ - `$ pip3 install jiwer lang_trans torchaudio datasets transformers`
40
+
41
+ The model can be used directly (without a language model) as follows:
42
+ ```python
43
+ import torch
44
+ import torchaudio
45
+ from datasets import load_dataset
46
+ from lang_trans.arabic import buckwalter
47
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
48
+ dataset = load_dataset("common_voice", "ar", split="test[:10]")
49
+ resamplers = { # all three sampling rates exist in test split
50
+ 48000: torchaudio.transforms.Resample(48000, 16000),
51
+ 44100: torchaudio.transforms.Resample(44100, 16000),
52
+ 32000: torchaudio.transforms.Resample(32000, 16000),
53
+ }
54
+ def prepare_example(example):
55
+ speech, sampling_rate = torchaudio.load(example["path"])
56
+ example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
57
+ return example
58
+ dataset = dataset.map(prepare_example)
59
+ processor = Wav2Vec2Processor.from_pretrained("bakrianoo/sinai-voice-ar-stt")
60
+ model = Wav2Vec2ForCTC.from_pretrained("bakrianoo/sinai-voice-ar-stt").eval()
61
+ def predict(batch):
62
+ inputs = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding=True)
63
+ with torch.no_grad():
64
+ predicted = torch.argmax(model(inputs.input_values).logits, dim=-1)
65
+ predicted[predicted == -100] = processor.tokenizer.pad_token_id # see fine-tuning script
66
+ batch["predicted"] = processor.tokenizer.batch_decode(predicted)
67
+ return batch
68
+ dataset = dataset.map(predict, batched=True, batch_size=1, remove_columns=["speech"])
69
+ for reference, predicted in zip(dataset["sentence"], dataset["predicted"]):
70
+ print("reference:", reference)
71
+ print("predicted:", predicted)
72
+ print("--")
73
+ ```
74
+ Here's the output:
75
+ ```
76
+ reference: ألديك قلم ؟
77
+ predicted: ألديك قلم
78
+ --
79
+ reference: ليست هناك مسافة على هذه الأرض أبعد من يوم أمس.
80
+ predicted: ليست نارك مسافة على هذه الأرض أبعد من يوم أمس
81
+ --
82
+ reference: إنك تكبر المشكلة.
83
+ predicted: إنك تكبر المشكلة
84
+ --
85
+ reference: يرغب أن يلتقي بك.
86
+ predicted: يرغب أن يلتقي بك
87
+ --
88
+ reference: إنهم لا يعرفون لماذا حتى.
89
+ predicted: إنهم لا يعرفون لماذا حتى
90
+ --
91
+ reference: سيسعدني مساعدتك أي وقت تحب.
92
+ predicted: سيسعدن مساعثتك أي وقد تحب
93
+ --
94
+ reference: أَحَبُّ نظريّة علمية إليّ هي أن حلقات زحل مكونة بالكامل من الأمتعة المفقودة.
95
+ predicted: أحب نظرية علمية إلي هي أن أحلقتز حلم كوينا بالكامل من الأمت عن المفقودة
96
+ --
97
+ reference: سأشتري له قلماً.
98
+ predicted: سأشتري له قلما
99
+ --
100
+ reference: أين المشكلة ؟
101
+ predicted: أين المشكل
102
+ --
103
+ reference: وَلِلَّهِ يَسْجُدُ مَا فِي السَّمَاوَاتِ وَمَا فِي الْأَرْضِ مِنْ دَابَّةٍ وَالْمَلَائِكَةُ وَهُمْ لَا يَسْتَكْبِرُونَ
104
+ predicted: ولله يسجد ما في السماوات وما في الأرض من دابة والملائكة وهم لا يستكبرون
105
+ ```
106
+ ## Evaluation
107
+
108
+ CLONED from [elgeish/wav2vec2-large-xlsr-53-arabic](https://huggingface.co/elgeish/wav2vec2-large-xlsr-53-arabic)
109
+
110
+ The model can be evaluated as follows on the Arabic test data of Common Voice:
111
+ ```python
112
+ import jiwer
113
+ import torch
114
+ import torchaudio
115
+ from datasets import load_dataset
116
+ from lang_trans.arabic import buckwalter
117
+ from transformers import set_seed, Wav2Vec2ForCTC, Wav2Vec2Processor
118
+ set_seed(42)
119
+ test_split = load_dataset("common_voice", "ar", split="test")
120
+ resamplers = { # all three sampling rates exist in test split
121
+ 48000: torchaudio.transforms.Resample(48000, 16000),
122
+ 44100: torchaudio.transforms.Resample(44100, 16000),
123
+ 32000: torchaudio.transforms.Resample(32000, 16000),
124
+ }
125
+ def prepare_example(example):
126
+ speech, sampling_rate = torchaudio.load(example["path"])
127
+ example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
128
+ return example
129
+ test_split = test_split.map(prepare_example)
130
+ processor = Wav2Vec2Processor.from_pretrained("bakrianoo/sinai-voice-ar-stt")
131
+ model = Wav2Vec2ForCTC.from_pretrained("bakrianoo/sinai-voice-ar-stt").to("cuda").eval()
132
+ def predict(batch):
133
+ inputs = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding=True)
134
+ with torch.no_grad():
135
+ predicted = torch.argmax(model(inputs.input_values.to("cuda")).logits, dim=-1)
136
+ predicted[predicted == -100] = processor.tokenizer.pad_token_id # see fine-tuning script
137
+ batch["predicted"] = processor.batch_decode(predicted)
138
+ return batch
139
+ test_split = test_split.map(predict, batched=True, batch_size=16, remove_columns=["speech"])
140
+ transformation = jiwer.Compose([
141
+ # normalize some diacritics, remove punctuation, and replace Persian letters with Arabic ones
142
+ jiwer.SubstituteRegexes({
143
+ r'[auiFNKo\~_،؟»\?;:\-,\.؛«!"]': "", "\u06D6": "",
144
+ r"[\|\{]": "A", "p": "h", "ک": "k", "ی": "y"}),
145
+ # default transformation below
146
+ jiwer.RemoveMultipleSpaces(),
147
+ jiwer.Strip(),
148
+ jiwer.SentencesToListOfWords(),
149
+ jiwer.RemoveEmptyStrings(),
150
+ ])
151
+ metrics = jiwer.compute_measures(
152
+ truth=[buckwalter.trans(s) for s in test_split["sentence"]], # Buckwalter transliteration
153
+ hypothesis=test_split["predicted"],
154
+ truth_transform=transformation,
155
+ hypothesis_transform=transformation,
156
+ )
157
+ print(f"WER: {metrics['wer']:.2%}")
158
+ ```
159
+ **Test Result**: 40.2%
config.json ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/wav2vec2-large-xlsr-53",
3
+ "activation_dropout": 0.0,
4
+ "apply_spec_augment": true,
5
+ "architectures": [
6
+ "Wav2Vec2ForCTC"
7
+ ],
8
+ "attention_dropout": 0.15,
9
+ "bos_token_id": 1,
10
+ "conv_bias": true,
11
+ "conv_dim": [
12
+ 512,
13
+ 512,
14
+ 512,
15
+ 512,
16
+ 512,
17
+ 512,
18
+ 512
19
+ ],
20
+ "conv_kernel": [
21
+ 10,
22
+ 3,
23
+ 3,
24
+ 3,
25
+ 3,
26
+ 2,
27
+ 2
28
+ ],
29
+ "conv_stride": [
30
+ 5,
31
+ 2,
32
+ 2,
33
+ 2,
34
+ 2,
35
+ 2,
36
+ 2
37
+ ],
38
+ "ctc_loss_reduction": "mean",
39
+ "ctc_zero_infinity": false,
40
+ "do_stable_layer_norm": true,
41
+ "eos_token_id": 2,
42
+ "feat_extract_activation": "gelu",
43
+ "feat_extract_dropout": 0.0,
44
+ "feat_extract_norm": "layer",
45
+ "feat_proj_dropout": 0.0,
46
+ "final_dropout": 0.0,
47
+ "gradient_checkpointing": true,
48
+ "hidden_act": "gelu",
49
+ "hidden_dropout": 0.15,
50
+ "hidden_size": 1024,
51
+ "initializer_range": 0.02,
52
+ "intermediate_size": 4096,
53
+ "layer_norm_eps": 1e-05,
54
+ "layerdrop": 0.1,
55
+ "mask_channel_length": 10,
56
+ "mask_channel_min_space": 1,
57
+ "mask_channel_other": 0.0,
58
+ "mask_channel_prob": 0.0,
59
+ "mask_channel_selection": "static",
60
+ "mask_feature_length": 10,
61
+ "mask_feature_prob": 0.0,
62
+ "mask_time_length": 10,
63
+ "mask_time_min_space": 1,
64
+ "mask_time_other": 0.0,
65
+ "mask_time_prob": 0.05,
66
+ "mask_time_selection": "static",
67
+ "model_type": "wav2vec2",
68
+ "num_attention_heads": 16,
69
+ "num_conv_pos_embedding_groups": 16,
70
+ "num_conv_pos_embeddings": 128,
71
+ "num_feat_extract_layers": 7,
72
+ "num_hidden_layers": 24,
73
+ "pad_token_id": 45,
74
+ "transformers_version": "4.4.0",
75
+ "vocab_size": 48
76
+ }
optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff523d662d9db49ccaadba49574509f044262f42534b15281436d469f3d2a65e
3
+ size 2490464146
preprocessor_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_size": 1,
4
+ "padding_side": "right",
5
+ "padding_value": 0.0,
6
+ "return_attention_mask": true,
7
+ "sampling_rate": 16000
8
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe340e53abd873e326771a32ab2b1d21f4ae07a98401dd0ada85940dbacc6e92
3
+ size 1262126414
scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7db558e13c832cb0475ee05ac56dc0ed99b68cb98e8687b128d7b06e57d53360
3
+ size 623
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]"}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "[PAD]", "do_lower_case": false, "word_delimiter_token": "|"}
trainer_state.json ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 21.238938053097346,
5
+ "global_step": 2400,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 3.54,
12
+ "learning_rate": 0.0001926605504587156,
13
+ "loss": NaN,
14
+ "step": 400
15
+ },
16
+ {
17
+ "epoch": 3.54,
18
+ "eval_loss": 2.914095878601074,
19
+ "eval_runtime": 779.3384,
20
+ "eval_samples_per_second": 9.78,
21
+ "eval_wer": 1.0,
22
+ "step": 400
23
+ },
24
+ {
25
+ "epoch": 7.08,
26
+ "learning_rate": 0.0001779816513761468,
27
+ "loss": NaN,
28
+ "step": 800
29
+ },
30
+ {
31
+ "epoch": 7.08,
32
+ "eval_loss": 0.5257614850997925,
33
+ "eval_runtime": 808.0741,
34
+ "eval_samples_per_second": 9.432,
35
+ "eval_wer": 0.4909435120753172,
36
+ "step": 800
37
+ },
38
+ {
39
+ "epoch": 10.62,
40
+ "learning_rate": 0.00016330275229357798,
41
+ "loss": NaN,
42
+ "step": 1200
43
+ },
44
+ {
45
+ "epoch": 10.62,
46
+ "eval_loss": 0.4604354500770569,
47
+ "eval_runtime": 816.9678,
48
+ "eval_samples_per_second": 9.33,
49
+ "eval_wer": 0.444765656979124,
50
+ "step": 1200
51
+ },
52
+ {
53
+ "epoch": 14.16,
54
+ "learning_rate": 0.00014862385321100919,
55
+ "loss": NaN,
56
+ "step": 1600
57
+ },
58
+ {
59
+ "epoch": 14.16,
60
+ "eval_loss": 0.4556906819343567,
61
+ "eval_runtime": 788.8971,
62
+ "eval_samples_per_second": 9.662,
63
+ "eval_wer": 0.4178008595988539,
64
+ "step": 1600
65
+ },
66
+ {
67
+ "epoch": 17.7,
68
+ "learning_rate": 0.00013394495412844036,
69
+ "loss": NaN,
70
+ "step": 2000
71
+ },
72
+ {
73
+ "epoch": 17.7,
74
+ "eval_loss": 0.44174715876579285,
75
+ "eval_runtime": 788.0029,
76
+ "eval_samples_per_second": 9.673,
77
+ "eval_wer": 0.4084885386819484,
78
+ "step": 2000
79
+ },
80
+ {
81
+ "epoch": 21.24,
82
+ "learning_rate": 0.00011926605504587157,
83
+ "loss": NaN,
84
+ "step": 2400
85
+ },
86
+ {
87
+ "epoch": 21.24,
88
+ "eval_loss": 0.43745675683021545,
89
+ "eval_runtime": 797.2999,
90
+ "eval_samples_per_second": 9.56,
91
+ "eval_wer": 0.4028602128530495,
92
+ "step": 2400
93
+ }
94
+ ],
95
+ "max_steps": 5650,
96
+ "num_train_epochs": 50,
97
+ "total_flos": 8.638551360908018e+19,
98
+ "trial_name": null,
99
+ "trial_params": null
100
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b674eec74b835dba17b22c39ba4f76f198e2fd26941f1588e8a63aa2c526e6f9
3
+ size 2287
vocab.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"ج": 0, "ح": 1, "ﺃ": 2, "ت": 3, "ط": 4, "خ": 5, "چ": 6, "س": 7, "ب": 8, "غ": 10, "ث": 11, "ض": 12, "ا": 13, "ذ": 14, "ھ": 15, "ز": 16, "ى": 17, "ﻻ": 18, "ظ": 19, "ق": 20, "ص": 21, "م": 22, "ف": 23, "د": 24, "ش": 25, "و": 26, "ه": 27, "ی": 28, "ء": 29, "ر": 30, "آ": 31, "ع": 32, "ي": 33, "ل": 34, "ؤ": 35, "ڨ": 36, "ک": 37, "إ": 38, "أ": 39, "ك": 40, "ة": 41, "ئ": 42, "ن": 43, "|": 9, "[UNK]": 44, "[PAD]": 45}