Yehor commited on
Commit
b4ad1cc
·
1 Parent(s): 1a38a52
.dockerignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .ruff_cache/
2
+ .venv/
3
+ models/
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ .idea/
2
+ .venv/
3
+ .ruff_cache/
4
+ __pycache__/
5
+
6
+ flagged/
7
+ models/
8
+
9
+ audio.wav
RADTTS-LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a
4
+ copy of this software and associated documentation files (the "Software"),
5
+ to deal in the Software without restriction, including without limitation
6
+ the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ and/or sell copies of the Software, and to permit persons to whom the
8
+ Software is furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ DEALINGS IN THE SOFTWARE.
README.md CHANGED
@@ -1,12 +1,29 @@
1
  ---
2
- title: Radtts Uk Vocos Demo
3
- emoji: 📈
4
- colorFrom: indigo
5
- colorTo: indigo
6
  sdk: gradio
 
 
 
 
7
  sdk_version: 5.19.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: apache-2.0
3
+ title: RAD-TTS++ Ukrainian (Vocos)
 
 
4
  sdk: gradio
5
+ emoji: 🎧
6
+ colorFrom: blue
7
+ colorTo: gray
8
+ short_description: Use RAD-TTS++ model to synthesize text in Ukrainian
9
  sdk_version: 5.19.0
 
 
10
  ---
11
 
12
+ ## Install
13
+
14
+ ```shell
15
+ uv venv --python 3.10
16
+
17
+ source .venv/bin/activate
18
+
19
+ uv pip install -r requirements.txt
20
+
21
+ # in development mode
22
+ uv pip install -r requirements-dev.txt
23
+ ```
24
+
25
+ ## Run
26
+
27
+ ```shell
28
+ python app.py
29
+ ```
alignment.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+
22
+ import numpy as np
23
+ from numba import jit
24
+
25
+
26
+ @jit(nopython=True)
27
+ def mas_width1(attn_map):
28
+ """mas with hardcoded width=1"""
29
+ # assumes mel x text
30
+ opt = np.zeros_like(attn_map)
31
+ attn_map = np.log(attn_map)
32
+ attn_map[0, 1:] = -np.inf
33
+ log_p = np.zeros_like(attn_map)
34
+ log_p[0, :] = attn_map[0, :]
35
+ prev_ind = np.zeros_like(attn_map, dtype=np.int64)
36
+ for i in range(1, attn_map.shape[0]):
37
+ for j in range(attn_map.shape[1]): # for each text dim
38
+ prev_log = log_p[i - 1, j]
39
+ prev_j = j
40
+
41
+ if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
42
+ prev_log = log_p[i - 1, j - 1]
43
+ prev_j = j - 1
44
+
45
+ log_p[i, j] = attn_map[i, j] + prev_log
46
+ prev_ind[i, j] = prev_j
47
+
48
+ # now backtrack
49
+ curr_text_idx = attn_map.shape[1] - 1
50
+ for i in range(attn_map.shape[0] - 1, -1, -1):
51
+ opt[i, curr_text_idx] = 1
52
+ curr_text_idx = prev_ind[i, curr_text_idx]
53
+ opt[0, curr_text_idx] = 1
54
+ return opt
app.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import time
5
+
6
+ from importlib.metadata import version
7
+ from enum import Enum
8
+
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ use_zerogpu = False
12
+
13
+ try:
14
+ import spaces # it's for ZeroGPU
15
+ use_zerogpu = True
16
+ print("ZeroGPU is available, changing inference call.")
17
+ except ImportError:
18
+ print("ZeroGPU is not available, skipping...")
19
+
20
+ import gradio as gr
21
+
22
+ import torch
23
+ import torchaudio
24
+
25
+ # Vocos
26
+ from vocos import Vocos
27
+
28
+ # RAD-TTS code
29
+ from radtts import RADTTS
30
+ from data import Data
31
+ from common import update_params
32
+
33
+ use_cuda = torch.cuda.is_available()
34
+
35
+ if use_cuda:
36
+ print("CUDA is available, setting correct inference_device variable.")
37
+ device = "cuda"
38
+ else:
39
+ device = "cpu"
40
+
41
+
42
+ def download_file_from_repo(
43
+ repo_id: str,
44
+ filename: str,
45
+ local_dir: str = ".",
46
+ repo_type: str = "model",
47
+ ) -> str:
48
+ try:
49
+ os.makedirs(local_dir, exist_ok=True)
50
+
51
+ file_path = hf_hub_download(
52
+ repo_id=repo_id,
53
+ filename=filename,
54
+ local_dir=local_dir,
55
+ cache_dir=None,
56
+ force_download=False,
57
+ repo_type=repo_type,
58
+ )
59
+
60
+ return file_path
61
+ except Exception as e:
62
+ raise Exception(f"An error occurred during download: {e}") from e
63
+
64
+
65
+ download_file_from_repo(
66
+ "Yehor/radtts-uk",
67
+ "radtts-pp-dap-model/model_dap_84000.pt",
68
+ "./models/",
69
+ )
70
+
71
+ # Init the model
72
+ seed = 1234
73
+
74
+ config = "configs/radtts-pp-dap-model.json"
75
+ radtts_path = "models/radtts-pp-dap-model/model_dap_84000.pt"
76
+
77
+ params = []
78
+
79
+ # Load the config
80
+ with open(config) as f:
81
+ data = f.read()
82
+
83
+ config = json.loads(data)
84
+ update_params(config, params)
85
+
86
+ data_config = config["data_config"]
87
+ model_config = config["model_config"]
88
+
89
+ # Seed
90
+ torch.manual_seed(seed)
91
+ torch.cuda.manual_seed(seed)
92
+
93
+ # Load vocoder
94
+ vocos = Vocos.from_pretrained("patriotyk/vocos-mel-hifigan-compat-44100khz").to(device)
95
+
96
+ # Load RAD-TTS
97
+ if use_cuda:
98
+ radtts = RADTTS(**model_config).cuda()
99
+ else:
100
+ radtts = RADTTS(**model_config)
101
+
102
+ radtts.enable_inverse_cache() # cache inverse matrix for 1x1 invertible convs
103
+
104
+ checkpoint_dict = torch.load(radtts_path, map_location="cpu") # todo: CPU?
105
+ radtts.load_state_dict(checkpoint_dict["state_dict"], strict=False)
106
+ radtts.eval()
107
+
108
+ print(f"Loaded checkpoint '{radtts_path}')")
109
+
110
+ ignore_keys = ["training_files", "validation_files"]
111
+ trainset = Data(
112
+ data_config["training_files"],
113
+ **dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
114
+ )
115
+
116
+ # Config
117
+ concurrency_limit = 5
118
+
119
+ title = "RAD-TTS++ Ukrainian"
120
+
121
+ # https://www.tablesgenerator.com/markdown_tables
122
+ authors_table = """
123
+ ## Authors
124
+
125
+ Follow them on social networks and **contact** if you need any help or have any questions:
126
+
127
+ | <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** |
128
+ |-------------------------------------------------------------------------------------------------|
129
+ | https://t.me/smlkw in Telegram |
130
+ | https://x.com/yehor_smoliakov at X |
131
+ | https://github.com/egorsmkv at GitHub |
132
+ | https://huggingface.co/Yehor at Hugging Face |
133
+ | or use [email protected] |
134
+ """.strip()
135
+
136
+ description_head = f"""
137
+ # {title}
138
+
139
+ ## Overview
140
+
141
+ Type your text in Ukrainian and select a voice to synthesize speech using [the RAD-TTS++ model](https://huggingface.co/Yehor/radtts-uk) and [Vocos](https://huggingface.co/patriotyk/vocos-mel-hifigan-compat-44100khz) with 44100 Hz.
142
+ """.strip()
143
+
144
+ description_foot = f"""
145
+ {authors_table}
146
+ """.strip()
147
+
148
+ tech_env = f"""
149
+ #### Environment
150
+
151
+ - Python: {sys.version}
152
+ """.strip()
153
+
154
+ tech_libraries = f"""
155
+ #### Libraries
156
+
157
+ - gradio: {version("gradio")}
158
+ - torch: {version("torch")}
159
+ - scipy: {version("scipy")}
160
+ - numba: {version("numba")}
161
+ - librosa: {version("librosa")}
162
+ - unidecode: {version("unidecode")}
163
+ - inflect: {version("inflect")}
164
+ """.strip()
165
+
166
+
167
+ class VoiceOption(Enum):
168
+ Tetiana = "Tetiana (female) 👩"
169
+ Mykyta = "Mykyta (male) 👨"
170
+ Lada = "Lada (female) 👩"
171
+
172
+
173
+ voice_mapping = {
174
+ VoiceOption.Tetiana.value: "tetiana",
175
+ VoiceOption.Mykyta.value: "mykyta",
176
+ VoiceOption.Lada.value: "lada",
177
+ }
178
+
179
+
180
+ examples = [
181
+ [
182
+ "Прокинувся ґазда вранці. Пішов, вичистив з-під коня, вичистив з-під бика, вичистив з-під овечок, вибрав молодняк, відніс його набік.",
183
+ VoiceOption.Mykyta.value,
184
+ ],
185
+ [
186
+ "Пішов взяв сіна, дав корові. Пішов взяв сіна, дав бикові. Ячміню коняці насипав. Зайшов почистив корову, зайшов ��очистив бика, зайшов почистив коня, за яйця його мацнув.",
187
+ VoiceOption.Lada.value,
188
+ ],
189
+ [
190
+ "Кінь ногою здригнув, на хазяїна ласкавим оком подивився. Тоді дядько пішов відкрив курей, гусей, качок, повиносив їм зерна, огірків нарізаних, нагодував. Коли чує – з хати дружина кличе. Зайшов. Дітки повмивані, сидять за столом, всі чекають тата. Взяв він ложку, перехрестив дітей, перехрестив лоба, почали снідати. Поснідали, він дістав пряників, роздав дітям. Діти зібралися, пішли в школу. Дядько вийшов, сів на призьбі, взяв сапку, почав мантачити. Мантачив-мантачив, коли – жінка виходить. Він їй ту сапку дає, ласкаво за сраку вщипнув, жінка до нього лагідно всміхнулася, пішла на город – сапати. Коли – йде пастух і товар кличе в череду. Повідмикав дядько овечок, коровку, бика, коня, все відпустив. Сів попри хати, дістав табАку, відірвав шмат газети, насипав, наслинив собі гарну таку цигарку. Благодать божа – і сонечко вже здійнялося над деревами. Дядько встромив цигарку в рота, дістав сірники, тільки чиркати – коли раптом з хати: Доброе утро! Московское время – шесть часов утра! Витяг дядько цигарку с рота, сплюнув набік, і сам собі каже: Ана маєш. Прокинулись, бляді!",
191
+ VoiceOption.Tetiana.value,
192
+ ],
193
+ ]
194
+
195
+
196
+ def inference(text, voice):
197
+ if not text:
198
+ raise gr.Error("Please paste your text.")
199
+
200
+ gr.Info("Starting...", duration=0.5)
201
+
202
+ speaker = voice_mapping[voice]
203
+ speaker = speaker_text = speaker_attributes = speaker
204
+
205
+ n_takes = 1
206
+
207
+ sigma = 0.8 # sampling sigma for decoder
208
+ sigma_tkndur = 0.666 # sampling sigma for duration
209
+ sigma_f0 = 1.0 # sampling sigma for f0
210
+ sigma_energy = 1.0 # sampling sigma for energy avg
211
+
212
+ token_dur_scaling = 1.0
213
+
214
+ f0_mean = 0
215
+ f0_std = 0
216
+ energy_mean = 0
217
+ energy_std = 0
218
+
219
+ if use_cuda:
220
+ speaker_id = trainset.get_speaker_id(speaker).cuda()
221
+ speaker_id_text, speaker_id_attributes = speaker_id, speaker_id
222
+
223
+ if speaker_text is not None:
224
+ speaker_id_text = trainset.get_speaker_id(speaker_text).cuda()
225
+
226
+ if speaker_attributes is not None:
227
+ speaker_id_attributes = trainset.get_speaker_id(speaker_attributes).cuda()
228
+
229
+ tensor_text = trainset.get_text(text).cuda()[None]
230
+ else:
231
+ speaker_id = trainset.get_speaker_id(speaker)
232
+ speaker_id_text, speaker_id_attributes = speaker_id, speaker_id
233
+
234
+ if speaker_text is not None:
235
+ speaker_id_text = trainset.get_speaker_id(speaker_text)
236
+
237
+ if speaker_attributes is not None:
238
+ speaker_id_attributes = trainset.get_speaker_id(speaker_attributes)
239
+
240
+ tensor_text = trainset.get_text(text)[None]
241
+
242
+ inference_start = time.time()
243
+
244
+ for take in range(n_takes):
245
+ with torch.autocast(device, enabled=False):
246
+ with torch.inference_mode():
247
+ outputs = radtts.infer(
248
+ speaker_id,
249
+ tensor_text,
250
+ sigma,
251
+ sigma_tkndur,
252
+ sigma_f0,
253
+ sigma_energy,
254
+ token_dur_scaling,
255
+ token_duration_max=100,
256
+ speaker_id_text=speaker_id_text,
257
+ speaker_id_attributes=speaker_id_attributes,
258
+ f0_mean=f0_mean,
259
+ f0_std=f0_std,
260
+ energy_mean=energy_mean,
261
+ energy_std=energy_std,
262
+ use_cuda=use_cuda,
263
+ )
264
+
265
+ mel = outputs["mel"]
266
+
267
+ gr.Info(
268
+ "Synthesized MEL spectrogram, converting to WAVE.", duration=0.5
269
+ )
270
+
271
+ wav_gen = vocos.decode(mel)
272
+ wav_gen_float = wav_gen.cpu()
273
+
274
+ torchaudio.save("audio.wav", wav_gen_float, 44_100, encoding="PCM_S")
275
+
276
+ duration = len(wav_gen_float[0]) / 44_100
277
+
278
+ elapsed_time = time.time() - inference_start
279
+ rtf = elapsed_time / duration
280
+
281
+ speed_ratio = duration / elapsed_time
282
+ speech_rate = len(text.split(" ")) / duration
283
+
284
+ rtf_value = f"Real-Time Factor: {round(rtf, 4)}, time: {round(elapsed_time, 4)} seconds, audio duration: {round(duration, 4)} seconds. Speed ratio: {round(speed_ratio, 2)}x. Speech rate: {round(speech_rate, 4)} words-per-second."
285
+
286
+ gr.Success("Finished!", duration=0.5)
287
+
288
+ return [gr.Audio("audio.wav"), rtf_value]
289
+
290
+
291
+ try:
292
+ @spaces.GPU
293
+ def inference_zerogpu(text, voice):
294
+ return inference(text, voice)
295
+ except NameError:
296
+ print("ZeroGPU is not available, skipping...")
297
+
298
+
299
+ def inference_cpu(text, voice):
300
+ return inference(text, voice)
301
+
302
+
303
+ demo = gr.Blocks(
304
+ title=title,
305
+ analytics_enabled=False,
306
+ theme=gr.themes.Base(),
307
+ )
308
+
309
+ with demo:
310
+ gr.Markdown(description_head)
311
+
312
+ gr.Markdown("## Usage")
313
+
314
+ with gr.Row():
315
+ with gr.Column():
316
+ audio = gr.Audio(label="Synthesized audio")
317
+ rtf = gr.Markdown(
318
+ label="Real-Time Factor",
319
+ value="Here you will see how fast the model and the speaker is.",
320
+ )
321
+
322
+ with gr.Row():
323
+ with gr.Column():
324
+ text = gr.Text(
325
+ label="Text",
326
+ value="Сл+ава Укра+їні! — українське вітання, національне гасло.",
327
+ )
328
+ voice = gr.Radio(
329
+ label="Voice",
330
+ choices=[option.value for option in VoiceOption],
331
+ value=VoiceOption.Tetiana.value,
332
+ )
333
+
334
+ gr.Button("Run").click(
335
+ inference_zerogpu if use_zerogpu else inference_cpu,
336
+ concurrency_limit=concurrency_limit,
337
+ inputs=[text, voice],
338
+ outputs=[audio, rtf],
339
+ )
340
+
341
+ with gr.Row():
342
+ gr.Examples(
343
+ label="Choose an example",
344
+ inputs=[text, voice],
345
+ examples=examples,
346
+ )
347
+
348
+ gr.Markdown(description_foot)
349
+
350
+ gr.Markdown("### Gradio app uses:")
351
+ gr.Markdown(tech_env)
352
+ gr.Markdown(tech_libraries)
353
+
354
+ if __name__ == "__main__":
355
+ demo.queue()
356
+ demo.launch()
attribute_prediction_model.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+ import torch
22
+ from torch import nn
23
+ from common import ConvNorm, Invertible1x1Conv
24
+ from common import AffineTransformationLayer, SplineTransformationLayer
25
+ from common import ConvLSTMLinear
26
+ from transformer import FFTransformer
27
+ from autoregressive_flow import AR_Step, AR_Back_Step
28
+
29
+
30
+ def get_attribute_prediction_model(config):
31
+ name = config["name"]
32
+ hparams = config["hparams"]
33
+ if name == "dap":
34
+ model = DAP(**hparams)
35
+ elif name == "bgap":
36
+ model = BGAP(**hparams)
37
+ elif name == "agap":
38
+ model = AGAP(**hparams)
39
+ else:
40
+ raise Exception("{} model is not supported".format(name))
41
+
42
+ return model
43
+
44
+
45
+ class AttributeProcessing:
46
+ def __init__(self, take_log_of_input=False):
47
+ super(AttributeProcessing).__init__()
48
+ self.take_log_of_input = take_log_of_input
49
+
50
+ def normalize(self, x):
51
+ if self.take_log_of_input:
52
+ x = torch.log(x + 1)
53
+ return x
54
+
55
+ def denormalize(self, x):
56
+ if self.take_log_of_input:
57
+ x = torch.exp(x) - 1
58
+ return x
59
+
60
+
61
+ class BottleneckLayerLayer(nn.Module):
62
+ def __init__(
63
+ self,
64
+ in_dim,
65
+ reduction_factor,
66
+ norm="weightnorm",
67
+ non_linearity="relu",
68
+ kernel_size=3,
69
+ use_partial_padding=False,
70
+ ):
71
+ super(BottleneckLayerLayer, self).__init__()
72
+
73
+ self.reduction_factor = reduction_factor
74
+ reduced_dim = int(in_dim / reduction_factor)
75
+ self.out_dim = reduced_dim
76
+ if self.reduction_factor > 1:
77
+ fn = ConvNorm(
78
+ in_dim,
79
+ reduced_dim,
80
+ kernel_size=kernel_size,
81
+ use_weight_norm=(norm == "weightnorm"),
82
+ )
83
+ if norm == "instancenorm":
84
+ fn = nn.Sequential(fn, nn.InstanceNorm1d(reduced_dim, affine=True))
85
+
86
+ self.projection_fn = fn
87
+ self.non_linearity = nn.ReLU()
88
+ if non_linearity == "leakyrelu":
89
+ self.non_linearity = nn.LeakyReLU()
90
+
91
+ def forward(self, x):
92
+ if self.reduction_factor > 1:
93
+ x = self.projection_fn(x)
94
+ x = self.non_linearity(x)
95
+ return x
96
+
97
+
98
+ class DAP(nn.Module):
99
+ def __init__(
100
+ self,
101
+ n_speaker_dim,
102
+ bottleneck_hparams,
103
+ take_log_of_input,
104
+ arch_hparams,
105
+ use_transformer=False,
106
+ ):
107
+ super(DAP, self).__init__()
108
+ self.attribute_processing = AttributeProcessing(take_log_of_input)
109
+ self.bottleneck_layer = BottleneckLayerLayer(**bottleneck_hparams)
110
+
111
+ arch_hparams["in_dim"] = self.bottleneck_layer.out_dim + n_speaker_dim
112
+ if use_transformer:
113
+ self.feat_pred_fn = FFTransformer(**arch_hparams)
114
+ else:
115
+ self.feat_pred_fn = ConvLSTMLinear(**arch_hparams)
116
+
117
+ def forward(self, txt_enc, spk_emb, x, lens):
118
+ if x is not None:
119
+ x = self.attribute_processing.normalize(x)
120
+
121
+ txt_enc = self.bottleneck_layer(txt_enc)
122
+ spk_emb_expanded = spk_emb[..., None].expand(-1, -1, txt_enc.shape[2])
123
+ context = torch.cat((txt_enc, spk_emb_expanded), 1)
124
+
125
+ x_hat = self.feat_pred_fn(context, lens)
126
+
127
+ outputs = {"x_hat": x_hat, "x": x}
128
+ return outputs
129
+
130
+ def infer(self, z, txt_enc, spk_emb, lens=None):
131
+ x_hat = self.forward(txt_enc, spk_emb, x=None, lens=lens)["x_hat"]
132
+ x_hat = self.attribute_processing.denormalize(x_hat)
133
+ return x_hat
134
+
135
+
136
+ class BGAP(torch.nn.Module):
137
+ def __init__(
138
+ self,
139
+ n_in_dim,
140
+ n_speaker_dim,
141
+ bottleneck_hparams,
142
+ n_flows,
143
+ n_group_size,
144
+ n_layers,
145
+ with_dilation,
146
+ kernel_size,
147
+ scaling_fn,
148
+ take_log_of_input=False,
149
+ n_channels=1024,
150
+ use_quadratic=False,
151
+ n_bins=8,
152
+ n_spline_steps=2,
153
+ ):
154
+ super(BGAP, self).__init__()
155
+ # assert(n_group_size % 2 == 0)
156
+ self.n_flows = n_flows
157
+ self.n_group_size = n_group_size
158
+ self.transforms = torch.nn.ModuleList()
159
+ self.convinv = torch.nn.ModuleList()
160
+ self.n_speaker_dim = n_speaker_dim
161
+ self.scaling_fn = scaling_fn
162
+ self.attribute_processing = AttributeProcessing(take_log_of_input)
163
+ self.n_spline_steps = n_spline_steps
164
+ self.bottleneck_layer = BottleneckLayerLayer(**bottleneck_hparams)
165
+ n_txt_reduced_dim = self.bottleneck_layer.out_dim
166
+ context_dim = n_txt_reduced_dim * n_group_size + n_speaker_dim
167
+
168
+ if self.n_group_size > 1:
169
+ self.unfold_params = {
170
+ "kernel_size": (n_group_size, 1),
171
+ "stride": n_group_size,
172
+ "padding": 0,
173
+ "dilation": 1,
174
+ }
175
+ self.unfold = nn.Unfold(**self.unfold_params)
176
+
177
+ for k in range(n_flows):
178
+ self.convinv.append(Invertible1x1Conv(n_in_dim * n_group_size))
179
+ if k >= n_flows - self.n_spline_steps:
180
+ left = -3
181
+ right = 3
182
+ top = 3
183
+ bottom = -3
184
+ self.transforms.append(
185
+ SplineTransformationLayer(
186
+ n_in_dim * n_group_size,
187
+ context_dim,
188
+ n_layers,
189
+ with_dilation=with_dilation,
190
+ kernel_size=kernel_size,
191
+ scaling_fn=scaling_fn,
192
+ n_channels=n_channels,
193
+ top=top,
194
+ bottom=bottom,
195
+ left=left,
196
+ right=right,
197
+ use_quadratic=use_quadratic,
198
+ n_bins=n_bins,
199
+ )
200
+ )
201
+ else:
202
+ self.transforms.append(
203
+ AffineTransformationLayer(
204
+ n_in_dim * n_group_size,
205
+ context_dim,
206
+ n_layers,
207
+ with_dilation=with_dilation,
208
+ kernel_size=kernel_size,
209
+ scaling_fn=scaling_fn,
210
+ affine_model="simple_conv",
211
+ n_channels=n_channels,
212
+ )
213
+ )
214
+
215
+ def fold(self, data):
216
+ """Inverse of the self.unfold(data.unsqueeze(-1)) operation used for
217
+ the grouping or "squeeze" operation on input
218
+
219
+ Args:
220
+ data: B x C x T tensor of temporal data
221
+ """
222
+ output_size = (data.shape[2] * self.n_group_size, 1)
223
+ data = nn.functional.fold(
224
+ data, output_size=output_size, **self.unfold_params
225
+ ).squeeze(-1)
226
+ return data
227
+
228
+ def preprocess_context(self, txt_emb, speaker_vecs, std_scale=None):
229
+ if self.n_group_size > 1:
230
+ txt_emb = self.unfold(txt_emb[..., None])
231
+ speaker_vecs = speaker_vecs[..., None].expand(-1, -1, txt_emb.shape[2])
232
+ context = torch.cat((txt_emb, speaker_vecs), 1)
233
+ return context
234
+
235
+ def forward(self, txt_enc, spk_emb, x, lens):
236
+ """x<tensor>: duration or pitch or energy average"""
237
+ assert txt_enc.size(2) >= x.size(1)
238
+ if len(x.shape) == 2:
239
+ # add channel dimension
240
+ x = x[:, None]
241
+ txt_enc = self.bottleneck_layer(txt_enc)
242
+
243
+ # lens including padded values
244
+ lens_grouped = (lens // self.n_group_size).long()
245
+ context = self.preprocess_context(txt_enc, spk_emb)
246
+ x = self.unfold(x[..., None])
247
+ log_s_list, log_det_W_list = [], []
248
+ for k in range(self.n_flows):
249
+ x, log_s = self.transforms[k](x, context, seq_lens=lens_grouped)
250
+ x, log_det_W = self.convinv[k](x)
251
+ log_det_W_list.append(log_det_W)
252
+ log_s_list.append(log_s)
253
+ # prepare outputs
254
+ outputs = {"z": x, "log_det_W_list": log_det_W_list, "log_s_list": log_s_list}
255
+
256
+ return outputs
257
+
258
+ def infer(self, z, txt_enc, spk_emb, seq_lens):
259
+ txt_enc = self.bottleneck_layer(txt_enc)
260
+ context = self.preprocess_context(txt_enc, spk_emb)
261
+ lens_grouped = (seq_lens // self.n_group_size).long()
262
+ z = self.unfold(z[..., None])
263
+ for k in reversed(range(self.n_flows)):
264
+ z = self.convinv[k](z, inverse=True)
265
+ z = self.transforms[k].forward(
266
+ z, context, inverse=True, seq_lens=lens_grouped
267
+ )
268
+ # z mapped to input domain
269
+ x_hat = self.fold(z)
270
+ # pad on the way out
271
+ return x_hat
272
+
273
+
274
+ class AGAP(torch.nn.Module):
275
+ def __init__(
276
+ self,
277
+ n_in_dim,
278
+ n_speaker_dim,
279
+ n_flows,
280
+ n_hidden,
281
+ n_lstm_layers,
282
+ bottleneck_hparams,
283
+ scaling_fn="exp",
284
+ take_log_of_input=False,
285
+ p_dropout=0.0,
286
+ setup="",
287
+ spline_flow_params=None,
288
+ n_group_size=1,
289
+ ):
290
+ super(AGAP, self).__init__()
291
+ self.flows = torch.nn.ModuleList()
292
+ self.n_group_size = n_group_size
293
+ self.n_speaker_dim = n_speaker_dim
294
+ self.attribute_processing = AttributeProcessing(take_log_of_input)
295
+ self.n_in_dim = n_in_dim
296
+ self.bottleneck_layer = BottleneckLayerLayer(**bottleneck_hparams)
297
+ n_txt_reduced_dim = self.bottleneck_layer.out_dim
298
+
299
+ if self.n_group_size > 1:
300
+ self.unfold_params = {
301
+ "kernel_size": (n_group_size, 1),
302
+ "stride": n_group_size,
303
+ "padding": 0,
304
+ "dilation": 1,
305
+ }
306
+ self.unfold = nn.Unfold(**self.unfold_params)
307
+
308
+ if spline_flow_params is not None:
309
+ spline_flow_params["n_in_channels"] *= self.n_group_size
310
+
311
+ for i in range(n_flows):
312
+ if i % 2 == 0:
313
+ self.flows.append(
314
+ AR_Step(
315
+ n_in_dim * n_group_size,
316
+ n_speaker_dim,
317
+ n_txt_reduced_dim * n_group_size,
318
+ n_hidden,
319
+ n_lstm_layers,
320
+ scaling_fn,
321
+ spline_flow_params,
322
+ )
323
+ )
324
+ else:
325
+ self.flows.append(
326
+ AR_Back_Step(
327
+ n_in_dim * n_group_size,
328
+ n_speaker_dim,
329
+ n_txt_reduced_dim * n_group_size,
330
+ n_hidden,
331
+ n_lstm_layers,
332
+ scaling_fn,
333
+ spline_flow_params,
334
+ )
335
+ )
336
+
337
+ def fold(self, data):
338
+ """Inverse of the self.unfold(data.unsqueeze(-1)) operation used for
339
+ the grouping or "squeeze" operation on input
340
+
341
+ Args:
342
+ data: B x C x T tensor of temporal data
343
+ """
344
+ output_size = (data.shape[2] * self.n_group_size, 1)
345
+ data = nn.functional.fold(
346
+ data, output_size=output_size, **self.unfold_params
347
+ ).squeeze(-1)
348
+ return data
349
+
350
+ def preprocess_context(self, txt_emb, speaker_vecs):
351
+ if self.n_group_size > 1:
352
+ txt_emb = self.unfold(txt_emb[..., None])
353
+ speaker_vecs = speaker_vecs[..., None].expand(-1, -1, txt_emb.shape[2])
354
+ context = torch.cat((txt_emb, speaker_vecs), 1)
355
+ return context
356
+
357
+ def forward(self, txt_emb, spk_emb, x, lens):
358
+ """x<tensor>: duration or pitch or energy average"""
359
+
360
+ x = x[:, None] if len(x.shape) == 2 else x # add channel dimension
361
+ if self.n_group_size > 1:
362
+ x = self.unfold(x[..., None])
363
+ x = x.permute(2, 0, 1) # permute to time, batch, dims
364
+ x = self.attribute_processing.normalize(x)
365
+
366
+ txt_emb = self.bottleneck_layer(txt_emb)
367
+ context = self.preprocess_context(txt_emb, spk_emb)
368
+ context = context.permute(2, 0, 1) # permute to time, batch, dims
369
+
370
+ lens_groupped = (lens / self.n_group_size).long()
371
+ log_s_list = []
372
+ for i, flow in enumerate(self.flows):
373
+ x, log_s = flow(x, context, lens_groupped)
374
+ log_s_list.append(log_s)
375
+
376
+ x = x.permute(1, 2, 0) # x mapped to z
377
+ log_s_list = [log_s_elt.permute(1, 2, 0) for log_s_elt in log_s_list]
378
+ outputs = {"z": x, "log_s_list": log_s_list, "log_det_W_list": []}
379
+ return outputs
380
+
381
+ def infer(self, z, txt_emb, spk_emb, seq_lens=None):
382
+ if self.n_group_size > 1:
383
+ n_frames = z.shape[2]
384
+ z = self.unfold(z[..., None])
385
+ z = z.permute(2, 0, 1) # permute to time, batch, dims
386
+
387
+ txt_emb = self.bottleneck_layer(txt_emb)
388
+ context = self.preprocess_context(txt_emb, spk_emb)
389
+ context = context.permute(2, 0, 1) # permute to time, batch, dims
390
+
391
+ for i, flow in enumerate(reversed(self.flows)):
392
+ z = flow.infer(z, context)
393
+
394
+ x_hat = z.permute(1, 2, 0)
395
+ if self.n_group_size > 1:
396
+ x_hat = self.fold(x_hat)
397
+ if n_frames > x_hat.shape[2]:
398
+ m = nn.ReflectionPad1d((0, n_frames - x_hat.shape[2]))
399
+ x_hat = m(x_hat)
400
+
401
+ x_hat = self.attribute_processing.denormalize(x_hat)
402
+ return x_hat
audio_processing.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+ import torch
22
+ import numpy as np
23
+ from scipy.signal import get_window
24
+ from librosa.filters import mel as librosa_mel_fn
25
+ import librosa.util as librosa_util
26
+
27
+
28
+ def window_sumsquare(
29
+ window,
30
+ n_frames,
31
+ hop_length=200,
32
+ win_length=800,
33
+ n_fft=800,
34
+ dtype=np.float32,
35
+ norm=None,
36
+ ):
37
+ """
38
+ # from librosa 0.6
39
+ Compute the sum-square envelope of a window function at a given hop length.
40
+
41
+ This is used to estimate modulation effects induced by windowing
42
+ observations in short-time fourier transforms.
43
+
44
+ Parameters
45
+ ----------
46
+ window : string, tuple, number, callable, or list-like
47
+ Window specification, as in `get_window`
48
+
49
+ n_frames : int > 0
50
+ The number of analysis frames
51
+
52
+ hop_length : int > 0
53
+ The number of samples to advance between frames
54
+
55
+ win_length : [optional]
56
+ The length of the window function. By default, this matches `n_fft`.
57
+
58
+ n_fft : int > 0
59
+ The length of each analysis frame.
60
+
61
+ dtype : np.dtype
62
+ The data type of the output
63
+
64
+ Returns
65
+ -------
66
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
67
+ The sum-squared envelope of the window function
68
+ """
69
+ if win_length is None:
70
+ win_length = n_fft
71
+
72
+ n = n_fft + hop_length * (n_frames - 1)
73
+ x = np.zeros(n, dtype=dtype)
74
+
75
+ # Compute the squared window at the desired length
76
+ win_sq = get_window(window, win_length, fftbins=True)
77
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
78
+ win_sq = librosa_util.pad_center(win_sq, size=n_fft)
79
+
80
+ # Fill the envelope
81
+ for i in range(n_frames):
82
+ sample = i * hop_length
83
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
84
+ return x
85
+
86
+
87
+ def griffin_lim(magnitudes, stft_fn, n_iters=30):
88
+ """
89
+ PARAMS
90
+ ------
91
+ magnitudes: spectrogram magnitudes
92
+ stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
93
+ """
94
+
95
+ angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
96
+ angles = angles.astype(np.float32)
97
+ angles = torch.autograd.Variable(torch.from_numpy(angles))
98
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
99
+
100
+ for i in range(n_iters):
101
+ _, angles = stft_fn.transform(signal)
102
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
103
+ return signal
104
+
105
+
106
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
107
+ """
108
+ PARAMS
109
+ ------
110
+ C: compression factor
111
+ """
112
+ return torch.log(torch.clamp(x, min=clip_val) * C)
113
+
114
+
115
+ def dynamic_range_decompression(x, C=1):
116
+ """
117
+ PARAMS
118
+ ------
119
+ C: compression factor used to compress
120
+ """
121
+ return torch.exp(x) / C
122
+
123
+
124
+ class TacotronSTFT(torch.nn.Module):
125
+ def __init__(
126
+ self,
127
+ filter_length=1024,
128
+ hop_length=256,
129
+ win_length=1024,
130
+ n_mel_channels=80,
131
+ sampling_rate=22050,
132
+ mel_fmin=0.0,
133
+ mel_fmax=None,
134
+ ):
135
+ super(TacotronSTFT, self).__init__()
136
+ self.n_mel_channels = n_mel_channels
137
+ self.sampling_rate = sampling_rate
138
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
139
+ mel_basis = librosa_mel_fn(
140
+ sr=sampling_rate,
141
+ n_fft=filter_length,
142
+ n_mels=n_mel_channels,
143
+ fmin=mel_fmin,
144
+ fmax=mel_fmax,
145
+ )
146
+ mel_basis = torch.from_numpy(mel_basis).float()
147
+ self.register_buffer("mel_basis", mel_basis)
148
+
149
+ def spectral_normalize(self, magnitudes):
150
+ output = dynamic_range_compression(magnitudes)
151
+ return output
152
+
153
+ def spectral_de_normalize(self, magnitudes):
154
+ output = dynamic_range_decompression(magnitudes)
155
+ return output
156
+
157
+ def mel_spectrogram(self, y):
158
+ """Computes mel-spectrograms from a batch of waves
159
+ PARAMS
160
+ ------
161
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
162
+
163
+ RETURNS
164
+ -------
165
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
166
+ """
167
+ assert torch.min(y.data) >= -1
168
+ assert torch.max(y.data) <= 1
169
+
170
+ magnitudes, phases = self.stft_fn.transform(y)
171
+ magnitudes = magnitudes.data
172
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
173
+ mel_output = self.spectral_normalize(mel_output)
174
+ return mel_output
175
+
176
+
177
+ """
178
+ BSD 3-Clause License
179
+
180
+ Copyright (c) 2017, Prem Seetharaman
181
+ All rights reserved.
182
+
183
+ * Redistribution and use in source and binary forms, with or without
184
+ modification, are permitted provided that the following conditions are met:
185
+
186
+ * Redistributions of source code must retain the above copyright notice,
187
+ this list of conditions and the following disclaimer.
188
+
189
+ * Redistributions in binary form must reproduce the above copyright notice, this
190
+ list of conditions and the following disclaimer in the
191
+ documentation and/or other materials provided with the distribution.
192
+
193
+ * Neither the name of the copyright holder nor the names of its
194
+ contributors may be used to endorse or promote products derived from this
195
+ software without specific prior written permission.
196
+
197
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
198
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
199
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
200
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
201
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
202
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
203
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
204
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
205
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
206
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
207
+ """
208
+ import torch.nn.functional as F
209
+ from torch.autograd import Variable
210
+ from scipy.signal import get_window
211
+ from librosa.util import pad_center, tiny
212
+
213
+
214
+ class STFT(torch.nn.Module):
215
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
216
+
217
+ def __init__(
218
+ self, filter_length=800, hop_length=200, win_length=800, window="hann"
219
+ ):
220
+ super(STFT, self).__init__()
221
+ self.filter_length = filter_length
222
+ self.hop_length = hop_length
223
+ self.win_length = win_length
224
+ self.window = window
225
+ self.forward_transform = None
226
+ scale = self.filter_length / self.hop_length
227
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
228
+
229
+ cutoff = int((self.filter_length / 2 + 1))
230
+ fourier_basis = np.vstack(
231
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
232
+ )
233
+
234
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
235
+ inverse_basis = torch.FloatTensor(
236
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :]
237
+ )
238
+
239
+ if window is not None:
240
+ assert win_length >= filter_length
241
+ # get window and zero center pad it to filter_length
242
+ fft_window = get_window(window, win_length, fftbins=True)
243
+ fft_window = pad_center(fft_window, size=filter_length)
244
+ fft_window = torch.from_numpy(fft_window).float()
245
+
246
+ # window the bases
247
+ forward_basis *= fft_window
248
+ inverse_basis *= fft_window
249
+
250
+ self.register_buffer("forward_basis", forward_basis.float())
251
+ self.register_buffer("inverse_basis", inverse_basis.float())
252
+
253
+ def transform(self, input_data):
254
+ num_batches = input_data.size(0)
255
+ num_samples = input_data.size(1)
256
+
257
+ self.num_samples = num_samples
258
+
259
+ # similar to librosa, reflect-pad the input
260
+ input_data = input_data.view(num_batches, 1, num_samples)
261
+ input_data = F.pad(
262
+ input_data.unsqueeze(1),
263
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
264
+ mode="reflect",
265
+ )
266
+ input_data = input_data.squeeze(1)
267
+
268
+ forward_transform = F.conv1d(
269
+ input_data,
270
+ Variable(self.forward_basis, requires_grad=False),
271
+ stride=self.hop_length,
272
+ padding=0,
273
+ )
274
+
275
+ cutoff = int((self.filter_length / 2) + 1)
276
+ real_part = forward_transform[:, :cutoff, :]
277
+ imag_part = forward_transform[:, cutoff:, :]
278
+
279
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
280
+ phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
281
+
282
+ return magnitude, phase
283
+
284
+ def inverse(self, magnitude, phase):
285
+ recombine_magnitude_phase = torch.cat(
286
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
287
+ )
288
+
289
+ inverse_transform = F.conv_transpose1d(
290
+ recombine_magnitude_phase,
291
+ Variable(self.inverse_basis, requires_grad=False),
292
+ stride=self.hop_length,
293
+ padding=0,
294
+ )
295
+
296
+ if self.window is not None:
297
+ window_sum = window_sumsquare(
298
+ self.window,
299
+ magnitude.size(-1),
300
+ hop_length=self.hop_length,
301
+ win_length=self.win_length,
302
+ n_fft=self.filter_length,
303
+ dtype=np.float32,
304
+ )
305
+ # remove modulation effects
306
+ approx_nonzero_indices = torch.from_numpy(
307
+ np.where(window_sum > tiny(window_sum))[0]
308
+ )
309
+ window_sum = torch.autograd.Variable(
310
+ torch.from_numpy(window_sum), requires_grad=False
311
+ )
312
+ window_sum = window_sum.to(magnitude.device)
313
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
314
+ approx_nonzero_indices
315
+ ]
316
+
317
+ # scale by hop ratio
318
+ inverse_transform *= float(self.filter_length) / self.hop_length
319
+
320
+ inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
321
+ inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
322
+
323
+ return inverse_transform
324
+
325
+ def forward(self, input_data):
326
+ self.magnitude, self.phase = self.transform(input_data)
327
+ reconstruction = self.inverse(self.magnitude, self.phase)
328
+ return reconstruction
autoregressive_flow.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+
22
+ # AR_Back_Step and AR_Step based on implementation from
23
+ # https://github.com/NVIDIA/flowtron/blob/master/flowtron.py
24
+ # Original license text:
25
+ ###############################################################################
26
+ #
27
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
28
+ # Licensed under the Apache License, Version 2.0 (the "License");
29
+ # you may not use this file except in compliance with the License.
30
+ # You may obtain a copy of the License at
31
+ #
32
+ # http://www.apache.org/licenses/LICENSE-2.0
33
+ #
34
+ # Unless required by applicable law or agreed to in writing, software
35
+ # distributed under the License is distributed on an "AS IS" BASIS,
36
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
37
+ # See the License for the specific language governing permissions and
38
+ # limitations under the License.
39
+ #
40
+ ###############################################################################
41
+ # Original Author and Contact: Rafael Valle
42
+ # Modification by Rafael Valle
43
+
44
+ import torch
45
+ from torch import nn
46
+ from common import DenseLayer, SplineTransformationLayerAR
47
+
48
+
49
+ class AR_Back_Step(torch.nn.Module):
50
+ def __init__(
51
+ self,
52
+ n_attr_channels,
53
+ n_speaker_dim,
54
+ n_text_dim,
55
+ n_hidden,
56
+ n_lstm_layers,
57
+ scaling_fn,
58
+ spline_flow_params=None,
59
+ ):
60
+ super(AR_Back_Step, self).__init__()
61
+ self.ar_step = AR_Step(
62
+ n_attr_channels,
63
+ n_speaker_dim,
64
+ n_text_dim,
65
+ n_hidden,
66
+ n_lstm_layers,
67
+ scaling_fn,
68
+ spline_flow_params,
69
+ )
70
+
71
+ def forward(self, mel, context, lens):
72
+ mel = torch.flip(mel, (0,))
73
+ context = torch.flip(context, (0,))
74
+ # backwards flow, send padded zeros back to end
75
+ for k in range(mel.size(1)):
76
+ mel[:, k] = mel[:, k].roll(lens[k].item(), dims=0)
77
+ context[:, k] = context[:, k].roll(lens[k].item(), dims=0)
78
+
79
+ mel, log_s = self.ar_step(mel, context, lens)
80
+
81
+ # move padded zeros back to beginning
82
+ for k in range(mel.size(1)):
83
+ mel[:, k] = mel[:, k].roll(-lens[k].item(), dims=0)
84
+
85
+ return torch.flip(mel, (0,)), log_s
86
+
87
+ def infer(self, residual, context):
88
+ residual = self.ar_step.infer(
89
+ torch.flip(residual, (0,)), torch.flip(context, (0,))
90
+ )
91
+ residual = torch.flip(residual, (0,))
92
+ return residual
93
+
94
+
95
+ class AR_Step(torch.nn.Module):
96
+ def __init__(
97
+ self,
98
+ n_attr_channels,
99
+ n_speaker_dim,
100
+ n_text_channels,
101
+ n_hidden,
102
+ n_lstm_layers,
103
+ scaling_fn,
104
+ spline_flow_params=None,
105
+ ):
106
+ super(AR_Step, self).__init__()
107
+ if spline_flow_params is not None:
108
+ self.spline_flow = SplineTransformationLayerAR(**spline_flow_params)
109
+ else:
110
+ self.n_out_dims = n_attr_channels
111
+ self.conv = torch.nn.Conv1d(n_hidden, 2 * n_attr_channels, 1)
112
+ self.conv.weight.data = 0.0 * self.conv.weight.data
113
+ self.conv.bias.data = 0.0 * self.conv.bias.data
114
+
115
+ self.attr_lstm = torch.nn.LSTM(n_attr_channels, n_hidden)
116
+ self.lstm = torch.nn.LSTM(
117
+ n_hidden + n_text_channels + n_speaker_dim, n_hidden, n_lstm_layers
118
+ )
119
+
120
+ if spline_flow_params is None:
121
+ self.dense_layer = DenseLayer(in_dim=n_hidden, sizes=[n_hidden, n_hidden])
122
+ self.scaling_fn = scaling_fn
123
+
124
+ def run_padded_sequence(
125
+ self, sorted_idx, unsort_idx, lens, padded_data, recurrent_model
126
+ ):
127
+ """Sorts input data by previded ordering (and un-ordering) and runs the
128
+ packed data through the recurrent model
129
+
130
+ Args:
131
+ sorted_idx (torch.tensor): 1D sorting index
132
+ unsort_idx (torch.tensor): 1D unsorting index (inverse sorted_idx)
133
+ lens: lengths of input data (sorted in descending order)
134
+ padded_data (torch.tensor): input sequences (padded)
135
+ recurrent_model (nn.Module): recurrent model to run data through
136
+ Returns:
137
+ hidden_vectors (torch.tensor): outputs of the RNN, in the original,
138
+ unsorted, ordering
139
+ """
140
+
141
+ # sort the data by decreasing length using provided index
142
+ # we assume batch index is in dim=1
143
+ padded_data = padded_data[:, sorted_idx]
144
+ padded_data = nn.utils.rnn.pack_padded_sequence(padded_data, lens.cpu())
145
+ hidden_vectors = recurrent_model(padded_data)[0]
146
+ hidden_vectors, _ = nn.utils.rnn.pad_packed_sequence(hidden_vectors)
147
+ # unsort the results at dim=1 and return
148
+ hidden_vectors = hidden_vectors[:, unsort_idx]
149
+ return hidden_vectors
150
+
151
+ def get_scaling_and_logs(self, scale_unconstrained):
152
+ if self.scaling_fn == "translate":
153
+ s = torch.exp(scale_unconstrained * 0)
154
+ log_s = scale_unconstrained * 0
155
+ elif self.scaling_fn == "exp":
156
+ s = torch.exp(scale_unconstrained)
157
+ log_s = scale_unconstrained # log(exp
158
+ elif self.scaling_fn == "tanh":
159
+ s = torch.tanh(scale_unconstrained) + 1 + 1e-6
160
+ log_s = torch.log(s)
161
+ elif self.scaling_fn == "sigmoid":
162
+ s = torch.sigmoid(scale_unconstrained + 10) + 1e-6
163
+ log_s = torch.log(s)
164
+ else:
165
+ raise Exception("Scaling fn {} not supp.".format(self.scaling_fn))
166
+
167
+ return s, log_s
168
+
169
+ def forward(self, mel, context, lens):
170
+ dummy = torch.FloatTensor(1, mel.size(1), mel.size(2)).zero_()
171
+ dummy = dummy.type(mel.type())
172
+ # seq_len x batch x dim
173
+ mel0 = torch.cat([dummy, mel[:-1]], 0)
174
+
175
+ self.lstm.flatten_parameters()
176
+ self.attr_lstm.flatten_parameters()
177
+ if lens is not None:
178
+ # collect decreasing length indices
179
+ lens, ids = torch.sort(lens, descending=True)
180
+ original_ids = [0] * lens.size(0)
181
+ for i, ids_i in enumerate(ids):
182
+ original_ids[ids_i] = i
183
+ # mel_seq_len x batch x hidden_dim
184
+ mel_hidden = self.run_padded_sequence(
185
+ ids, original_ids, lens, mel0, self.attr_lstm
186
+ )
187
+ else:
188
+ mel_hidden = self.attr_lstm(mel0)[0]
189
+
190
+ decoder_input = torch.cat((mel_hidden, context), -1)
191
+
192
+ if lens is not None:
193
+ # reorder, run padded sequence and undo reordering
194
+ lstm_hidden = self.run_padded_sequence(
195
+ ids, original_ids, lens, decoder_input, self.lstm
196
+ )
197
+ else:
198
+ lstm_hidden = self.lstm(decoder_input)[0]
199
+
200
+ if hasattr(self, "spline_flow"):
201
+ # spline flow fn expects inputs to be batch, channel, time
202
+ lstm_hidden = lstm_hidden.permute(1, 2, 0)
203
+ mel = mel.permute(1, 2, 0)
204
+ mel, log_s = self.spline_flow(mel, lstm_hidden, inverse=False)
205
+ mel = mel.permute(2, 0, 1)
206
+ log_s = log_s.permute(2, 0, 1)
207
+ else:
208
+ lstm_hidden = self.dense_layer(lstm_hidden).permute(1, 2, 0)
209
+ decoder_output = self.conv(lstm_hidden).permute(2, 0, 1)
210
+
211
+ scale, log_s = self.get_scaling_and_logs(
212
+ decoder_output[:, :, : self.n_out_dims]
213
+ )
214
+ bias = decoder_output[:, :, self.n_out_dims :]
215
+
216
+ mel = scale * mel + bias
217
+
218
+ return mel, log_s
219
+
220
+ def infer(self, residual, context):
221
+ total_output = [] # seems 10FPS faster than pre-allocation
222
+
223
+ output = None
224
+ dummy = torch.cuda.FloatTensor(1, residual.size(1), residual.size(2)).zero_()
225
+ self.attr_lstm.flatten_parameters()
226
+
227
+ for i in range(0, residual.size(0)):
228
+ if i == 0:
229
+ output = dummy
230
+ mel_hidden, (h, c) = self.attr_lstm(output)
231
+ else:
232
+ mel_hidden, (h, c) = self.attr_lstm(output, (h, c))
233
+
234
+ decoder_input = torch.cat((mel_hidden, context[i][None]), -1)
235
+
236
+ if i == 0:
237
+ lstm_hidden, (h1, c1) = self.lstm(decoder_input)
238
+ else:
239
+ lstm_hidden, (h1, c1) = self.lstm(decoder_input, (h1, c1))
240
+
241
+ if hasattr(self, "spline_flow"):
242
+ # expects inputs to be batch, channel, time
243
+ lstm_hidden = lstm_hidden.permute(1, 2, 0)
244
+ output = residual[i : i + 1].permute(1, 2, 0)
245
+ output = self.spline_flow(output, lstm_hidden, inverse=True)
246
+ output = output.permute(2, 0, 1)
247
+ else:
248
+ lstm_hidden = self.dense_layer(lstm_hidden).permute(1, 2, 0)
249
+ decoder_output = self.conv(lstm_hidden).permute(2, 0, 1)
250
+
251
+ s, log_s = self.get_scaling_and_logs(
252
+ decoder_output[:, :, : decoder_output.size(2) // 2]
253
+ )
254
+ b = decoder_output[:, :, decoder_output.size(2) // 2 :]
255
+ output = (residual[i : i + 1] - b) / s
256
+ total_output.append(output)
257
+
258
+ total_output = torch.cat(total_output, 0)
259
+ return total_output
common.py ADDED
@@ -0,0 +1,1083 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+
22
+ # 1x1InvertibleConv and WN based on implementation from WaveGlow https://github.com/NVIDIA/waveglow/blob/master/glow.py
23
+ # Original license:
24
+ # *****************************************************************************
25
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
26
+ #
27
+ # Redistribution and use in source and binary forms, with or without
28
+ # modification, are permitted provided that the following conditions are met:
29
+ # * Redistributions of source code must retain the above copyright
30
+ # notice, this list of conditions and the following disclaimer.
31
+ # * Redistributions in binary form must reproduce the above copyright
32
+ # notice, this list of conditions and the following disclaimer in the
33
+ # documentation and/or other materials provided with the distribution.
34
+ # * Neither the name of the NVIDIA CORPORATION nor the
35
+ # names of its contributors may be used to endorse or promote products
36
+ # derived from this software without specific prior written permission.
37
+ #
38
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
39
+ # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
40
+ # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
41
+ # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
42
+ # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
43
+ # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
44
+ # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
45
+ # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
46
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
47
+ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
48
+ #
49
+ # *****************************************************************************
50
+
51
+ import torch
52
+ from torch import nn
53
+ from torch.nn import functional as F
54
+
55
+ import numpy as np
56
+ import ast
57
+
58
+ from splines import (
59
+ piecewise_linear_transform,
60
+ piecewise_linear_inverse_transform,
61
+ unbounded_piecewise_quadratic_transform,
62
+ )
63
+ from partialconv1d import PartialConv1d as pconv1d
64
+ from typing import Tuple
65
+
66
+ use_cuda = torch.cuda.is_available()
67
+
68
+ if use_cuda:
69
+ device = "cuda"
70
+ else:
71
+ device = "cpu"
72
+
73
+
74
+ def update_params(config, params):
75
+ for param in params:
76
+ print(param)
77
+ k, v = param.split("=")
78
+ try:
79
+ v = ast.literal_eval(v)
80
+ except:
81
+ pass
82
+
83
+ k_split = k.split(".")
84
+ if len(k_split) > 1:
85
+ parent_k = k_split[0]
86
+ cur_param = [".".join(k_split[1:]) + "=" + str(v)]
87
+ update_params(config[parent_k], cur_param)
88
+ elif k in config and len(k_split) == 1:
89
+ print(f"overriding {k} with {v}")
90
+ config[k] = v
91
+ else:
92
+ print("{}, {} params not updated".format(k, v))
93
+
94
+
95
+ def get_mask_from_lengths(lengths):
96
+ """Constructs binary mask from a 1D torch tensor of input lengths
97
+
98
+ Args:
99
+ lengths (torch.tensor): 1D tensor
100
+ Returns:
101
+ mask (torch.tensor): num_sequences x max_length x 1 binary tensor
102
+ """
103
+ max_len = torch.max(lengths).item()
104
+ if torch.cuda.is_available():
105
+ ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
106
+ else:
107
+ ids = torch.arange(0, max_len, out=torch.LongTensor(max_len))
108
+ mask = (ids < lengths.unsqueeze(1)).bool()
109
+ return mask
110
+
111
+
112
+ class ExponentialClass(torch.nn.Module):
113
+ def __init__(self):
114
+ super(ExponentialClass, self).__init__()
115
+
116
+ def forward(self, x):
117
+ return torch.exp(x)
118
+
119
+
120
+ class LinearNorm(torch.nn.Module):
121
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
122
+ super(LinearNorm, self).__init__()
123
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
124
+
125
+ torch.nn.init.xavier_uniform_(
126
+ self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
127
+ )
128
+
129
+ def forward(self, x):
130
+ return self.linear_layer(x)
131
+
132
+
133
+ class ConvNorm(torch.nn.Module):
134
+ def __init__(
135
+ self,
136
+ in_channels,
137
+ out_channels,
138
+ kernel_size=1,
139
+ stride=1,
140
+ padding=None,
141
+ dilation=1,
142
+ bias=True,
143
+ w_init_gain="linear",
144
+ use_partial_padding=False,
145
+ use_weight_norm=False,
146
+ ):
147
+ super(ConvNorm, self).__init__()
148
+ if padding is None:
149
+ assert kernel_size % 2 == 1
150
+ padding = int(dilation * (kernel_size - 1) / 2)
151
+ self.kernel_size = kernel_size
152
+ self.dilation = dilation
153
+ self.use_partial_padding = use_partial_padding
154
+ self.use_weight_norm = use_weight_norm
155
+ conv_fn = torch.nn.Conv1d
156
+ if self.use_partial_padding:
157
+ conv_fn = pconv1d
158
+ self.conv = conv_fn(
159
+ in_channels,
160
+ out_channels,
161
+ kernel_size=kernel_size,
162
+ stride=stride,
163
+ padding=padding,
164
+ dilation=dilation,
165
+ bias=bias,
166
+ )
167
+ torch.nn.init.xavier_uniform_(
168
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
169
+ )
170
+ if self.use_weight_norm:
171
+ self.conv = nn.utils.weight_norm(self.conv)
172
+
173
+ def forward(self, signal, mask=None):
174
+ if self.use_partial_padding:
175
+ conv_signal = self.conv(signal, mask)
176
+ else:
177
+ conv_signal = self.conv(signal)
178
+ if mask is not None:
179
+ # always re-zero output if mask is
180
+ # available to match zero-padding
181
+ conv_signal = conv_signal * mask
182
+ return conv_signal
183
+
184
+
185
+ class DenseLayer(nn.Module):
186
+ def __init__(self, in_dim=1024, sizes=[1024, 1024]):
187
+ super(DenseLayer, self).__init__()
188
+ in_sizes = [in_dim] + sizes[:-1]
189
+ self.layers = nn.ModuleList(
190
+ [
191
+ LinearNorm(in_size, out_size, bias=True)
192
+ for (in_size, out_size) in zip(in_sizes, sizes)
193
+ ]
194
+ )
195
+
196
+ def forward(self, x):
197
+ for linear in self.layers:
198
+ x = torch.tanh(linear(x))
199
+ return x
200
+
201
+
202
+ class LengthRegulator(nn.Module):
203
+ def __init__(self):
204
+ super().__init__()
205
+
206
+ def forward(self, x, dur):
207
+ output = []
208
+ for x_i, dur_i in zip(x, dur):
209
+ expanded = self.expand(x_i, dur_i)
210
+ output.append(expanded)
211
+ output = self.pad(output)
212
+ return output
213
+
214
+ def expand(self, x, dur):
215
+ output = []
216
+ for i, frame in enumerate(x):
217
+ expanded_len = int(dur[i] + 0.5)
218
+ expanded = frame.expand(expanded_len, -1)
219
+ output.append(expanded)
220
+ output = torch.cat(output, 0)
221
+ return output
222
+
223
+ def pad(self, x):
224
+ output = []
225
+ max_len = max([x[i].size(0) for i in range(len(x))])
226
+ for i, seq in enumerate(x):
227
+ padded = F.pad(seq, [0, 0, 0, max_len - seq.size(0)], "constant", 0.0)
228
+ output.append(padded)
229
+ output = torch.stack(output)
230
+ return output
231
+
232
+
233
+ class ConvLSTMLinear(nn.Module):
234
+ def __init__(
235
+ self,
236
+ in_dim,
237
+ out_dim,
238
+ n_layers=2,
239
+ n_channels=256,
240
+ kernel_size=3,
241
+ p_dropout=0.1,
242
+ lstm_type="bilstm",
243
+ use_linear=True,
244
+ ):
245
+ super(ConvLSTMLinear, self).__init__()
246
+ self.out_dim = out_dim
247
+ self.lstm_type = lstm_type
248
+ self.use_linear = use_linear
249
+ self.dropout = nn.Dropout(p=p_dropout)
250
+
251
+ convolutions = []
252
+ for i in range(n_layers):
253
+ conv_layer = ConvNorm(
254
+ in_dim if i == 0 else n_channels,
255
+ n_channels,
256
+ kernel_size=kernel_size,
257
+ stride=1,
258
+ padding=int((kernel_size - 1) / 2),
259
+ dilation=1,
260
+ w_init_gain="relu",
261
+ )
262
+ conv_layer = torch.nn.utils.weight_norm(conv_layer.conv, name="weight")
263
+ convolutions.append(conv_layer)
264
+
265
+ self.convolutions = nn.ModuleList(convolutions)
266
+
267
+ if not self.use_linear:
268
+ n_channels = out_dim
269
+
270
+ if self.lstm_type != "":
271
+ use_bilstm = False
272
+ lstm_channels = n_channels
273
+ if self.lstm_type == "bilstm":
274
+ use_bilstm = True
275
+ lstm_channels = int(n_channels // 2)
276
+
277
+ self.bilstm = nn.LSTM(
278
+ n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm
279
+ )
280
+ lstm_norm_fn_pntr = nn.utils.spectral_norm
281
+ self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0")
282
+ if self.lstm_type == "bilstm":
283
+ self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse")
284
+
285
+ if self.use_linear:
286
+ self.dense = nn.Linear(n_channels, out_dim)
287
+
288
+ def run_padded_sequence(self, context, lens):
289
+ context_embedded = []
290
+ for b_ind in range(context.size()[0]): # TODO: speed up
291
+ curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone()
292
+ for conv in self.convolutions:
293
+ curr_context = self.dropout(F.relu(conv(curr_context)))
294
+ context_embedded.append(curr_context[0].transpose(0, 1))
295
+ context = torch.nn.utils.rnn.pad_sequence(context_embedded, batch_first=True)
296
+ return context
297
+
298
+ def run_unsorted_inputs(self, fn, context, lens):
299
+ lens_sorted, ids_sorted = torch.sort(lens, descending=True)
300
+ unsort_ids = [0] * lens.size(0)
301
+ for i in range(len(ids_sorted)):
302
+ unsort_ids[ids_sorted[i]] = i
303
+ lens_sorted = lens_sorted.long().cpu()
304
+
305
+ context = context[ids_sorted]
306
+ context = nn.utils.rnn.pack_padded_sequence(
307
+ context, lens_sorted, batch_first=True
308
+ )
309
+ context = fn(context)[0]
310
+ context = nn.utils.rnn.pad_packed_sequence(context, batch_first=True)[0]
311
+
312
+ # map back to original indices
313
+ context = context[unsort_ids]
314
+ return context
315
+
316
+ def forward(self, context, lens):
317
+ if context.size()[0] > 1:
318
+ context = self.run_padded_sequence(context, lens)
319
+ # to B, D, T
320
+ context = context.transpose(1, 2)
321
+ else:
322
+ for conv in self.convolutions:
323
+ context = self.dropout(F.relu(conv(context)))
324
+
325
+ if self.lstm_type != "":
326
+ context = context.transpose(1, 2)
327
+ self.bilstm.flatten_parameters()
328
+ if lens is not None:
329
+ context = self.run_unsorted_inputs(self.bilstm, context, lens)
330
+ else:
331
+ context = self.bilstm(context)[0]
332
+ context = context.transpose(1, 2)
333
+
334
+ x_hat = context
335
+ if self.use_linear:
336
+ x_hat = self.dense(context.transpose(1, 2)).transpose(1, 2)
337
+
338
+ return x_hat
339
+
340
+ def infer(self, z, txt_enc, spk_emb):
341
+ x_hat = self.forward(txt_enc, spk_emb)["x_hat"]
342
+ x_hat = self.feature_processing.denormalize(x_hat)
343
+ return x_hat
344
+
345
+
346
+ class Encoder(nn.Module):
347
+ """Encoder module:
348
+ - Three 1-d convolution banks
349
+ - Bidirectional LSTM
350
+ """
351
+
352
+ def __init__(
353
+ self,
354
+ encoder_n_convolutions=3,
355
+ encoder_embedding_dim=512,
356
+ encoder_kernel_size=5,
357
+ norm_fn=nn.BatchNorm1d,
358
+ lstm_norm_fn=None,
359
+ ):
360
+ super(Encoder, self).__init__()
361
+
362
+ convolutions = []
363
+ for _ in range(encoder_n_convolutions):
364
+ conv_layer = nn.Sequential(
365
+ ConvNorm(
366
+ encoder_embedding_dim,
367
+ encoder_embedding_dim,
368
+ kernel_size=encoder_kernel_size,
369
+ stride=1,
370
+ padding=int((encoder_kernel_size - 1) / 2),
371
+ dilation=1,
372
+ w_init_gain="relu",
373
+ use_partial_padding=True,
374
+ ),
375
+ norm_fn(encoder_embedding_dim, affine=True),
376
+ )
377
+ convolutions.append(conv_layer)
378
+ self.convolutions = nn.ModuleList(convolutions)
379
+
380
+ self.lstm = nn.LSTM(
381
+ encoder_embedding_dim,
382
+ int(encoder_embedding_dim / 2),
383
+ 1,
384
+ batch_first=True,
385
+ bidirectional=True,
386
+ )
387
+ if lstm_norm_fn is not None:
388
+ if "spectral" in lstm_norm_fn:
389
+ print("Applying spectral norm to text encoder LSTM")
390
+ lstm_norm_fn_pntr = torch.nn.utils.spectral_norm
391
+ elif "weight" in lstm_norm_fn:
392
+ print("Applying weight norm to text encoder LSTM")
393
+ lstm_norm_fn_pntr = torch.nn.utils.weight_norm
394
+ self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0")
395
+ self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0_reverse")
396
+
397
+ @torch.autocast(device, enabled=False)
398
+ def forward(self, x, in_lens):
399
+ """
400
+ Args:
401
+ x (torch.tensor): N x C x L padded input of text embeddings
402
+ in_lens (torch.tensor): 1D tensor of sequence lengths
403
+ """
404
+ if x.size()[0] > 1:
405
+ x_embedded = []
406
+ for b_ind in range(x.size()[0]): # TODO: improve speed
407
+ curr_x = x[b_ind : b_ind + 1, :, : in_lens[b_ind]].clone()
408
+ for conv in self.convolutions:
409
+ curr_x = F.dropout(F.relu(conv(curr_x)), 0.5, self.training)
410
+ x_embedded.append(curr_x[0].transpose(0, 1))
411
+ x = torch.nn.utils.rnn.pad_sequence(x_embedded, batch_first=True)
412
+ else:
413
+ for conv in self.convolutions:
414
+ x = F.dropout(F.relu(conv(x)), 0.5, self.training)
415
+ x = x.transpose(1, 2)
416
+
417
+ # recent amp change -- change in_lens to int
418
+ in_lens = in_lens.int().cpu()
419
+
420
+ x = nn.utils.rnn.pack_padded_sequence(x, in_lens, batch_first=True)
421
+
422
+ self.lstm.flatten_parameters()
423
+ outputs, _ = self.lstm(x)
424
+
425
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
426
+
427
+ return outputs
428
+
429
+ @torch.autocast(device, enabled=False)
430
+ def infer(self, x):
431
+ for conv in self.convolutions:
432
+ x = F.dropout(F.relu(conv(x)), 0.5, self.training)
433
+
434
+ x = x.transpose(1, 2)
435
+ self.lstm.flatten_parameters()
436
+ outputs, _ = self.lstm(x)
437
+
438
+ return outputs
439
+
440
+
441
+ class Invertible1x1ConvLUS(torch.nn.Module):
442
+ def __init__(self, c, cache_inverse=False):
443
+ super(Invertible1x1ConvLUS, self).__init__()
444
+ # Sample a random orthonormal matrix to initialize weights
445
+ W = torch.linalg.qr(torch.FloatTensor(c, c).normal_())[0]
446
+ # Ensure determinant is 1.0 not -1.0
447
+ if torch.det(W) < 0:
448
+ W[:, 0] = -1 * W[:, 0]
449
+ p, lower, upper = torch.lu_unpack(*torch.lu(W))
450
+
451
+ self.register_buffer("p", p)
452
+ # diagonals of lower will always be 1s anyway
453
+ lower = torch.tril(lower, -1)
454
+ lower_diag = torch.diag(torch.eye(c, c))
455
+ self.register_buffer("lower_diag", lower_diag)
456
+ self.lower = nn.Parameter(lower)
457
+ self.upper_diag = nn.Parameter(torch.diag(upper))
458
+ self.upper = nn.Parameter(torch.triu(upper, 1))
459
+ self.cache_inverse = cache_inverse
460
+
461
+ @torch.autocast(device, enabled=False)
462
+ def forward(self, z, inverse=False):
463
+ U = torch.triu(self.upper, 1) + torch.diag(self.upper_diag)
464
+ L = torch.tril(self.lower, -1) + torch.diag(self.lower_diag)
465
+ W = torch.mm(self.p, torch.mm(L, U))
466
+ if inverse:
467
+ if not hasattr(self, "W_inverse"):
468
+ # inverse computation
469
+ W_inverse = W.float().inverse()
470
+ if z.type() == "torch.cuda.HalfTensor":
471
+ W_inverse = W_inverse.half()
472
+
473
+ self.W_inverse = W_inverse[..., None]
474
+ z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
475
+ if not self.cache_inverse:
476
+ delattr(self, "W_inverse")
477
+ return z
478
+ else:
479
+ W = W[..., None]
480
+ z = F.conv1d(z, W, bias=None, stride=1, padding=0)
481
+ log_det_W = torch.sum(torch.log(torch.abs(self.upper_diag)))
482
+ return z, log_det_W
483
+
484
+
485
+ class Invertible1x1Conv(torch.nn.Module):
486
+ """
487
+ The layer outputs both the convolution, and the log determinant
488
+ of its weight matrix. If inverse=True it does convolution with
489
+ inverse
490
+ """
491
+
492
+ def __init__(self, c, cache_inverse=False):
493
+ super(Invertible1x1Conv, self).__init__()
494
+ self.conv = torch.nn.Conv1d(
495
+ c, c, kernel_size=1, stride=1, padding=0, bias=False
496
+ )
497
+
498
+ # Sample a random orthonormal matrix to initialize weights
499
+ W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
500
+
501
+ # Ensure determinant is 1.0 not -1.0
502
+ if torch.det(W) < 0:
503
+ W[:, 0] = -1 * W[:, 0]
504
+ W = W.view(c, c, 1)
505
+ self.conv.weight.data = W
506
+ self.cache_inverse = cache_inverse
507
+
508
+ def forward(self, z, inverse=False):
509
+ # DO NOT apply n_of_groups, as it doesn't account for padded sequences
510
+ W = self.conv.weight.squeeze()
511
+
512
+ if inverse:
513
+ if not hasattr(self, "W_inverse"):
514
+ # Inverse computation
515
+ W_inverse = W.float().inverse()
516
+ if z.type() == "torch.cuda.HalfTensor":
517
+ W_inverse = W_inverse.half()
518
+
519
+ self.W_inverse = W_inverse[..., None]
520
+ z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
521
+ if not self.cache_inverse:
522
+ delattr(self, "W_inverse")
523
+ return z
524
+ else:
525
+ # Forward computation
526
+ log_det_W = torch.logdet(W).clone()
527
+ z = self.conv(z)
528
+ return z, log_det_W
529
+
530
+
531
+ class SimpleConvNet(torch.nn.Module):
532
+ def __init__(
533
+ self,
534
+ n_mel_channels,
535
+ n_context_dim,
536
+ final_out_channels,
537
+ n_layers=2,
538
+ kernel_size=5,
539
+ with_dilation=True,
540
+ max_channels=1024,
541
+ zero_init=True,
542
+ use_partial_padding=True,
543
+ ):
544
+ super(SimpleConvNet, self).__init__()
545
+ self.layers = torch.nn.ModuleList()
546
+ self.n_layers = n_layers
547
+ in_channels = n_mel_channels + n_context_dim
548
+ out_channels = -1
549
+ self.use_partial_padding = use_partial_padding
550
+ for i in range(n_layers):
551
+ dilation = 2**i if with_dilation else 1
552
+ padding = int((kernel_size * dilation - dilation) / 2)
553
+ out_channels = min(max_channels, in_channels * 2)
554
+ self.layers.append(
555
+ ConvNorm(
556
+ in_channels,
557
+ out_channels,
558
+ kernel_size=kernel_size,
559
+ stride=1,
560
+ padding=padding,
561
+ dilation=dilation,
562
+ bias=True,
563
+ w_init_gain="relu",
564
+ use_partial_padding=use_partial_padding,
565
+ )
566
+ )
567
+ in_channels = out_channels
568
+
569
+ self.last_layer = torch.nn.Conv1d(
570
+ out_channels, final_out_channels, kernel_size=1
571
+ )
572
+
573
+ if zero_init:
574
+ self.last_layer.weight.data *= 0
575
+ self.last_layer.bias.data *= 0
576
+
577
+ def forward(self, z_w_context, seq_lens: torch.Tensor = None):
578
+ # seq_lens: tensor array of sequence sequence lengths
579
+ # output should be b x n_mel_channels x z_w_context.shape(2)
580
+ mask = None
581
+ if seq_lens is not None:
582
+ mask = get_mask_from_lengths(seq_lens).unsqueeze(1).float()
583
+
584
+ for i in range(self.n_layers):
585
+ z_w_context = self.layers[i](z_w_context, mask)
586
+ z_w_context = torch.relu(z_w_context)
587
+
588
+ z_w_context = self.last_layer(z_w_context)
589
+ return z_w_context
590
+
591
+
592
+ class WN(torch.nn.Module):
593
+ """
594
+ Adapted from WN() module in WaveGlow with modififcations to variable names
595
+ """
596
+
597
+ def __init__(
598
+ self,
599
+ n_in_channels,
600
+ n_context_dim,
601
+ n_layers,
602
+ n_channels,
603
+ kernel_size=5,
604
+ affine_activation="softplus",
605
+ use_partial_padding=True,
606
+ ):
607
+ super(WN, self).__init__()
608
+ assert kernel_size % 2 == 1
609
+ assert n_channels % 2 == 0
610
+ self.n_layers = n_layers
611
+ self.n_channels = n_channels
612
+ self.in_layers = torch.nn.ModuleList()
613
+ self.res_skip_layers = torch.nn.ModuleList()
614
+ start = torch.nn.Conv1d(n_in_channels + n_context_dim, n_channels, 1)
615
+ start = torch.nn.utils.weight_norm(start, name="weight")
616
+ self.start = start
617
+ self.softplus = torch.nn.Softplus()
618
+ self.affine_activation = affine_activation
619
+ self.use_partial_padding = use_partial_padding
620
+ # Initializing last layer to 0 makes the affine coupling layers
621
+ # do nothing at first. This helps with training stability
622
+ end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1)
623
+ end.weight.data.zero_()
624
+ end.bias.data.zero_()
625
+ self.end = end
626
+
627
+ for i in range(n_layers):
628
+ dilation = 2**i
629
+ padding = int((kernel_size * dilation - dilation) / 2)
630
+ in_layer = ConvNorm(
631
+ n_channels,
632
+ n_channels,
633
+ kernel_size=kernel_size,
634
+ dilation=dilation,
635
+ padding=padding,
636
+ use_partial_padding=use_partial_padding,
637
+ use_weight_norm=True,
638
+ )
639
+ # in_layer = nn.Conv1d(n_channels, n_channels, kernel_size,
640
+ # dilation=dilation, padding=padding)
641
+ # in_layer = nn.utils.weight_norm(in_layer)
642
+ self.in_layers.append(in_layer)
643
+ res_skip_layer = nn.Conv1d(n_channels, n_channels, 1)
644
+ res_skip_layer = nn.utils.weight_norm(res_skip_layer)
645
+ self.res_skip_layers.append(res_skip_layer)
646
+
647
+ def forward(
648
+ self,
649
+ forward_input: Tuple[torch.Tensor, torch.Tensor],
650
+ seq_lens: torch.Tensor = None,
651
+ ):
652
+ z, context = forward_input
653
+ z = torch.cat((z, context), 1) # append context to z as well
654
+ z = self.start(z)
655
+ output = torch.zeros_like(z)
656
+ mask = None
657
+ if seq_lens is not None:
658
+ mask = get_mask_from_lengths(seq_lens).unsqueeze(1).float()
659
+ non_linearity = torch.relu
660
+ if self.affine_activation == "softplus":
661
+ non_linearity = self.softplus
662
+
663
+ for i in range(self.n_layers):
664
+ z = non_linearity(self.in_layers[i](z, mask))
665
+ res_skip_acts = non_linearity(self.res_skip_layers[i](z))
666
+ output = output + res_skip_acts
667
+
668
+ output = self.end(output) # [B, dim, seq_len]
669
+ return output
670
+
671
+
672
+ # Affine Coupling Layers
673
+ class SplineTransformationLayerAR(torch.nn.Module):
674
+ def __init__(
675
+ self,
676
+ n_in_channels,
677
+ n_context_dim,
678
+ n_layers,
679
+ affine_model="simple_conv",
680
+ kernel_size=1,
681
+ scaling_fn="exp",
682
+ affine_activation="softplus",
683
+ n_channels=1024,
684
+ n_bins=8,
685
+ left=-6,
686
+ right=6,
687
+ bottom=-6,
688
+ top=6,
689
+ use_quadratic=False,
690
+ ):
691
+ super(SplineTransformationLayerAR, self).__init__()
692
+ self.n_in_channels = n_in_channels # input dimensions
693
+ self.left = left
694
+ self.right = right
695
+ self.bottom = bottom
696
+ self.top = top
697
+ self.n_bins = n_bins
698
+ self.spline_fn = piecewise_linear_transform
699
+ self.inv_spline_fn = piecewise_linear_inverse_transform
700
+ self.use_quadratic = use_quadratic
701
+
702
+ if self.use_quadratic:
703
+ self.spline_fn = unbounded_piecewise_quadratic_transform
704
+ self.inv_spline_fn = unbounded_piecewise_quadratic_transform
705
+ self.n_bins = 2 * self.n_bins + 1
706
+ final_out_channels = self.n_in_channels * self.n_bins
707
+
708
+ # autoregressive flow, kernel size 1 and no dilation
709
+ self.param_predictor = SimpleConvNet(
710
+ n_context_dim,
711
+ 0,
712
+ final_out_channels,
713
+ n_layers,
714
+ with_dilation=False,
715
+ kernel_size=1,
716
+ zero_init=True,
717
+ use_partial_padding=False,
718
+ )
719
+
720
+ # output is unnormalized bin weights
721
+
722
+ def normalize(self, z, inverse):
723
+ # normalize to [0, 1]
724
+ if inverse:
725
+ z = (z - self.bottom) / (self.top - self.bottom)
726
+ else:
727
+ z = (z - self.left) / (self.right - self.left)
728
+
729
+ return z
730
+
731
+ def denormalize(self, z, inverse):
732
+ if inverse:
733
+ z = z * (self.right - self.left) + self.left
734
+ else:
735
+ z = z * (self.top - self.bottom) + self.bottom
736
+
737
+ return z
738
+
739
+ def forward(self, z, context, inverse=False):
740
+ b_s, c_s, t_s = z.size(0), z.size(1), z.size(2)
741
+
742
+ z = self.normalize(z, inverse)
743
+
744
+ if z.min() < 0.0 or z.max() > 1.0:
745
+ print("spline z scaled beyond [0, 1]", z.min(), z.max())
746
+
747
+ z_reshaped = z.permute(0, 2, 1).reshape(b_s * t_s, -1)
748
+ affine_params = self.param_predictor(context)
749
+ q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, c_s, -1)
750
+ with torch.autocast(device, enabled=False):
751
+ if self.use_quadratic:
752
+ w = q_tilde[:, :, : self.n_bins // 2]
753
+ v = q_tilde[:, :, self.n_bins // 2 :]
754
+ z_tformed, log_s = self.spline_fn(
755
+ z_reshaped.float(), w.float(), v.float(), inverse=inverse
756
+ )
757
+ else:
758
+ z_tformed, log_s = self.spline_fn(z_reshaped.float(), q_tilde.float())
759
+
760
+ z = z_tformed.reshape(b_s, t_s, -1).permute(0, 2, 1)
761
+ z = self.denormalize(z, inverse)
762
+ if inverse:
763
+ return z
764
+
765
+ log_s = log_s.reshape(b_s, t_s, -1)
766
+ log_s = log_s.permute(0, 2, 1)
767
+ log_s = log_s + c_s * (
768
+ np.log(self.top - self.bottom) - np.log(self.right - self.left)
769
+ )
770
+ return z, log_s
771
+
772
+
773
+ class SplineTransformationLayer(torch.nn.Module):
774
+ def __init__(
775
+ self,
776
+ n_mel_channels,
777
+ n_context_dim,
778
+ n_layers,
779
+ with_dilation=True,
780
+ kernel_size=5,
781
+ scaling_fn="exp",
782
+ affine_activation="softplus",
783
+ n_channels=1024,
784
+ n_bins=8,
785
+ left=-4,
786
+ right=4,
787
+ bottom=-4,
788
+ top=4,
789
+ use_quadratic=False,
790
+ ):
791
+ super(SplineTransformationLayer, self).__init__()
792
+ self.n_mel_channels = n_mel_channels # input dimensions
793
+ self.half_mel_channels = int(n_mel_channels / 2) # half, because we split
794
+ self.left = left
795
+ self.right = right
796
+ self.bottom = bottom
797
+ self.top = top
798
+ self.n_bins = n_bins
799
+ self.spline_fn = piecewise_linear_transform
800
+ self.inv_spline_fn = piecewise_linear_inverse_transform
801
+ self.use_quadratic = use_quadratic
802
+
803
+ if self.use_quadratic:
804
+ self.spline_fn = unbounded_piecewise_quadratic_transform
805
+ self.inv_spline_fn = unbounded_piecewise_quadratic_transform
806
+ self.n_bins = 2 * self.n_bins + 1
807
+ final_out_channels = self.half_mel_channels * self.n_bins
808
+
809
+ self.param_predictor = SimpleConvNet(
810
+ self.half_mel_channels,
811
+ n_context_dim,
812
+ final_out_channels,
813
+ n_layers,
814
+ with_dilation=with_dilation,
815
+ kernel_size=kernel_size,
816
+ zero_init=False,
817
+ )
818
+
819
+ # output is unnormalized bin weights
820
+
821
+ def forward(self, z, context, inverse=False, seq_lens=None):
822
+ b_s, c_s, t_s = z.size(0), z.size(1), z.size(2)
823
+
824
+ # condition on z_0, transform z_1
825
+ n_half = self.half_mel_channels
826
+ z_0, z_1 = z[:, :n_half], z[:, n_half:]
827
+
828
+ # normalize to [0,1]
829
+ if inverse:
830
+ z_1 = (z_1 - self.bottom) / (self.top - self.bottom)
831
+ else:
832
+ z_1 = (z_1 - self.left) / (self.right - self.left)
833
+
834
+ z_w_context = torch.cat((z_0, context), 1)
835
+ affine_params = self.param_predictor(z_w_context, seq_lens)
836
+ z_1_reshaped = z_1.permute(0, 2, 1).reshape(b_s * t_s, -1)
837
+ q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, n_half, self.n_bins)
838
+
839
+ with torch.autocast(device, enabled=False):
840
+ if self.use_quadratic:
841
+ w = q_tilde[:, :, : self.n_bins // 2]
842
+ v = q_tilde[:, :, self.n_bins // 2 :]
843
+ z_1_tformed, log_s = self.spline_fn(
844
+ z_1_reshaped.float(), w.float(), v.float(), inverse=inverse
845
+ )
846
+ if not inverse:
847
+ log_s = torch.sum(log_s, 1)
848
+ else:
849
+ if inverse:
850
+ z_1_tformed, _dc = self.inv_spline_fn(
851
+ z_1_reshaped.float(), q_tilde.float(), False
852
+ )
853
+ else:
854
+ z_1_tformed, log_s = self.spline_fn(
855
+ z_1_reshaped.float(), q_tilde.float()
856
+ )
857
+
858
+ z_1 = z_1_tformed.reshape(b_s, t_s, -1).permute(0, 2, 1)
859
+
860
+ # undo [0, 1] normalization
861
+ if inverse:
862
+ z_1 = z_1 * (self.right - self.left) + self.left
863
+ z = torch.cat((z_0, z_1), dim=1)
864
+ return z
865
+ else: # training
866
+ z_1 = z_1 * (self.top - self.bottom) + self.bottom
867
+ z = torch.cat((z_0, z_1), dim=1)
868
+ log_s = log_s.reshape(b_s, t_s).unsqueeze(1) + n_half * (
869
+ np.log(self.top - self.bottom) - np.log(self.right - self.left)
870
+ )
871
+ return z, log_s
872
+
873
+
874
+ class AffineTransformationLayer(torch.nn.Module):
875
+ def __init__(
876
+ self,
877
+ n_mel_channels,
878
+ n_context_dim,
879
+ n_layers,
880
+ affine_model="simple_conv",
881
+ with_dilation=True,
882
+ kernel_size=5,
883
+ scaling_fn="exp",
884
+ affine_activation="softplus",
885
+ n_channels=1024,
886
+ use_partial_padding=False,
887
+ ):
888
+ super(AffineTransformationLayer, self).__init__()
889
+ if affine_model not in ("wavenet", "simple_conv"):
890
+ raise Exception("{} affine model not supported".format(affine_model))
891
+ if isinstance(scaling_fn, list):
892
+ if not all(
893
+ [x in ("translate", "exp", "tanh", "sigmoid") for x in scaling_fn]
894
+ ):
895
+ raise Exception("{} scaling fn not supported".format(scaling_fn))
896
+ else:
897
+ if scaling_fn not in ("translate", "exp", "tanh", "sigmoid"):
898
+ raise Exception("{} scaling fn not supported".format(scaling_fn))
899
+
900
+ self.affine_model = affine_model
901
+ self.scaling_fn = scaling_fn
902
+ if affine_model == "wavenet":
903
+ self.affine_param_predictor = WN(
904
+ int(n_mel_channels / 2),
905
+ n_context_dim,
906
+ n_layers=n_layers,
907
+ n_channels=n_channels,
908
+ affine_activation=affine_activation,
909
+ use_partial_padding=use_partial_padding,
910
+ )
911
+ elif affine_model == "simple_conv":
912
+ self.affine_param_predictor = SimpleConvNet(
913
+ int(n_mel_channels / 2),
914
+ n_context_dim,
915
+ n_mel_channels,
916
+ n_layers,
917
+ with_dilation=with_dilation,
918
+ kernel_size=kernel_size,
919
+ use_partial_padding=use_partial_padding,
920
+ )
921
+ self.n_mel_channels = n_mel_channels
922
+
923
+ def get_scaling_and_logs(self, scale_unconstrained):
924
+ if self.scaling_fn == "translate":
925
+ s = torch.exp(scale_unconstrained * 0)
926
+ log_s = scale_unconstrained * 0
927
+ elif self.scaling_fn == "exp":
928
+ s = torch.exp(scale_unconstrained)
929
+ log_s = scale_unconstrained # log(exp
930
+ elif self.scaling_fn == "tanh":
931
+ s = torch.tanh(scale_unconstrained) + 1 + 1e-6
932
+ log_s = torch.log(s)
933
+ elif self.scaling_fn == "sigmoid":
934
+ s = torch.sigmoid(scale_unconstrained + 10) + 1e-6
935
+ log_s = torch.log(s)
936
+ elif isinstance(self.scaling_fn, list):
937
+ s_list, log_s_list = [], []
938
+ for i in range(scale_unconstrained.shape[1]):
939
+ scaling_i = self.scaling_fn[i]
940
+ if scaling_i == "translate":
941
+ s_i = torch.exp(scale_unconstrained[:i] * 0)
942
+ log_s_i = scale_unconstrained[:, i] * 0
943
+ elif scaling_i == "exp":
944
+ s_i = torch.exp(scale_unconstrained[:, i])
945
+ log_s_i = scale_unconstrained[:, i]
946
+ elif scaling_i == "tanh":
947
+ s_i = torch.tanh(scale_unconstrained[:, i]) + 1 + 1e-6
948
+ log_s_i = torch.log(s_i)
949
+ elif scaling_i == "sigmoid":
950
+ s_i = torch.sigmoid(scale_unconstrained[:, i]) + 1e-6
951
+ log_s_i = torch.log(s_i)
952
+ s_list.append(s_i[:, None])
953
+ log_s_list.append(log_s_i[:, None])
954
+ s = torch.cat(s_list, dim=1)
955
+ log_s = torch.cat(log_s_list, dim=1)
956
+ return s, log_s
957
+
958
+ def forward(self, z, context, inverse=False, seq_lens=None):
959
+ n_half = int(self.n_mel_channels / 2)
960
+ z_0, z_1 = z[:, :n_half], z[:, n_half:]
961
+ if self.affine_model == "wavenet":
962
+ affine_params = self.affine_param_predictor(
963
+ (z_0, context), seq_lens=seq_lens
964
+ )
965
+ elif self.affine_model == "simple_conv":
966
+ z_w_context = torch.cat((z_0, context), 1)
967
+ affine_params = self.affine_param_predictor(z_w_context, seq_lens=seq_lens)
968
+
969
+ scale_unconstrained = affine_params[:, :n_half, :]
970
+ b = affine_params[:, n_half:, :]
971
+ s, log_s = self.get_scaling_and_logs(scale_unconstrained)
972
+
973
+ if inverse:
974
+ z_1 = (z_1 - b) / s
975
+ z = torch.cat((z_0, z_1), dim=1)
976
+ return z
977
+ else:
978
+ z_1 = s * z_1 + b
979
+ z = torch.cat((z_0, z_1), dim=1)
980
+ return z, log_s
981
+
982
+
983
+ class ConvAttention(torch.nn.Module):
984
+ def __init__(
985
+ self, n_mel_channels=80, n_text_channels=512, n_att_channels=80, temperature=1.0
986
+ ):
987
+ super(ConvAttention, self).__init__()
988
+ self.temperature = temperature
989
+ self.softmax = torch.nn.Softmax(dim=3)
990
+ self.log_softmax = torch.nn.LogSoftmax(dim=3)
991
+
992
+ self.key_proj = nn.Sequential(
993
+ ConvNorm(
994
+ n_text_channels,
995
+ n_text_channels * 2,
996
+ kernel_size=3,
997
+ bias=True,
998
+ w_init_gain="relu",
999
+ ),
1000
+ torch.nn.ReLU(),
1001
+ ConvNorm(n_text_channels * 2, n_att_channels, kernel_size=1, bias=True),
1002
+ )
1003
+
1004
+ self.query_proj = nn.Sequential(
1005
+ ConvNorm(
1006
+ n_mel_channels,
1007
+ n_mel_channels * 2,
1008
+ kernel_size=3,
1009
+ bias=True,
1010
+ w_init_gain="relu",
1011
+ ),
1012
+ torch.nn.ReLU(),
1013
+ ConvNorm(n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True),
1014
+ torch.nn.ReLU(),
1015
+ ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True),
1016
+ )
1017
+
1018
+ def run_padded_sequence(
1019
+ self, sorted_idx, unsort_idx, lens, padded_data, recurrent_model
1020
+ ):
1021
+ """Sorts input data by previded ordering (and un-ordering) and runs the
1022
+ packed data through the recurrent model
1023
+
1024
+ Args:
1025
+ sorted_idx (torch.tensor): 1D sorting index
1026
+ unsort_idx (torch.tensor): 1D unsorting index (inverse of sorted_idx)
1027
+ lens: lengths of input data (sorted in descending order)
1028
+ padded_data (torch.tensor): input sequences (padded)
1029
+ recurrent_model (nn.Module): recurrent model to run data through
1030
+ Returns:
1031
+ hidden_vectors (torch.tensor): outputs of the RNN, in the original,
1032
+ unsorted, ordering
1033
+ """
1034
+
1035
+ # sort the data by decreasing length using provided index
1036
+ # we assume batch index is in dim=1
1037
+ padded_data = padded_data[:, sorted_idx]
1038
+ padded_data = nn.utils.rnn.pack_padded_sequence(padded_data, lens)
1039
+ hidden_vectors = recurrent_model(padded_data)[0]
1040
+ hidden_vectors, _ = nn.utils.rnn.pad_packed_sequence(hidden_vectors)
1041
+ # unsort the results at dim=1 and return
1042
+ hidden_vectors = hidden_vectors[:, unsort_idx]
1043
+ return hidden_vectors
1044
+
1045
+ def forward(
1046
+ self, queries, keys, query_lens, mask=None, key_lens=None, attn_prior=None
1047
+ ):
1048
+ """Attention mechanism for radtts. Unlike in Flowtron, we have no
1049
+ restrictions such as causality etc, since we only need this during
1050
+ training.
1051
+
1052
+ Args:
1053
+ queries (torch.tensor): B x C x T1 tensor (likely mel data)
1054
+ keys (torch.tensor): B x C2 x T2 tensor (text data)
1055
+ query_lens: lengths for sorting the queries in descending order
1056
+ mask (torch.tensor): uint8 binary mask for variable length entries
1057
+ (should be in the T2 domain)
1058
+ Output:
1059
+ attn (torch.tensor): B x 1 x T1 x T2 attention mask.
1060
+ Final dim T2 should sum to 1
1061
+ """
1062
+ temp = 0.0005
1063
+ keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
1064
+ # Beware can only do this since query_dim = attn_dim = n_mel_channels
1065
+ queries_enc = self.query_proj(queries)
1066
+
1067
+ # Gaussian Isotopic Attention
1068
+ # B x n_attn_dims x T1 x T2
1069
+ attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2
1070
+
1071
+ # compute log-likelihood from gaussian
1072
+ eps = 1e-8
1073
+ attn = -temp * attn.sum(1, keepdim=True)
1074
+ if attn_prior is not None:
1075
+ attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + eps)
1076
+
1077
+ attn_logprob = attn.clone()
1078
+
1079
+ if mask is not None:
1080
+ attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), -float("inf"))
1081
+
1082
+ attn = self.softmax(attn) # softmax along T2
1083
+ return attn, attn_logprob
configs/radtts-pp-dap-model.json ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_config": {
3
+ "output_directory": "outdir_pp_model",
4
+ "epochs": 10000000,
5
+ "optim_algo": "RAdam",
6
+ "learning_rate": 0.001,
7
+ "weight_decay": 1e-06,
8
+ "sigma": 1.0,
9
+ "iters_per_checkpoint": 1000,
10
+ "batch_size": 16,
11
+ "seed": null,
12
+ "checkpoint_path": "",
13
+ "ignore_layers": [],
14
+ "ignore_layers_warmstart": [],
15
+ "finetune_layers": [],
16
+ "include_layers": [],
17
+ "vocoder_config_path": "models/hifigan_22khz_config.json",
18
+ "vocoder_checkpoint_path": "models/hifigan_ljs_generator_v1.pt",
19
+ "log_attribute_samples": true,
20
+ "log_decoder_samples": true,
21
+ "warmstart_checkpoint_path": "outdir_pp/model_100000",
22
+ "use_amp": true,
23
+ "grad_clip_val": 1.0,
24
+ "loss_weights": {
25
+ "blank_logprob": -1,
26
+ "ctc_loss_weight": 0.1,
27
+ "binarization_loss_weight": 1.0,
28
+ "dur_loss_weight": 1.0,
29
+ "f0_loss_weight": 1.0,
30
+ "energy_loss_weight": 1.0,
31
+ "vpred_loss_weight": 1.0
32
+ },
33
+ "binarization_start_iter": 0,
34
+ "kl_loss_start_iter": 0,
35
+ "unfreeze_modules": "all"
36
+ },
37
+ "data_config": {
38
+ "training_files": {
39
+ "LJS": {
40
+ "basedir": "filelists/",
41
+ "audiodir": "wavs",
42
+ "filelist": "3speakers_ukrainian_train_filelist_dc.txt",
43
+ "lmdbpath": ""
44
+ }
45
+ },
46
+ "validation_files": {
47
+ "LJS": {
48
+ "basedir": "filelists/",
49
+ "audiodir": "wavs",
50
+ "filelist": "3speakers_ukrainian_val_filelist_dc.txt",
51
+ "lmdbpath": ""
52
+ }
53
+ },
54
+ "dur_min": 0.1,
55
+ "dur_max": 10.2,
56
+ "sampling_rate": 22050,
57
+ "filter_length": 1024,
58
+ "hop_length": 256,
59
+ "win_length": 1024,
60
+ "n_mel_channels": 80,
61
+ "mel_fmin": 0.0,
62
+ "mel_fmax": 8000.0,
63
+ "f0_min": 80.0,
64
+ "f0_max": 640.0,
65
+ "max_wav_value": 32768.0,
66
+ "use_f0": true,
67
+ "use_log_f0": 0,
68
+ "use_energy_avg": true,
69
+ "use_scaled_energy": true,
70
+ "symbol_set": "ukrainian",
71
+ "cleaner_names": [
72
+ "ukrainian_cleaners"
73
+ ],
74
+ "heteronyms_path": "tts_text_processing/heteronyms",
75
+ "phoneme_dict_path": "tts_text_processing/cmudict-0.7b",
76
+ "p_phoneme": 0.0,
77
+ "handle_phoneme": "word",
78
+ "handle_phoneme_ambiguous": "ignore",
79
+ "include_speakers": null,
80
+ "n_frames": -1,
81
+ "betabinom_cache_path": "/home/dmytro_chaplinsky/RAD-TTS/radtts-code/cache",
82
+ "lmdb_cache_path": "",
83
+ "use_attn_prior_masking": true,
84
+ "prepend_space_to_text": true,
85
+ "append_space_to_text": true,
86
+ "add_bos_eos_to_text": false,
87
+ "betabinom_scaling_factor": 1.0,
88
+ "distance_tx_unvoiced": false,
89
+ "mel_noise_scale": 0.0
90
+ },
91
+ "dist_config": {
92
+ "dist_backend": "nccl",
93
+ "dist_url": "tcp://localhost:54321"
94
+ },
95
+ "model_config": {
96
+ "n_speakers": 3,
97
+ "n_speaker_dim": 16,
98
+ "n_text": 185,
99
+ "n_text_dim": 512,
100
+ "n_flows": 8,
101
+ "n_conv_layers_per_step": 4,
102
+ "n_mel_channels": 80,
103
+ "n_hidden": 1024,
104
+ "mel_encoder_n_hidden": 512,
105
+ "dummy_speaker_embedding": false,
106
+ "n_early_size": 2,
107
+ "n_early_every": 2,
108
+ "n_group_size": 2,
109
+ "affine_model": "wavenet",
110
+ "include_modules": "decatndpmvpredapm",
111
+ "scaling_fn": "tanh",
112
+ "matrix_decomposition": "LUS",
113
+ "learn_alignments": true,
114
+ "use_speaker_emb_for_alignment": false,
115
+ "attn_straight_through_estimator": true,
116
+ "use_context_lstm": true,
117
+ "context_lstm_norm": "spectral",
118
+ "context_lstm_w_f0_and_energy": true,
119
+ "text_encoder_lstm_norm": "spectral",
120
+ "n_f0_dims": 1,
121
+ "n_energy_avg_dims": 1,
122
+ "use_first_order_features": false,
123
+ "unvoiced_bias_activation": "relu",
124
+ "decoder_use_partial_padding": true,
125
+ "decoder_use_unvoiced_bias": true,
126
+ "ap_pred_log_f0": true,
127
+ "ap_use_unvoiced_bias": false,
128
+ "ap_use_voiced_embeddings": true,
129
+ "dur_model_config": {
130
+ "name": "dap",
131
+ "hparams": {
132
+ "n_speaker_dim": 16,
133
+ "bottleneck_hparams": {
134
+ "in_dim": 512,
135
+ "reduction_factor": 16,
136
+ "norm": "weightnorm",
137
+ "non_linearity": "relu"
138
+ },
139
+ "take_log_of_input": true,
140
+ "arch_hparams": {
141
+ "out_dim": 1,
142
+ "n_layers": 2,
143
+ "n_channels": 256,
144
+ "kernel_size": 3,
145
+ "p_dropout": 0.25,
146
+ "in_dim": 48
147
+ }
148
+ }
149
+ },
150
+ "f0_model_config": {
151
+ "name": "dap",
152
+ "hparams": {
153
+ "n_speaker_dim": 16,
154
+ "bottleneck_hparams": {
155
+ "in_dim": 512,
156
+ "reduction_factor": 16,
157
+ "norm": "weightnorm",
158
+ "non_linearity": "relu"
159
+ },
160
+ "take_log_of_input": false,
161
+ "use_transformer": false,
162
+ "arch_hparams": {
163
+ "out_dim": 1,
164
+ "n_layers": 2,
165
+ "n_channels": 256,
166
+ "kernel_size": 11,
167
+ "p_dropout": 0.5,
168
+ "in_dim": 48
169
+ }
170
+ }
171
+ },
172
+ "energy_model_config": {
173
+ "name": "dap",
174
+ "hparams": {
175
+ "n_speaker_dim": 16,
176
+ "bottleneck_hparams": {
177
+ "in_dim": 512,
178
+ "reduction_factor": 16,
179
+ "norm": "weightnorm",
180
+ "non_linearity": "relu"
181
+ },
182
+ "take_log_of_input": false,
183
+ "use_transformer": false,
184
+ "arch_hparams": {
185
+ "out_dim": 1,
186
+ "n_layers": 2,
187
+ "n_channels": 256,
188
+ "kernel_size": 3,
189
+ "p_dropout": 0.25,
190
+ "in_dim": 48
191
+ }
192
+ }
193
+ },
194
+ "v_model_config": {
195
+ "name": "dap",
196
+ "hparams": {
197
+ "n_speaker_dim": 16,
198
+ "take_log_of_input": false,
199
+ "bottleneck_hparams": {
200
+ "in_dim": 512,
201
+ "reduction_factor": 16,
202
+ "norm": "weightnorm",
203
+ "non_linearity": "relu"
204
+ },
205
+ "arch_hparams": {
206
+ "out_dim": 1,
207
+ "n_layers": 2,
208
+ "n_channels": 256,
209
+ "kernel_size": 3,
210
+ "p_dropout": 0.5,
211
+ "lstm_type": "",
212
+ "use_linear": 1,
213
+ "in_dim": 48
214
+ }
215
+ }
216
+ }
217
+ }
218
+ }
data.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+
22
+ # Based on https://github.com/NVIDIA/flowtron/blob/master/data.py
23
+ # Original license text:
24
+ ###############################################################################
25
+ #
26
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
27
+ # Licensed under the Apache License, Version 2.0 (the "License");
28
+ # you may not use this file except in compliance with the License.
29
+ # You may obtain a copy of the License at
30
+ #
31
+ # http://www.apache.org/licenses/LICENSE-2.0
32
+ #
33
+ # Unless required by applicable law or agreed to in writing, software
34
+ # distributed under the License is distributed on an "AS IS" BASIS,
35
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
36
+ # See the License for the specific language governing permissions and
37
+ # limitations under the License.
38
+ #
39
+ ###############################################################################
40
+
41
+ import os
42
+ import argparse
43
+ import json
44
+ import numpy as np
45
+ import lmdb
46
+ import pickle as pkl
47
+ import torch
48
+ import torch.utils.data
49
+ from scipy.io.wavfile import read
50
+ from audio_processing import TacotronSTFT
51
+ from tts_text_processing.text_processing import TextProcessing
52
+ from scipy.stats import betabinom
53
+ from librosa import pyin
54
+ from common import update_params
55
+ from scipy.ndimage import distance_transform_edt as distance_transform
56
+
57
+
58
+ def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=0.05):
59
+ P = phoneme_count
60
+ M = mel_count
61
+ x = np.arange(0, P)
62
+ mel_text_probs = []
63
+ for i in range(1, M + 1):
64
+ a, b = scaling_factor * i, scaling_factor * (M + 1 - i)
65
+ rv = betabinom(P - 1, a, b)
66
+ mel_i_prob = rv.pmf(x)
67
+ mel_text_probs.append(mel_i_prob)
68
+ return torch.tensor(np.array(mel_text_probs))
69
+
70
+
71
+ def load_wav_to_torch(full_path):
72
+ """Loads wavdata into torch array"""
73
+ sampling_rate, data = read(full_path)
74
+ return torch.from_numpy(np.array(data)).float(), sampling_rate
75
+
76
+
77
+ class Data(torch.utils.data.Dataset):
78
+ def __init__(
79
+ self,
80
+ datasets,
81
+ filter_length,
82
+ hop_length,
83
+ win_length,
84
+ sampling_rate,
85
+ n_mel_channels,
86
+ mel_fmin,
87
+ mel_fmax,
88
+ f0_min,
89
+ f0_max,
90
+ max_wav_value,
91
+ use_f0,
92
+ use_energy_avg,
93
+ use_log_f0,
94
+ use_scaled_energy,
95
+ symbol_set,
96
+ cleaner_names,
97
+ heteronyms_path,
98
+ phoneme_dict_path,
99
+ p_phoneme,
100
+ handle_phoneme="word",
101
+ handle_phoneme_ambiguous="ignore",
102
+ speaker_ids=None,
103
+ include_speakers=None,
104
+ n_frames=-1,
105
+ use_attn_prior_masking=True,
106
+ prepend_space_to_text=True,
107
+ append_space_to_text=True,
108
+ add_bos_eos_to_text=False,
109
+ betabinom_cache_path="",
110
+ betabinom_scaling_factor=0.05,
111
+ lmdb_cache_path="",
112
+ dur_min=None,
113
+ dur_max=None,
114
+ combine_speaker_and_emotion=False,
115
+ **kwargs,
116
+ ):
117
+ self.combine_speaker_and_emotion = combine_speaker_and_emotion
118
+ self.max_wav_value = max_wav_value
119
+ self.audio_lmdb_dict = {} # dictionary of lmdbs for audio data
120
+ self.data = self.load_data(datasets)
121
+ self.distance_tx_unvoiced = False
122
+ if "distance_tx_unvoiced" in kwargs.keys():
123
+ self.distance_tx_unvoiced = kwargs["distance_tx_unvoiced"]
124
+ self.stft = TacotronSTFT(
125
+ filter_length=filter_length,
126
+ hop_length=hop_length,
127
+ win_length=win_length,
128
+ sampling_rate=sampling_rate,
129
+ n_mel_channels=n_mel_channels,
130
+ mel_fmin=mel_fmin,
131
+ mel_fmax=mel_fmax,
132
+ )
133
+
134
+ self.do_mel_scaling = kwargs.get("do_mel_scaling", True)
135
+ self.mel_noise_scale = kwargs.get("mel_noise_scale", 0.0)
136
+ self.filter_length = filter_length
137
+ self.hop_length = hop_length
138
+ self.win_length = win_length
139
+ self.mel_fmin = mel_fmin
140
+ self.mel_fmax = mel_fmax
141
+ self.f0_min = f0_min
142
+ self.f0_max = f0_max
143
+ self.use_f0 = use_f0
144
+ self.use_log_f0 = use_log_f0
145
+ self.use_energy_avg = use_energy_avg
146
+ self.use_scaled_energy = use_scaled_energy
147
+ self.sampling_rate = sampling_rate
148
+ self.tp = TextProcessing(
149
+ symbol_set,
150
+ cleaner_names,
151
+ heteronyms_path,
152
+ phoneme_dict_path,
153
+ p_phoneme=p_phoneme,
154
+ handle_phoneme=handle_phoneme,
155
+ handle_phoneme_ambiguous=handle_phoneme_ambiguous,
156
+ prepend_space_to_text=prepend_space_to_text,
157
+ append_space_to_text=append_space_to_text,
158
+ add_bos_eos_to_text=add_bos_eos_to_text,
159
+ )
160
+
161
+ self.dur_min = dur_min
162
+ self.dur_max = dur_max
163
+ if speaker_ids is None or speaker_ids == "":
164
+ self.speaker_ids = self.create_speaker_lookup_table(self.data)
165
+ else:
166
+ self.speaker_ids = speaker_ids
167
+
168
+ print("Number of files", len(self.data))
169
+ if include_speakers is not None:
170
+ for speaker_set, include in include_speakers:
171
+ self.filter_by_speakers_(speaker_set, include)
172
+ print("Number of files after speaker filtering", len(self.data))
173
+
174
+ if dur_min is not None and dur_max is not None:
175
+ self.filter_by_duration_(dur_min, dur_max)
176
+ print("Number of files after duration filtering", len(self.data))
177
+
178
+ self.use_attn_prior_masking = bool(use_attn_prior_masking)
179
+ self.prepend_space_to_text = bool(prepend_space_to_text)
180
+ self.append_space_to_text = bool(append_space_to_text)
181
+ self.betabinom_cache_path = betabinom_cache_path
182
+ self.betabinom_scaling_factor = betabinom_scaling_factor
183
+ self.lmdb_cache_path = lmdb_cache_path
184
+ if self.lmdb_cache_path != "":
185
+ self.cache_data_lmdb = lmdb.open(
186
+ self.lmdb_cache_path, readonly=True, max_readers=1024, lock=False
187
+ ).begin()
188
+
189
+ # # make sure caching path exists
190
+ # if not os.path.exists(self.betabinom_cache_path):
191
+ # os.makedirs(self.betabinom_cache_path)
192
+
193
+ print("Dataloader initialized with no augmentations")
194
+ self.speaker_map = None
195
+ if "speaker_map" in kwargs:
196
+ self.speaker_map = kwargs["speaker_map"]
197
+
198
+ def load_data(self, datasets, split="|"):
199
+ dataset = []
200
+ for dset_name, dset_dict in datasets.items():
201
+ folder_path = dset_dict["basedir"]
202
+ audiodir = dset_dict["audiodir"]
203
+ filename = dset_dict["filelist"]
204
+ audio_lmdb_key = None
205
+ if "lmdbpath" in dset_dict.keys() and len(dset_dict["lmdbpath"]) > 0:
206
+ self.audio_lmdb_dict[dset_name] = lmdb.open(
207
+ dset_dict["lmdbpath"], readonly=True, max_readers=256, lock=False
208
+ ).begin()
209
+ audio_lmdb_key = dset_name
210
+
211
+ wav_folder_prefix = os.path.join(folder_path, audiodir)
212
+ filelist_path = os.path.join(folder_path, filename)
213
+ with open(filelist_path, encoding="utf-8") as f:
214
+ data = [line.strip().split(split) for line in f]
215
+
216
+ for d in data:
217
+ emotion = "other" if len(d) == 3 else d[3]
218
+ duration = -1 if len(d) == 3 else d[4]
219
+ dataset.append(
220
+ {
221
+ "audiopath": os.path.join(wav_folder_prefix, d[0]),
222
+ "text": d[1],
223
+ "speaker": d[2] + "-" + emotion
224
+ if self.combine_speaker_and_emotion
225
+ else d[2],
226
+ "emotion": emotion,
227
+ "duration": float(duration),
228
+ "lmdb_key": audio_lmdb_key,
229
+ }
230
+ )
231
+ return dataset
232
+
233
+ def filter_by_speakers_(self, speakers, include=True):
234
+ print("Include spaker {}: {}".format(speakers, include))
235
+ if include:
236
+ self.data = [x for x in self.data if x["speaker"] in speakers]
237
+ else:
238
+ self.data = [x for x in self.data if x["speaker"] not in speakers]
239
+
240
+ def filter_by_duration_(self, dur_min, dur_max):
241
+ self.data = [
242
+ x
243
+ for x in self.data
244
+ if x["duration"] == -1
245
+ or (x["duration"] >= dur_min and x["duration"] <= dur_max)
246
+ ]
247
+
248
+ def create_speaker_lookup_table(self, data):
249
+ speaker_ids = np.sort(np.unique([x["speaker"] for x in data]))
250
+ d = {speaker_ids[i]: i for i in range(len(speaker_ids))}
251
+ print("Number of speakers:", len(d))
252
+ print("Speaker IDS", d)
253
+ return d
254
+
255
+ def f0_normalize(self, x):
256
+ if self.use_log_f0:
257
+ mask = x >= self.f0_min
258
+ x[mask] = torch.log(x[mask])
259
+ x[~mask] = 0.0
260
+
261
+ return x
262
+
263
+ def f0_denormalize(self, x):
264
+ if self.use_log_f0:
265
+ log_f0_min = np.log(self.f0_min)
266
+ mask = x >= log_f0_min
267
+ x[mask] = torch.exp(x[mask])
268
+ x[~mask] = 0.0
269
+ x[x <= 0.0] = 0.0
270
+
271
+ return x
272
+
273
+ def energy_avg_normalize(self, x):
274
+ if self.use_scaled_energy:
275
+ x = (x + 20.0) / 20.0
276
+ return x
277
+
278
+ def energy_avg_denormalize(self, x):
279
+ if self.use_scaled_energy:
280
+ x = x * 20.0 - 20.0
281
+ return x
282
+
283
+ def get_f0_pvoiced(
284
+ self,
285
+ audio,
286
+ sampling_rate=22050,
287
+ frame_length=1024,
288
+ hop_length=256,
289
+ f0_min=100,
290
+ f0_max=300,
291
+ ):
292
+ audio_norm = audio / self.max_wav_value
293
+ f0, voiced_mask, p_voiced = pyin(
294
+ audio_norm,
295
+ f0_min,
296
+ f0_max,
297
+ sampling_rate,
298
+ frame_length=frame_length,
299
+ win_length=frame_length // 2,
300
+ hop_length=hop_length,
301
+ )
302
+ f0[~voiced_mask] = 0.0
303
+ f0 = torch.FloatTensor(f0)
304
+ p_voiced = torch.FloatTensor(p_voiced)
305
+ voiced_mask = torch.FloatTensor(voiced_mask)
306
+ return f0, voiced_mask, p_voiced
307
+
308
+ def get_energy_average(self, mel):
309
+ energy_avg = mel.mean(0)
310
+ energy_avg = self.energy_avg_normalize(energy_avg)
311
+ return energy_avg
312
+
313
+ def get_mel(self, audio):
314
+ audio_norm = audio / self.max_wav_value
315
+ audio_norm = audio_norm.unsqueeze(0)
316
+ audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
317
+ melspec = self.stft.mel_spectrogram(audio_norm)
318
+ melspec = torch.squeeze(melspec, 0)
319
+ if self.do_mel_scaling:
320
+ melspec = (melspec + 5.5) / 2
321
+ if self.mel_noise_scale > 0:
322
+ melspec += torch.randn_like(melspec) * self.mel_noise_scale
323
+ return melspec
324
+
325
+ def get_speaker_id(self, speaker):
326
+ if self.speaker_map is not None and speaker in self.speaker_map:
327
+ speaker = self.speaker_map[speaker]
328
+
329
+ return torch.LongTensor([self.speaker_ids[speaker]])
330
+
331
+ def get_text(self, text):
332
+ text = self.tp.encode_text(text)
333
+ text = torch.LongTensor(text)
334
+ return text
335
+
336
+ def get_attention_prior(self, n_tokens, n_frames):
337
+ # cache the entire attn_prior by filename
338
+ if self.use_attn_prior_masking:
339
+ filename = "{}_{}".format(n_tokens, n_frames)
340
+ prior_path = os.path.join(self.betabinom_cache_path, filename)
341
+ prior_path += "_prior.pth"
342
+ if self.lmdb_cache_path != "":
343
+ attn_prior = pkl.loads(
344
+ self.cache_data_lmdb.get(prior_path.encode("ascii"))
345
+ )
346
+ elif os.path.exists(prior_path):
347
+ attn_prior = torch.load(prior_path)
348
+ else:
349
+ attn_prior = beta_binomial_prior_distribution(
350
+ n_tokens, n_frames, self.betabinom_scaling_factor
351
+ )
352
+ torch.save(attn_prior, prior_path)
353
+ else:
354
+ attn_prior = torch.ones(n_frames, n_tokens) # all ones baseline
355
+
356
+ return attn_prior
357
+
358
+ def __getitem__(self, index):
359
+ data = self.data[index]
360
+ audiopath, text = data["audiopath"], data["text"]
361
+ speaker_id = data["speaker"]
362
+
363
+ if data["lmdb_key"] is not None:
364
+ data_dict = pkl.loads(
365
+ self.audio_lmdb_dict[data["lmdb_key"]].get(audiopath.encode("ascii"))
366
+ )
367
+ audio = data_dict["audio"]
368
+ sampling_rate = data_dict["sampling_rate"]
369
+ else:
370
+ audio, sampling_rate = load_wav_to_torch(audiopath)
371
+
372
+ if sampling_rate != self.sampling_rate:
373
+ raise ValueError(
374
+ "{} SR doesn't match target {} SR".format(
375
+ sampling_rate, self.sampling_rate
376
+ )
377
+ )
378
+
379
+ mel = self.get_mel(audio)
380
+ f0 = None
381
+ p_voiced = None
382
+ voiced_mask = None
383
+ if self.use_f0:
384
+ filename = "_".join(audiopath.split("/")[-3:])
385
+ f0_path = os.path.join(self.betabinom_cache_path, filename)
386
+ f0_path += "_f0_sr{}_fl{}_hl{}_f0min{}_f0max{}_log{}.pt".format(
387
+ self.sampling_rate,
388
+ self.filter_length,
389
+ self.hop_length,
390
+ self.f0_min,
391
+ self.f0_max,
392
+ self.use_log_f0,
393
+ )
394
+
395
+ dikt = None
396
+ if len(self.lmdb_cache_path) > 0:
397
+ dikt = pkl.loads(self.cache_data_lmdb.get(f0_path.encode("ascii")))
398
+ f0 = dikt["f0"]
399
+ p_voiced = dikt["p_voiced"]
400
+ voiced_mask = dikt["voiced_mask"]
401
+ elif os.path.exists(f0_path):
402
+ try:
403
+ dikt = torch.load(f0_path)
404
+ except:
405
+ print(f"f0 loading from {f0_path} is broken, recomputing.")
406
+
407
+ if dikt is not None:
408
+ f0 = dikt["f0"]
409
+ p_voiced = dikt["p_voiced"]
410
+ voiced_mask = dikt["voiced_mask"]
411
+ else:
412
+ f0, voiced_mask, p_voiced = self.get_f0_pvoiced(
413
+ audio.cpu().numpy(),
414
+ self.sampling_rate,
415
+ self.filter_length,
416
+ self.hop_length,
417
+ self.f0_min,
418
+ self.f0_max,
419
+ )
420
+ print("saving f0 to {}".format(f0_path))
421
+ torch.save(
422
+ {"f0": f0, "voiced_mask": voiced_mask, "p_voiced": p_voiced},
423
+ f0_path,
424
+ )
425
+ if f0 is None:
426
+ raise Exception("STOP, BROKEN F0 {}".format(audiopath))
427
+
428
+ f0 = self.f0_normalize(f0)
429
+ if self.distance_tx_unvoiced:
430
+ mask = f0 <= 0.0
431
+ distance_map = np.log(distance_transform(mask))
432
+ distance_map[distance_map <= 0] = 0.0
433
+ f0 = f0 - distance_map
434
+
435
+ energy_avg = None
436
+ if self.use_energy_avg:
437
+ energy_avg = self.get_energy_average(mel)
438
+ if self.use_scaled_energy and energy_avg.min() < 0.0:
439
+ print(audiopath, "has scaled energy avg smaller than 0")
440
+
441
+ speaker_id = self.get_speaker_id(speaker_id)
442
+ text_encoded = self.get_text(text)
443
+
444
+ attn_prior = self.get_attention_prior(text_encoded.shape[0], mel.shape[1])
445
+
446
+ if not self.use_attn_prior_masking:
447
+ attn_prior = None
448
+
449
+ return {
450
+ "mel": mel,
451
+ "speaker_id": speaker_id,
452
+ "text_encoded": text_encoded,
453
+ "audiopath": audiopath,
454
+ "attn_prior": attn_prior,
455
+ "f0": f0,
456
+ "p_voiced": p_voiced,
457
+ "voiced_mask": voiced_mask,
458
+ "energy_avg": energy_avg,
459
+ }
460
+
461
+ def __len__(self):
462
+ return len(self.data)
463
+
464
+
465
+ class DataCollate:
466
+ """Zero-pads model inputs and targets given number of steps"""
467
+
468
+ def __init__(self, n_frames_per_step=1):
469
+ self.n_frames_per_step = n_frames_per_step
470
+
471
+ def __call__(self, batch):
472
+ """Collate from normalized data"""
473
+ # Right zero-pad all one-hot text sequences to max input length
474
+ input_lengths, ids_sorted_decreasing = torch.sort(
475
+ torch.LongTensor([len(x["text_encoded"]) for x in batch]),
476
+ dim=0,
477
+ descending=True,
478
+ )
479
+
480
+ max_input_len = input_lengths[0]
481
+ text_padded = torch.LongTensor(len(batch), max_input_len)
482
+ text_padded.zero_()
483
+
484
+ for i in range(len(ids_sorted_decreasing)):
485
+ text = batch[ids_sorted_decreasing[i]]["text_encoded"]
486
+ text_padded[i, : text.size(0)] = text
487
+
488
+ # Right zero-pad mel-spec
489
+ num_mel_channels = batch[0]["mel"].size(0)
490
+ max_target_len = max([x["mel"].size(1) for x in batch])
491
+
492
+ # include mel padded, gate padded and speaker ids
493
+ mel_padded = torch.FloatTensor(len(batch), num_mel_channels, max_target_len)
494
+ mel_padded.zero_()
495
+ f0_padded = None
496
+ p_voiced_padded = None
497
+ voiced_mask_padded = None
498
+ energy_avg_padded = None
499
+ if batch[0]["f0"] is not None:
500
+ f0_padded = torch.FloatTensor(len(batch), max_target_len)
501
+ f0_padded.zero_()
502
+
503
+ if batch[0]["p_voiced"] is not None:
504
+ p_voiced_padded = torch.FloatTensor(len(batch), max_target_len)
505
+ p_voiced_padded.zero_()
506
+
507
+ if batch[0]["voiced_mask"] is not None:
508
+ voiced_mask_padded = torch.FloatTensor(len(batch), max_target_len)
509
+ voiced_mask_padded.zero_()
510
+
511
+ if batch[0]["energy_avg"] is not None:
512
+ energy_avg_padded = torch.FloatTensor(len(batch), max_target_len)
513
+ energy_avg_padded.zero_()
514
+
515
+ attn_prior_padded = torch.FloatTensor(len(batch), max_target_len, max_input_len)
516
+ attn_prior_padded.zero_()
517
+
518
+ output_lengths = torch.LongTensor(len(batch))
519
+ speaker_ids = torch.LongTensor(len(batch))
520
+ audiopaths = []
521
+ for i in range(len(ids_sorted_decreasing)):
522
+ mel = batch[ids_sorted_decreasing[i]]["mel"]
523
+ mel_padded[i, :, : mel.size(1)] = mel
524
+ if batch[ids_sorted_decreasing[i]]["f0"] is not None:
525
+ f0 = batch[ids_sorted_decreasing[i]]["f0"]
526
+ f0_padded[i, : len(f0)] = f0
527
+
528
+ if batch[ids_sorted_decreasing[i]]["voiced_mask"] is not None:
529
+ voiced_mask = batch[ids_sorted_decreasing[i]]["voiced_mask"]
530
+ voiced_mask_padded[i, : len(f0)] = voiced_mask
531
+
532
+ if batch[ids_sorted_decreasing[i]]["p_voiced"] is not None:
533
+ p_voiced = batch[ids_sorted_decreasing[i]]["p_voiced"]
534
+ p_voiced_padded[i, : len(f0)] = p_voiced
535
+
536
+ if batch[ids_sorted_decreasing[i]]["energy_avg"] is not None:
537
+ energy_avg = batch[ids_sorted_decreasing[i]]["energy_avg"]
538
+ energy_avg_padded[i, : len(energy_avg)] = energy_avg
539
+
540
+ output_lengths[i] = mel.size(1)
541
+ speaker_ids[i] = batch[ids_sorted_decreasing[i]]["speaker_id"]
542
+ audiopath = batch[ids_sorted_decreasing[i]]["audiopath"]
543
+ audiopaths.append(audiopath)
544
+ cur_attn_prior = batch[ids_sorted_decreasing[i]]["attn_prior"]
545
+ if cur_attn_prior is None:
546
+ attn_prior_padded = None
547
+ else:
548
+ attn_prior_padded[
549
+ i, : cur_attn_prior.size(0), : cur_attn_prior.size(1)
550
+ ] = cur_attn_prior
551
+
552
+ return {
553
+ "mel": mel_padded,
554
+ "speaker_ids": speaker_ids,
555
+ "text": text_padded,
556
+ "input_lengths": input_lengths,
557
+ "output_lengths": output_lengths,
558
+ "audiopaths": audiopaths,
559
+ "attn_prior": attn_prior_padded,
560
+ "f0": f0_padded,
561
+ "p_voiced": p_voiced_padded,
562
+ "voiced_mask": voiced_mask_padded,
563
+ "energy_avg": energy_avg_padded,
564
+ }
565
+
566
+
567
+ # ===================================================================
568
+ # Takes directory of clean audio and makes directory of spectrograms
569
+ # Useful for making test sets
570
+ # ===================================================================
571
+ if __name__ == "__main__":
572
+ # Get defaults so it can work with no Sacred
573
+ parser = argparse.ArgumentParser()
574
+ parser.add_argument("-c", "--config", type=str, help="JSON file for configuration")
575
+ parser.add_argument("-p", "--params", nargs="+", default=[])
576
+ args = parser.parse_args()
577
+ args.rank = 0
578
+
579
+ # Parse configs. Globals nicer in this case
580
+ with open(args.config) as f:
581
+ data = f.read()
582
+
583
+ config = json.loads(data)
584
+ update_params(config, args.params)
585
+ print(config)
586
+
587
+ data_config = config["data_config"]
588
+
589
+ ignore_keys = ["training_files", "validation_files"]
590
+ trainset = Data(
591
+ data_config["training_files"],
592
+ **dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
593
+ )
594
+
595
+ valset = Data(
596
+ data_config["validation_files"],
597
+ **dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
598
+ speaker_ids=trainset.speaker_ids,
599
+ )
600
+
601
+ collate_fn = DataCollate()
602
+
603
+ for dataset in (trainset, valset):
604
+ for i, batch in enumerate(dataset):
605
+ out = batch
606
+ print("{}/{}".format(i, len(dataset)))
distributed.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original source: https://github.com/NVIDIA/waveglow/blob/master/distributed.py
2
+ #
3
+ # Original license text:
4
+ # *****************************************************************************
5
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
6
+ #
7
+ # Redistribution and use in source and binary forms, with or without
8
+ # modification, are permitted provided that the following conditions are met:
9
+ # * Redistributions of source code must retain the above copyright
10
+ # notice, this list of conditions and the following disclaimer.
11
+ # * Redistributions in binary form must reproduce the above copyright
12
+ # notice, this list of conditions and the following disclaimer in the
13
+ # documentation and/or other materials provided with the distribution.
14
+ # * Neither the name of the NVIDIA CORPORATION nor the
15
+ # names of its contributors may be used to endorse or promote products
16
+ # derived from this software without specific prior written permission.
17
+ #
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19
+ # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20
+ # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
22
+ # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
23
+ # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
24
+ # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
25
+ # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27
+ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
+ #
29
+ # *****************************************************************************
30
+
31
+ import os
32
+ import torch
33
+ import torch.distributed as dist
34
+ from torch.autograd import Variable
35
+
36
+
37
+ def reduce_tensor(tensor, num_gpus, reduce_dst=None):
38
+ if num_gpus <= 1: # pass-thru
39
+ return tensor
40
+ rt = tensor.clone()
41
+ if reduce_dst is not None:
42
+ dist.reduce(rt, reduce_dst, op=dist.ReduceOp.SUM)
43
+ else:
44
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
45
+ rt /= num_gpus
46
+ return rt
47
+
48
+
49
+ def init_distributed(rank, num_gpus, dist_backend, dist_url):
50
+ assert torch.cuda.is_available(), "Distributed mode requires CUDA."
51
+
52
+ print("> initializing distributed for rank {} out of {}".format(rank, num_gpus))
53
+
54
+ # Set cuda device so everything is done on the right GPU.
55
+ torch.cuda.set_device(rank % torch.cuda.device_count())
56
+
57
+ init_method = "tcp://"
58
+ master_ip = os.getenv("MASTER_ADDR", "localhost")
59
+ master_port = os.getenv("MASTER_PORT", "6000")
60
+ init_method += master_ip + ":" + master_port
61
+ torch.distributed.init_process_group(
62
+ backend="nccl", world_size=num_gpus, rank=rank, init_method=init_method
63
+ )
64
+
65
+
66
+ def _flatten_dense_tensors(tensors):
67
+ """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
68
+ same dense type.
69
+ Since inputs are dense, the resulting tensor will be a concatenated 1D
70
+ buffer. Element-wise operation on this buffer will be equivalent to
71
+ operating individually.
72
+ Arguments:
73
+ tensors (Iterable[Tensor]): dense tensors to flatten.
74
+ Returns:
75
+ A contiguous 1D buffer containing input tensors.
76
+ """
77
+ if len(tensors) == 1:
78
+ return tensors[0].contiguous().view(-1)
79
+ flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
80
+ return flat
81
+
82
+
83
+ def _unflatten_dense_tensors(flat, tensors):
84
+ """View a flat buffer using the sizes of tensors. Assume that tensors are of
85
+ same dense type, and that flat is given by _flatten_dense_tensors.
86
+ Arguments:
87
+ flat (Tensor): flattened dense tensors to unflatten.
88
+ tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
89
+ unflatten flat.
90
+ Returns:
91
+ Unflattened dense tensors with sizes same as tensors and values from
92
+ flat.
93
+ """
94
+ outputs = []
95
+ offset = 0
96
+ for tensor in tensors:
97
+ numel = tensor.numel()
98
+ outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
99
+ offset += numel
100
+ return tuple(outputs)
101
+
102
+
103
+ def apply_gradient_allreduce(module):
104
+ """
105
+ Modifies existing model to do gradient allreduce, but doesn't change class
106
+ so you don't need "module"
107
+ """
108
+ if not hasattr(dist, "_backend"):
109
+ module.warn_on_half = True
110
+ else:
111
+ module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
112
+
113
+ for p in module.state_dict().values():
114
+ if not torch.is_tensor(p):
115
+ continue
116
+ dist.broadcast(p, 0)
117
+
118
+ def allreduce_params():
119
+ if module.needs_reduction:
120
+ module.needs_reduction = False
121
+ buckets = {}
122
+ for param in module.parameters():
123
+ if param.requires_grad and param.grad is not None:
124
+ tp = type(param.data)
125
+ if tp not in buckets:
126
+ buckets[tp] = []
127
+ buckets[tp].append(param)
128
+ if module.warn_on_half:
129
+ if torch.cuda.HalfTensor in buckets:
130
+ print(
131
+ "WARNING: gloo dist backend for half parameters may be extremely slow."
132
+ + " It is recommended to use the NCCL backend in this case. This currently requires"
133
+ + "PyTorch built from top of tree master."
134
+ )
135
+ module.warn_on_half = False
136
+
137
+ for tp in buckets:
138
+ bucket = buckets[tp]
139
+ grads = [param.grad.data for param in bucket]
140
+ coalesced = _flatten_dense_tensors(grads)
141
+ dist.all_reduce(coalesced)
142
+ coalesced /= dist.get_world_size()
143
+ for buf, synced in zip(
144
+ grads, _unflatten_dense_tensors(coalesced, grads)
145
+ ):
146
+ buf.copy_(synced)
147
+
148
+ for param in list(module.parameters()):
149
+
150
+ def allreduce_hook(*unused):
151
+ Variable._execution_engine.queue_callback(allreduce_params)
152
+
153
+ if param.requires_grad:
154
+ param.register_hook(allreduce_hook)
155
+ dir(param)
156
+
157
+ def set_needs_reduction(self, input, output):
158
+ self.needs_reduction = True
159
+
160
+ module.register_forward_hook(set_needs_reduction)
161
+ return module
filelists/3speakers_ukrainian_train_filelist.txt ADDED
The diff for this file is too large to render. See raw diff
 
filelists/3speakers_ukrainian_train_filelist_dc.txt ADDED
The diff for this file is too large to render. See raw diff
 
filelists/3speakers_ukrainian_val_filelist.txt ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48849.wav|мандрівник+и вп+ерто відмовл+ялися.|lada
2
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48850.wav|він уз+яв сок+иру й г+острим кінц+ем поч+ав розв+ажувати з+уби.|lada
3
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48851.wav|розгр+ібши сніг, тр+охи прос+унув г+олову й пл+ечі під шатр+о.|lada
4
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48853.wav|ал+е раз зас+идівся до п+ізнього в+ечора.|lada
5
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48854.wav|то ж не дим їй +очі роз'їд+ав, бо др+ова бул+и сух+і.|lada
6
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48855.wav|вон+а не м+ала теп+ер с+умніву, що в портоса з д+амою бул+а інтр+ига.|lada
7
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48857.wav|х+очуть укра+їну з під л+яхів визвол+яти.|lada
8
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48858.wav|там жінк+ам не д+уже догодж+ають.|lada
9
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48859.wav|і б+удьте спок+ійні! якщ+о вин+о нам не спод+обається, ми пошлем+о по +інше.|lada
10
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48830.wav|мій д+івер і я м+арно чек+али на вас вч+ора й позавч+ора.|lada
11
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48831.wav|п+ане д'артаньяне, ви п+ерший.|lada
12
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48832.wav|ось мо+я в+ідповідь.|lada
13
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48833.wav|хоч той так+и й д+ійсно д+урень.|lada
14
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48834.wav|ви давн+о не гр+али?|lada
15
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48835.wav|теп+ер їм довел+ось зазн+ати д+оброї бід+и в цій кра+їні.|lada
16
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48836.wav|позавч+ора був пісн+ий день, а там подав+али лиш+е скор+омне.|lada
17
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48837.wav|і не потреб+уєте всі роб+ити.|lada
18
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48838.wav|у рук+ах у н+еї бул+а нов+а зап+иска міл+еді.|lada
19
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48839.wav|і ч+етверо др+узів одн+им г+олосом повтор+или прис+ягу, запропон+овану від д'артаньяна.|lada
20
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48841.wav|іг+уменя ст+ала сл+ухати ув+ажніш, тр+охи пожвав+іла й всміхн+улася.|lada
21
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48842.wav|так ти цьог+о не роб+и й не втрач+айся, бо одн+аково не пом+оже.|lada
22
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48843.wav|туд+и і рв+еться н+аша душ+а, кол+и х+очеш зн+ати.|lada
23
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48844.wav|б+олісно всміх+ався і трясс+я, як у проп+асниці.|lada
24
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48845.wav|я прив+ів тоб+і др+угого, сказ+ав д'артаньян.|lada
25
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48846.wav|я поб+ачу корол+я сьог+одні увечорі, ал+е вас не р+аджу наверт+атись йому на в+ічі.|lada
26
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48847.wav|ще весел+іш почал+и тод+і гомон+іти.|lada
27
+ /home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48848.wav|споч+атку вон+а нарахув+ала двох, п+отім п'ять, нар+ешті в+ісім.|lada
28
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68540.wav|кр+аще вже пуст+ити соб+і к+улю в л+оба і відр+азу покл+асти всь+ому край.|mykyta
29
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68541.wav|ал+е сидяч+и за стол+ом, при п+иві, знов поч+ув як+есь невдов+олення.|mykyta
30
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68543.wav|на шабл+ях!|mykyta
31
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68544.wav|вон+а пров+адила з незнай+омим д+уже жв+аву розм+ову.|mykyta
32
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68545.wav|офіц+ер взяв зі ст+олу вк+азані пап+ери, под+ав їх і, н+изько вклонившися, в+ийшов.|mykyta
33
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68546.wav|аж с+умно йому ст+ало.|mykyta
34
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68547.wav|житт+я не ласк+аве з багать+ох прич+ин.|mykyta
35
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68548.wav|так, звич+айно тр+еба, ств+ердила корол+ева.|mykyta
36
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68549.wav|вон+а, не зверн+увши ув+аги на цей д+ок+ір, промовл+яла д+алі.|mykyta
37
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68550.wav|зда+ється, не дочув+аю.|mykyta
38
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68551.wav|відв+ажний і завз+ятий, він не вп+ерше в+ажив сво+ї+++м житт+ям у так+их приг+одах.|mykyta
39
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68552.wav|як ч+асом, г+аво.|mykyta
40
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68553.wav|мій друг араміс, що оц+е сто+їть п+еред вами, здоб+ув легк+ого вд+ара шпад+ою в р+уку.|mykyta
41
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68554.wav|я знав+ець свог+о д+іла.|mykyta
42
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68556.wav|пог+онич леж+ав на с+анк+ах, а соб+аки шв+идко б+ігли пр+ямо до хат+ини.|mykyta
43
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68557.wav|міл+еді к+инулась до нього.|mykyta
44
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68558.wav|хто тоб+і сказ+ав?|mykyta
45
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68559.wav|то й не поваж+ай, не зляк+аєш.|mykyta
46
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68560.wav|поясн+іть, бо я не розум+ію, що ви х+очете сказ+ати.|mykyta
47
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68561.wav|шрам наздогн+ав свій п+оїзд к+оло вис+оких вор+іт п+ана гвинтовки.|mykyta
48
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68562.wav|що ж він так+е?|mykyta
49
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68563.wav|що це так+е? спит+ав портос.|mykyta
50
+ /home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68565.wav|див+іться, тут зн+ову втруч+алася ц+ерква, з+авжд+и та ц+ерква.|mykyta
51
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67117.wav|а чолов+ік цьог+о жахл+ивого створ+іння ще жив+ий? зацік+авився араміс.|tetiana
52
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67118.wav|ви, дик, не ч+ули ці+єї т+иші.|tetiana
53
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67119.wav|він баг+атий на р+ок+и, шан+обу й сл+аву вел+ику.|tetiana
54
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67120.wav|в +осени зар+ані, ск+оро п+ісля сп+аса под+ався макс+им до київа.|tetiana
55
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67121.wav|а до н+еї п+ишеш?|tetiana
56
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67122.wav|я, б+ачилось, н+авіть не люб+ив її так, як л+юблять зак+охані.|tetiana
57
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67123.wav|юрб+а провал+ила тим ч+асом м+имо петр+а.|tetiana
58
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67124.wav|хай так! приєдн+ався швайц+арець.|tetiana
59
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67125.wav|к+онюх підтв+ердив кардин+алові слов+а мушкет+ерів про атоса.|tetiana
60
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67126.wav|що завин+ив, те б+уду терп+іти.|tetiana
61
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67127.wav|чи є у вас тр+охи піск+у? ск+ільки? він показ+ав їй свій міш+ок.|tetiana
62
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67128.wav|я скаж+у це т+ільки том+у, хто прозирн+е в мо+ю д+ушу.|tetiana
63
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67129.wav|і в оц+ій хв+илі вон+а не міркув+ала тог+о.|tetiana
64
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67130.wav|ти б+ачив сво+ю ж?|tetiana
65
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67132.wav|прот+е, тр+еба скл+асти як+ийсь плян б+ою, пром+овив араміс.|tetiana
66
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67133.wav|огого! д+уже швидк+а! так я теб+е й пуст+ив до богун+а!|tetiana
67
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67134.wav|бог з тоб+ою, добр+одію!|tetiana
68
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67135.wav|киценька! ти т+ямиш її?|tetiana
69
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67136.wav|розм+ова поверн+ула на вес+еле.|tetiana
70
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67137.wav|розум+іється, сказ+ала вон+а к+оротко.|tetiana
71
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67138.wav|їй с+оромно ст+ало, що на оч+ах у всіх її так знев+ажено, і вон+а знен+авиділа фреду.|tetiana
72
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67139.wav|це бул+о м+ужнє обл+иччя.|tetiana
73
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67140.wav|св+екра зн+ала м+ало, не ч+асто й б+ачилася з ним, на рік раз+ів зо три.|tetiana
74
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67141.wav|спр+ава ця єсть особл+ивої делікатности.|tetiana
75
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67143.wav|я так отощ+ав, не +ївши зр+анку, що й р+адуватись незд+ужаю.|tetiana
76
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67145.wav|т+ільки в+ірна будь мен+і.|tetiana
77
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67146.wav|п'єр піш+ов за н+ею і відч+алив.|tetiana
78
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67147.wav|і по цих слов+ах к+инув торб+инку із з+олотом в р+ічку.|tetiana
79
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67148.wav|а, він в пор+ядку, сказ+ав нач+альник, та з чуд+овою рекоменд+ацією.|tetiana
80
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67149.wav|тод+і підожд+іть тр+ошки, зачек+айте.|tetiana
81
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67150.wav|із як+ими вістьми? пит+ає г+етьман.|tetiana
82
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67151.wav|стар+ий сарабр+ин міг л+егко пот+ішитися.|tetiana
83
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67152.wav|о, я, нещ+асний!|tetiana
84
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67153.wav|кр+оки в сальоні.|tetiana
85
+ /home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67154.wav|щоб н+ашим ворог+ам бул+о т+яжко!|tetiana
filelists/3speakers_ukrainian_val_filelist_dc.txt ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48849.wav|мандрівник+и вп+ерто відмовл+ялися.|lada
2
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48850.wav|він уз+яв сок+иру й г+острим кінц+ем поч+ав розв+ажувати з+уби.|lada
3
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48851.wav|розгр+ібши сніг, тр+охи прос+унув г+олову й пл+ечі під шатр+о.|lada
4
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48853.wav|ал+е раз зас+идівся до п+ізнього в+ечора.|lada
5
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48854.wav|то ж не дим їй +очі роз'їд+ав, бо др+ова бул+и сух+і.|lada
6
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48855.wav|вон+а не м+ала теп+ер с+умніву, що в портоса з д+амою бул+а інтр+ига.|lada
7
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48857.wav|х+очуть укра+їну з під л+яхів визвол+яти.|lada
8
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48858.wav|там жінк+ам не д+уже догодж+ають.|lada
9
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48859.wav|і б+удьте спок+ійні! якщ+о вин+о нам не спод+обається, ми пошлем+о по +інше.|lada
10
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48830.wav|мій д+івер і я м+арно чек+али на вас вч+ора й позавч+ора.|lada
11
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48831.wav|п+ане д'артаньяне, ви п+ерший.|lada
12
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48832.wav|ось мо+я в+ідповідь.|lada
13
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48833.wav|хоч той так+и й д+ійсно д+урень.|lada
14
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48834.wav|ви давн+о не гр+али?|lada
15
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48835.wav|теп+ер їм довел+ось зазн+ати д+оброї бід+и в цій кра+їні.|lada
16
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48836.wav|позавч+ора був пісн+ий день, а там подав+али лиш+е скор+омне.|lada
17
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48837.wav|і не потреб+уєте всі роб+ити.|lada
18
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48838.wav|у рук+ах у н+еї бул+а нов+а зап+иска міл+еді.|lada
19
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48839.wav|і ч+етверо др+узів одн+им г+олосом повтор+или прис+ягу, запропон+овану від д'артаньяна.|lada
20
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48841.wav|іг+уменя ст+ала сл+ухати ув+ажніш, тр+охи пожвав+іла й всміхн+улася.|lada
21
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48842.wav|так ти цьог+о не роб+и й не втрач+айся, бо одн+аково не пом+оже.|lada
22
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48843.wav|туд+и і рв+еться н+аша душ+а, кол+и х+очеш зн+ати.|lada
23
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48844.wav|б+олісно всміх+ався і трясс+я, як у проп+асниці.|lada
24
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48845.wav|я прив+ів тоб+і др+угого, сказ+ав д'артаньян.|lada
25
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48846.wav|я поб+ачу корол+я сьог+одні увечорі, ал+е вас не р+аджу наверт+атись йому на в+ічі.|lada
26
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48847.wav|ще весел+іш почал+и тод+і гомон+іти.|lada
27
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48848.wav|споч+атку вон+а нарахув+ала двох, п+отім п'ять, нар+ешті в+ісім.|lada
28
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68540.wav|кр+аще вже пуст+ити соб+і к+улю в л+оба і відр+азу покл+асти всь+ому край.|mykyta
29
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68541.wav|ал+е сидяч+и за стол+ом, при п+иві, знов поч+ув як+есь невдов+олення.|mykyta
30
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68543.wav|на шабл+ях!|mykyta
31
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68544.wav|вон+а пров+адила з незнай+омим д+уже жв+аву розм+ову.|mykyta
32
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68545.wav|офіц+ер взяв зі ст+олу вк+азані пап+ери, под+ав їх і, н+изько вклонившися, в+ийшов.|mykyta
33
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68546.wav|аж с+умно йому ст+ало.|mykyta
34
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68547.wav|житт+я не ласк+аве з багать+ох прич+ин.|mykyta
35
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68548.wav|так, звич+айно тр+еба, ств+ердила корол+ева.|mykyta
36
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68549.wav|вон+а, не зверн+увши ув+аги на цей д+ок+ір, промовл+яла д+алі.|mykyta
37
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68550.wav|зда+ється, не дочув+аю.|mykyta
38
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68551.wav|відв+ажний і завз+ятий, він не вп+ерше в+ажив сво+ї+++м житт+ям у так+их приг+одах.|mykyta
39
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68552.wav|як ч+асом, г+аво.|mykyta
40
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68553.wav|мій друг араміс, що оц+е сто+їть п+еред вами, здоб+ув легк+ого вд+ара шпад+ою в р+уку.|mykyta
41
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68554.wav|я знав+ець свог+о д+іла.|mykyta
42
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68556.wav|пог+онич леж+ав на с+анк+ах, а соб+аки шв+идко б+ігли пр+ямо до хат+ини.|mykyta
43
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68557.wav|міл+еді к+инулась до нього.|mykyta
44
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68558.wav|хто тоб+і сказ+ав?|mykyta
45
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68559.wav|то й не поваж+ай, не зляк+аєш.|mykyta
46
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68560.wav|поясн+іть, бо я не розум+ію, що ви х+очете сказ+ати.|mykyta
47
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68561.wav|шрам наздогн+ав свій п+оїзд к+оло вис+оких вор+іт п+ана гвинтовки.|mykyta
48
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68562.wav|що ж він так+е?|mykyta
49
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68563.wav|що це так+е? спит+ав портос.|mykyta
50
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68565.wav|див+іться, тут зн+ову втруч+алася ц+ерква, з+авжд+и та ц+ерква.|mykyta
51
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67117.wav|а чолов+ік цьог+о жахл+ивого створ+іння ще жив+ий? зацік+авився араміс.|tetiana
52
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67118.wav|ви, дик, не ч+ули ці+єї т+иші.|tetiana
53
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67119.wav|він баг+атий на р+ок+и, шан+обу й сл+аву вел+ику.|tetiana
54
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67120.wav|в +осени зар+ані, ск+оро п+ісля сп+аса под+ався макс+им до київа.|tetiana
55
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67121.wav|а до н+еї п+ишеш?|tetiana
56
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67122.wav|я, б+ачилось, н+авіть не люб+ив її так, як л+юблять зак+охані.|tetiana
57
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67123.wav|юрб+а провал+ила тим ч+асом м+имо петр+а.|tetiana
58
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67124.wav|хай так! приєдн+ався швайц+арець.|tetiana
59
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67125.wav|к+онюх підтв+ердив кардин+алові слов+а мушкет+ерів про атоса.|tetiana
60
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67126.wav|що завин+ив, те б+уду терп+іти.|tetiana
61
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67127.wav|чи є у вас тр+охи піск+у? ск+ільки? він показ+ав їй свій міш+ок.|tetiana
62
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67128.wav|я скаж+у це т+ільки том+у, хто прозирн+е в мо+ю д+ушу.|tetiana
63
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67129.wav|і в оц+ій хв+илі вон+а не міркув+ала тог+о.|tetiana
64
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67130.wav|ти б+ачив сво+ю ж?|tetiana
65
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67132.wav|прот+е, тр+еба скл+асти як+ийсь плян б+ою, пром+овив араміс.|tetiana
66
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67133.wav|огого! д+уже швидк+а! так я теб+е й пуст+ив до богун+а!|tetiana
67
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67134.wav|бог з тоб+ою, добр+одію!|tetiana
68
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67135.wav|киценька! ти т+ямиш її?|tetiana
69
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67136.wav|розм+ова поверн+ула на вес+еле.|tetiana
70
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67137.wav|розум+іється, сказ+ала вон+а к+оротко.|tetiana
71
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67138.wav|їй с+оромно ст+ало, що на оч+ах у всіх її так знев+ажено, і вон+а знен+авиділа фреду.|tetiana
72
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67139.wav|це бул+о м+ужнє обл+иччя.|tetiana
73
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67140.wav|св+екра зн+ала м+ало, не ч+асто й б+ачилася з ним, на рік раз+ів зо три.|tetiana
74
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67141.wav|спр+ава ця єсть особл+ивої делікатности.|tetiana
75
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67143.wav|я так отощ+ав, не +ївши зр+анку, що й р+адуватись незд+ужаю.|tetiana
76
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67145.wav|т+ільки в+ірна будь мен+і.|tetiana
77
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67146.wav|п'єр піш+ов за н+ею і відч+алив.|tetiana
78
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67147.wav|і по цих слов+ах к+инув торб+инку із з+олотом в р+ічку.|tetiana
79
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67148.wav|а, він в пор+ядку, сказ+ав нач+альник, та з чуд+овою рекоменд+ацією.|tetiana
80
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67149.wav|тод+і підожд+іть тр+ошки, зачек+айте.|tetiana
81
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67150.wav|із як+ими вістьми? пит+ає г+етьман.|tetiana
82
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67151.wav|стар+ий сарабр+ин міг л+егко пот+ішитися.|tetiana
83
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67152.wav|о, я, нещ+асний!|tetiana
84
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67153.wav|кр+оки в сальоні.|tetiana
85
+ /home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67154.wav|щоб н+ашим ворог+ам бул+о т+яжко!|tetiana
loss.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn import functional as F
24
+ from common import get_mask_from_lengths
25
+
26
+
27
+ def compute_flow_loss(
28
+ z, log_det_W_list, log_s_list, n_elements, n_dims, mask, sigma=1.0
29
+ ):
30
+ log_det_W_total = 0.0
31
+ for i, log_s in enumerate(log_s_list):
32
+ if i == 0:
33
+ log_s_total = torch.sum(log_s * mask)
34
+ if len(log_det_W_list):
35
+ log_det_W_total = log_det_W_list[i]
36
+ else:
37
+ log_s_total = log_s_total + torch.sum(log_s * mask)
38
+ if len(log_det_W_list):
39
+ log_det_W_total += log_det_W_list[i]
40
+
41
+ if len(log_det_W_list):
42
+ log_det_W_total *= n_elements
43
+
44
+ z = z * mask
45
+ prior_NLL = torch.sum(z * z) / (2 * sigma * sigma)
46
+
47
+ loss = prior_NLL - log_s_total - log_det_W_total
48
+
49
+ denom = n_elements * n_dims
50
+ loss = loss / denom
51
+ loss_prior = prior_NLL / denom
52
+ return loss, loss_prior
53
+
54
+
55
+ def compute_regression_loss(x_hat, x, mask, name=False):
56
+ x = x[:, None] if len(x.shape) == 2 else x # add channel dim
57
+ mask = mask[:, None] if len(mask.shape) == 2 else mask # add channel dim
58
+ assert len(x.shape) == len(mask.shape)
59
+
60
+ x = x * mask
61
+ x_hat = x_hat * mask
62
+
63
+ if name == "vpred":
64
+ loss = F.binary_cross_entropy_with_logits(x_hat, x, reduction="sum")
65
+ else:
66
+ loss = F.mse_loss(x_hat, x, reduction="sum")
67
+ loss = loss / mask.sum()
68
+
69
+ loss_dict = {"loss_{}".format(name): loss}
70
+
71
+ return loss_dict
72
+
73
+
74
+ class AttributePredictionLoss(torch.nn.Module):
75
+ def __init__(self, name, model_config, loss_weight, sigma=1.0):
76
+ super(AttributePredictionLoss, self).__init__()
77
+ self.name = name
78
+ self.sigma = sigma
79
+ self.model_name = model_config["name"]
80
+ self.loss_weight = loss_weight
81
+ self.n_group_size = 1
82
+ if "n_group_size" in model_config["hparams"]:
83
+ self.n_group_size = model_config["hparams"]["n_group_size"]
84
+
85
+ def forward(self, model_output, lens):
86
+ mask = get_mask_from_lengths(lens // self.n_group_size)
87
+ mask = mask[:, None].float()
88
+ loss_dict = {}
89
+ if "z" in model_output:
90
+ n_elements = lens.sum() // self.n_group_size
91
+ n_dims = model_output["z"].size(1)
92
+
93
+ loss, loss_prior = compute_flow_loss(
94
+ model_output["z"],
95
+ model_output["log_det_W_list"],
96
+ model_output["log_s_list"],
97
+ n_elements,
98
+ n_dims,
99
+ mask,
100
+ self.sigma,
101
+ )
102
+ loss_dict = {
103
+ "loss_{}".format(self.name): (loss, self.loss_weight),
104
+ "loss_prior_{}".format(self.name): (loss_prior, 0.0),
105
+ }
106
+ elif "x_hat" in model_output:
107
+ loss_dict = compute_regression_loss(
108
+ model_output["x_hat"], model_output["x"], mask, self.name
109
+ )
110
+ for k, v in loss_dict.items():
111
+ loss_dict[k] = (v, self.loss_weight)
112
+
113
+ if len(loss_dict) == 0:
114
+ raise Exception("loss not supported")
115
+
116
+ return loss_dict
117
+
118
+
119
+ class AttentionCTCLoss(torch.nn.Module):
120
+ def __init__(self, blank_logprob=-1):
121
+ super(AttentionCTCLoss, self).__init__()
122
+ self.log_softmax = torch.nn.LogSoftmax(dim=3)
123
+ self.blank_logprob = blank_logprob
124
+ self.CTCLoss = nn.CTCLoss(zero_infinity=True)
125
+
126
+ def forward(self, attn_logprob, in_lens, out_lens):
127
+ key_lens = in_lens
128
+ query_lens = out_lens
129
+ attn_logprob_padded = F.pad(
130
+ input=attn_logprob, pad=(1, 0, 0, 0, 0, 0, 0, 0), value=self.blank_logprob
131
+ )
132
+ cost_total = 0.0
133
+ for bid in range(attn_logprob.shape[0]):
134
+ target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
135
+ curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[
136
+ : query_lens[bid], :, : key_lens[bid] + 1
137
+ ]
138
+ curr_logprob = self.log_softmax(curr_logprob[None])[0]
139
+ ctc_cost = self.CTCLoss(
140
+ curr_logprob,
141
+ target_seq,
142
+ input_lengths=query_lens[bid : bid + 1],
143
+ target_lengths=key_lens[bid : bid + 1],
144
+ )
145
+ cost_total += ctc_cost
146
+ cost = cost_total / attn_logprob.shape[0]
147
+ return cost
148
+
149
+
150
+ class AttentionBinarizationLoss(torch.nn.Module):
151
+ def __init__(self):
152
+ super(AttentionBinarizationLoss, self).__init__()
153
+
154
+ def forward(self, hard_attention, soft_attention):
155
+ log_sum = torch.log(soft_attention[hard_attention == 1]).sum()
156
+ return -log_sum / hard_attention.sum()
157
+
158
+
159
+ class RADTTSLoss(torch.nn.Module):
160
+ def __init__(
161
+ self,
162
+ sigma=1.0,
163
+ n_group_size=1,
164
+ dur_model_config=None,
165
+ f0_model_config=None,
166
+ energy_model_config=None,
167
+ vpred_model_config=None,
168
+ loss_weights=None,
169
+ ):
170
+ super(RADTTSLoss, self).__init__()
171
+ self.sigma = sigma
172
+ self.n_group_size = n_group_size
173
+ self.loss_weights = loss_weights
174
+ self.attn_ctc_loss = AttentionCTCLoss(
175
+ blank_logprob=loss_weights.get("blank_logprob", -1)
176
+ )
177
+ self.loss_fns = {}
178
+ if dur_model_config is not None:
179
+ self.loss_fns["duration_model_outputs"] = AttributePredictionLoss(
180
+ "duration", dur_model_config, loss_weights["dur_loss_weight"]
181
+ )
182
+
183
+ if f0_model_config is not None:
184
+ self.loss_fns["f0_model_outputs"] = AttributePredictionLoss(
185
+ "f0", f0_model_config, loss_weights["f0_loss_weight"], sigma=1.0
186
+ )
187
+
188
+ if energy_model_config is not None:
189
+ self.loss_fns["energy_model_outputs"] = AttributePredictionLoss(
190
+ "energy", energy_model_config, loss_weights["energy_loss_weight"]
191
+ )
192
+
193
+ if vpred_model_config is not None:
194
+ self.loss_fns["vpred_model_outputs"] = AttributePredictionLoss(
195
+ "vpred", vpred_model_config, loss_weights["vpred_loss_weight"]
196
+ )
197
+
198
+ def forward(self, model_output, in_lens, out_lens):
199
+ loss_dict = {}
200
+ if len(model_output["z_mel"]):
201
+ n_elements = out_lens.sum() // self.n_group_size
202
+ mask = get_mask_from_lengths(out_lens // self.n_group_size)
203
+ mask = mask[:, None].float()
204
+ n_dims = model_output["z_mel"].size(1)
205
+ loss_mel, loss_prior_mel = compute_flow_loss(
206
+ model_output["z_mel"],
207
+ model_output["log_det_W_list"],
208
+ model_output["log_s_list"],
209
+ n_elements,
210
+ n_dims,
211
+ mask,
212
+ self.sigma,
213
+ )
214
+ loss_dict["loss_mel"] = (loss_mel, 1.0) # loss, weight
215
+ loss_dict["loss_prior_mel"] = (loss_prior_mel, 0.0)
216
+
217
+ ctc_cost = self.attn_ctc_loss(model_output["attn_logprob"], in_lens, out_lens)
218
+ loss_dict["loss_ctc"] = (ctc_cost, self.loss_weights["ctc_loss_weight"])
219
+
220
+ for k in model_output:
221
+ if k in self.loss_fns:
222
+ if model_output[k] is not None and len(model_output[k]) > 0:
223
+ t_lens = in_lens if "dur" in k else out_lens
224
+ mout = model_output[k]
225
+ for loss_name, v in self.loss_fns[k](mout, t_lens).items():
226
+ loss_dict[loss_name] = v
227
+
228
+ return loss_dict
partialconv1d.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified partialconv source code based on implementation from
2
+ # https://github.com/NVIDIA/partialconv/blob/master/models/partialconv2d.py
3
+ ###############################################################################
4
+ # BSD 3-Clause License
5
+ #
6
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
7
+ #
8
+ # Author & Contact: Guilin Liu ([email protected])
9
+ ###############################################################################
10
+
11
+ # Original Author & Contact: Guilin Liu ([email protected])
12
+ # Modified by Kevin Shih ([email protected])
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import nn
17
+
18
+
19
+ class PartialConv1d(nn.Conv1d):
20
+ def __init__(self, *args, **kwargs):
21
+ self.multi_channel = False
22
+ self.return_mask = False
23
+ super(PartialConv1d, self).__init__(*args, **kwargs)
24
+
25
+ self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0])
26
+ self.slide_winsize = (
27
+ self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2]
28
+ )
29
+
30
+ self.last_size = (None, None, None)
31
+ self.update_mask = None
32
+ self.mask_ratio = None
33
+
34
+ @torch.jit.ignore
35
+ def forward(self, input: torch.Tensor, mask_in: torch.Tensor = None):
36
+ """
37
+ input: standard input to a 1D conv
38
+ mask_in: binary mask for valid values, same shape as input
39
+ """
40
+ assert len(input.shape) == 3
41
+ # if a mask is input, or tensor shape changed, update mask ratio
42
+ if mask_in is not None or self.last_size != tuple(input.shape):
43
+ self.last_size = tuple(input.shape)
44
+ with torch.no_grad():
45
+ if self.weight_maskUpdater.type() != input.type():
46
+ self.weight_maskUpdater = self.weight_maskUpdater.to(input)
47
+ if mask_in is None:
48
+ mask = torch.ones(1, 1, input.data.shape[2]).to(input)
49
+ else:
50
+ mask = mask_in
51
+ self.update_mask = F.conv1d(
52
+ mask,
53
+ self.weight_maskUpdater,
54
+ bias=None,
55
+ stride=self.stride,
56
+ padding=self.padding,
57
+ dilation=self.dilation,
58
+ groups=1,
59
+ )
60
+ # for mixed precision training, change 1e-8 to 1e-6
61
+ self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-6)
62
+ self.update_mask = torch.clamp(self.update_mask, 0, 1)
63
+ self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
64
+ raw_out = super(PartialConv1d, self).forward(
65
+ torch.mul(input, mask) if mask_in is not None else input
66
+ )
67
+ if self.bias is not None:
68
+ bias_view = self.bias.view(1, self.out_channels, 1)
69
+ output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
70
+ output = torch.mul(output, self.update_mask)
71
+ else:
72
+ output = torch.mul(raw_out, self.mask_ratio)
73
+
74
+ if self.return_mask:
75
+ return output, self.update_mask
76
+ else:
77
+ return output
radam.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original source taken from https://github.com/LiyuanLucasLiu/RAdam
2
+ #
3
+ # Copyright 2019 Liyuan Liu
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import math
17
+
18
+ import torch
19
+
20
+ # pylint: disable=no-name-in-module
21
+ from torch.optim.optimizer import Optimizer
22
+
23
+
24
+ class RAdam(Optimizer):
25
+ """RAdam optimizer"""
26
+
27
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
28
+ """
29
+ Init
30
+
31
+ :param params: parameters to optimize
32
+ :param lr: learning rate
33
+ :param betas: beta
34
+ :param eps: numerical precision
35
+ :param weight_decay: weight decay weight
36
+ """
37
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
38
+ self.buffer = [[None, None, None] for _ in range(10)]
39
+ super().__init__(params, defaults)
40
+
41
+ def step(self, closure=None):
42
+ loss = None
43
+ if closure is not None:
44
+ loss = closure()
45
+
46
+ for group in self.param_groups:
47
+ for p in group["params"]:
48
+ if p.grad is None:
49
+ continue
50
+ grad = p.grad.data.float()
51
+ if grad.is_sparse:
52
+ raise RuntimeError("RAdam does not support sparse gradients")
53
+
54
+ p_data_fp32 = p.data.float()
55
+
56
+ state = self.state[p]
57
+
58
+ if len(state) == 0:
59
+ state["step"] = 0
60
+ state["exp_avg"] = torch.zeros_like(p_data_fp32)
61
+ state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
62
+ else:
63
+ state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
64
+ state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
65
+
66
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
67
+ beta1, beta2 = group["betas"]
68
+
69
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
70
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
71
+
72
+ state["step"] += 1
73
+ buffered = self.buffer[int(state["step"] % 10)]
74
+ if state["step"] == buffered[0]:
75
+ N_sma, step_size = buffered[1], buffered[2]
76
+ else:
77
+ buffered[0] = state["step"]
78
+ beta2_t = beta2 ** state["step"]
79
+ N_sma_max = 2 / (1 - beta2) - 1
80
+ N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
81
+ buffered[1] = N_sma
82
+
83
+ # more conservative since it's an approximated value
84
+ if N_sma >= 5:
85
+ step_size = (
86
+ group["lr"]
87
+ * math.sqrt(
88
+ (1 - beta2_t)
89
+ * (N_sma - 4)
90
+ / (N_sma_max - 4)
91
+ * (N_sma - 2)
92
+ / N_sma
93
+ * N_sma_max
94
+ / (N_sma_max - 2)
95
+ )
96
+ / (1 - beta1 ** state["step"])
97
+ )
98
+ else:
99
+ step_size = group["lr"] / (1 - beta1 ** state["step"])
100
+ buffered[2] = step_size
101
+
102
+ if group["weight_decay"] != 0:
103
+ p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
104
+
105
+ # more conservative since it's an approximated value
106
+ if N_sma >= 5:
107
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
108
+ p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
109
+ else:
110
+ p_data_fp32.add_(-step_size, exp_avg)
111
+
112
+ p.data.copy_(p_data_fp32)
113
+
114
+ return loss
radtts.py ADDED
@@ -0,0 +1,936 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: MIT
3
+ #
4
+ # Permission is hereby granted, free of charge, to any person obtaining a
5
+ # copy of this software and associated documentation files (the "Software"),
6
+ # to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
+ # and/or sell copies of the Software, and to permit persons to whom the
9
+ # Software is furnished to do so, subject to the following conditions:
10
+ #
11
+ # The above copyright notice and this permission notice shall be included in
12
+ # all copies or substantial portions of the Software.
13
+ #
14
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
+ # DEALINGS IN THE SOFTWARE.
21
+ import torch
22
+ from torch import nn
23
+ from common import Encoder, LengthRegulator, ConvAttention
24
+ from common import Invertible1x1ConvLUS, Invertible1x1Conv
25
+ from common import AffineTransformationLayer, LinearNorm, ExponentialClass
26
+ from common import get_mask_from_lengths
27
+ from attribute_prediction_model import get_attribute_prediction_model
28
+ from alignment import mas_width1 as mas
29
+
30
+
31
+ class FlowStep(nn.Module):
32
+ def __init__(
33
+ self,
34
+ n_mel_channels,
35
+ n_context_dim,
36
+ n_layers,
37
+ affine_model="simple_conv",
38
+ scaling_fn="exp",
39
+ matrix_decomposition="",
40
+ affine_activation="softplus",
41
+ use_partial_padding=False,
42
+ cache_inverse=False,
43
+ ):
44
+ super(FlowStep, self).__init__()
45
+ if matrix_decomposition == "LUS":
46
+ self.invtbl_conv = Invertible1x1ConvLUS(
47
+ n_mel_channels, cache_inverse=cache_inverse
48
+ )
49
+ else:
50
+ self.invtbl_conv = Invertible1x1Conv(
51
+ n_mel_channels, cache_inverse=cache_inverse
52
+ )
53
+
54
+ self.affine_tfn = AffineTransformationLayer(
55
+ n_mel_channels,
56
+ n_context_dim,
57
+ n_layers,
58
+ affine_model=affine_model,
59
+ scaling_fn=scaling_fn,
60
+ affine_activation=affine_activation,
61
+ use_partial_padding=use_partial_padding,
62
+ )
63
+
64
+ def enable_inverse_cache(self):
65
+ self.invtbl_conv.cache_inverse = True
66
+
67
+ def forward(self, z, context, inverse=False, seq_lens=None):
68
+ if inverse: # for inference z-> mel
69
+ z = self.affine_tfn(z, context, inverse, seq_lens=seq_lens)
70
+ z = self.invtbl_conv(z, inverse)
71
+ return z
72
+ else: # training mel->z
73
+ z, log_det_W = self.invtbl_conv(z)
74
+ z, log_s = self.affine_tfn(z, context, seq_lens=seq_lens)
75
+ return z, log_det_W, log_s
76
+
77
+
78
+ class RADTTS(torch.nn.Module):
79
+ def __init__(
80
+ self,
81
+ n_speakers,
82
+ n_speaker_dim,
83
+ n_text,
84
+ n_text_dim,
85
+ n_flows,
86
+ n_conv_layers_per_step,
87
+ n_mel_channels,
88
+ n_hidden,
89
+ mel_encoder_n_hidden,
90
+ dummy_speaker_embedding,
91
+ n_early_size,
92
+ n_early_every,
93
+ n_group_size,
94
+ affine_model,
95
+ dur_model_config,
96
+ f0_model_config,
97
+ energy_model_config,
98
+ v_model_config=None,
99
+ include_modules="dec",
100
+ scaling_fn="exp",
101
+ matrix_decomposition="",
102
+ learn_alignments=False,
103
+ affine_activation="softplus",
104
+ attn_use_CTC=True,
105
+ use_speaker_emb_for_alignment=False,
106
+ use_context_lstm=False,
107
+ context_lstm_norm=None,
108
+ text_encoder_lstm_norm=None,
109
+ n_f0_dims=0,
110
+ n_energy_avg_dims=0,
111
+ context_lstm_w_f0_and_energy=True,
112
+ use_first_order_features=False,
113
+ unvoiced_bias_activation="",
114
+ ap_pred_log_f0=False,
115
+ **kwargs,
116
+ ):
117
+ super(RADTTS, self).__init__()
118
+ assert n_early_size % 2 == 0
119
+ self.do_mel_descaling = kwargs.get("do_mel_descaling", True)
120
+ self.n_mel_channels = n_mel_channels
121
+ self.n_f0_dims = n_f0_dims # >= 1 to trains with f0
122
+ self.n_energy_avg_dims = n_energy_avg_dims # >= 1 trains with energy
123
+ self.decoder_use_partial_padding = kwargs.get(
124
+ "decoder_use_partial_padding", True
125
+ )
126
+ self.n_speaker_dim = n_speaker_dim
127
+ assert self.n_speaker_dim % 2 == 0
128
+ self.speaker_embedding = torch.nn.Embedding(n_speakers, self.n_speaker_dim)
129
+ self.embedding = torch.nn.Embedding(n_text, n_text_dim)
130
+ self.flows = torch.nn.ModuleList()
131
+ self.encoder = Encoder(
132
+ encoder_embedding_dim=n_text_dim,
133
+ norm_fn=nn.InstanceNorm1d,
134
+ lstm_norm_fn=text_encoder_lstm_norm,
135
+ )
136
+ self.dummy_speaker_embedding = dummy_speaker_embedding
137
+ self.learn_alignments = learn_alignments
138
+ self.affine_activation = affine_activation
139
+ self.include_modules = include_modules
140
+ self.attn_use_CTC = bool(attn_use_CTC)
141
+ self.use_speaker_emb_for_alignment = use_speaker_emb_for_alignment
142
+ self.use_context_lstm = bool(use_context_lstm)
143
+ self.context_lstm_norm = context_lstm_norm
144
+ self.context_lstm_w_f0_and_energy = context_lstm_w_f0_and_energy
145
+ self.length_regulator = LengthRegulator()
146
+ self.use_first_order_features = bool(use_first_order_features)
147
+ self.decoder_use_unvoiced_bias = kwargs.get("decoder_use_unvoiced_bias", True)
148
+ self.ap_pred_log_f0 = ap_pred_log_f0
149
+ self.ap_use_unvoiced_bias = kwargs.get("ap_use_unvoiced_bias", True)
150
+ self.attn_straight_through_estimator = kwargs.get(
151
+ "attn_straight_through_estimator", False
152
+ )
153
+ if "atn" in include_modules or "dec" in include_modules:
154
+ if self.learn_alignments:
155
+ if self.use_speaker_emb_for_alignment:
156
+ self.attention = ConvAttention(
157
+ n_mel_channels, n_text_dim + self.n_speaker_dim
158
+ )
159
+ else:
160
+ self.attention = ConvAttention(n_mel_channels, n_text_dim)
161
+
162
+ self.n_flows = n_flows
163
+ self.n_group_size = n_group_size
164
+
165
+ n_flowstep_cond_dims = (
166
+ self.n_speaker_dim
167
+ + (n_text_dim + n_f0_dims + n_energy_avg_dims) * n_group_size
168
+ )
169
+
170
+ if self.use_context_lstm:
171
+ n_in_context_lstm = self.n_speaker_dim + n_text_dim * n_group_size
172
+ n_context_lstm_hidden = int(
173
+ (self.n_speaker_dim + n_text_dim * n_group_size) / 2
174
+ )
175
+
176
+ if self.context_lstm_w_f0_and_energy:
177
+ n_in_context_lstm = n_f0_dims + n_energy_avg_dims + n_text_dim
178
+ n_in_context_lstm *= n_group_size
179
+ n_in_context_lstm += self.n_speaker_dim
180
+
181
+ n_context_hidden = n_f0_dims + n_energy_avg_dims + n_text_dim
182
+ n_context_hidden = n_context_hidden * n_group_size / 2
183
+ n_context_hidden = self.n_speaker_dim + n_context_hidden
184
+ n_context_hidden = int(n_context_hidden)
185
+
186
+ n_flowstep_cond_dims = (
187
+ self.n_speaker_dim + n_text_dim * n_group_size
188
+ )
189
+
190
+ self.context_lstm = torch.nn.LSTM(
191
+ input_size=n_in_context_lstm,
192
+ hidden_size=n_context_lstm_hidden,
193
+ num_layers=1,
194
+ batch_first=True,
195
+ bidirectional=True,
196
+ )
197
+
198
+ if context_lstm_norm is not None:
199
+ if "spectral" in context_lstm_norm:
200
+ print("Applying spectral norm to context encoder LSTM")
201
+ lstm_norm_fn_pntr = torch.nn.utils.spectral_norm
202
+ elif "weight" in context_lstm_norm:
203
+ print("Applying weight norm to context encoder LSTM")
204
+ lstm_norm_fn_pntr = torch.nn.utils.weight_norm
205
+
206
+ self.context_lstm = lstm_norm_fn_pntr(
207
+ self.context_lstm, "weight_hh_l0"
208
+ )
209
+ self.context_lstm = lstm_norm_fn_pntr(
210
+ self.context_lstm, "weight_hh_l0_reverse"
211
+ )
212
+
213
+ if self.n_group_size > 1:
214
+ self.unfold_params = {
215
+ "kernel_size": (n_group_size, 1),
216
+ "stride": n_group_size,
217
+ "padding": 0,
218
+ "dilation": 1,
219
+ }
220
+ self.unfold = nn.Unfold(**self.unfold_params)
221
+
222
+ self.exit_steps = []
223
+ self.n_early_size = n_early_size
224
+ n_mel_channels = n_mel_channels * n_group_size
225
+
226
+ for i in range(self.n_flows):
227
+ if i > 0 and i % n_early_every == 0: # early exitting
228
+ n_mel_channels -= self.n_early_size
229
+ self.exit_steps.append(i)
230
+
231
+ self.flows.append(
232
+ FlowStep(
233
+ n_mel_channels,
234
+ n_flowstep_cond_dims,
235
+ n_conv_layers_per_step,
236
+ affine_model,
237
+ scaling_fn,
238
+ matrix_decomposition,
239
+ affine_activation=affine_activation,
240
+ use_partial_padding=self.decoder_use_partial_padding,
241
+ )
242
+ )
243
+
244
+ if "dpm" in include_modules:
245
+ dur_model_config["hparams"]["n_speaker_dim"] = n_speaker_dim
246
+ self.dur_pred_layer = get_attribute_prediction_model(dur_model_config)
247
+
248
+ self.use_unvoiced_bias = False
249
+ self.use_vpred_module = False
250
+ self.ap_use_voiced_embeddings = kwargs.get("ap_use_voiced_embeddings", True)
251
+
252
+ if self.decoder_use_unvoiced_bias or self.ap_use_unvoiced_bias:
253
+ assert unvoiced_bias_activation in {"relu", "exp"}
254
+ self.use_unvoiced_bias = True
255
+ if unvoiced_bias_activation == "relu":
256
+ unvbias_nonlin = nn.ReLU()
257
+ elif unvoiced_bias_activation == "exp":
258
+ unvbias_nonlin = ExponentialClass()
259
+ else:
260
+ exit(1) # we won't reach here anyway due to the assertion
261
+ self.unvoiced_bias_module = nn.Sequential(
262
+ LinearNorm(n_text_dim, 1), unvbias_nonlin
263
+ )
264
+
265
+ # all situations in which the vpred module is necessary
266
+ if (
267
+ self.ap_use_voiced_embeddings
268
+ or self.use_unvoiced_bias
269
+ or "vpred" in include_modules
270
+ ):
271
+ self.use_vpred_module = True
272
+
273
+ if self.use_vpred_module:
274
+ v_model_config["hparams"]["n_speaker_dim"] = n_speaker_dim
275
+ self.v_pred_module = get_attribute_prediction_model(v_model_config)
276
+ # 4 embeddings, first two are scales, second two are biases
277
+ if self.ap_use_voiced_embeddings:
278
+ self.v_embeddings = torch.nn.Embedding(4, n_text_dim)
279
+
280
+ if "apm" in include_modules:
281
+ f0_model_config["hparams"]["n_speaker_dim"] = n_speaker_dim
282
+ energy_model_config["hparams"]["n_speaker_dim"] = n_speaker_dim
283
+ if self.use_first_order_features:
284
+ f0_model_config["hparams"]["n_in_dim"] = 2
285
+ energy_model_config["hparams"]["n_in_dim"] = 2
286
+ if (
287
+ "spline_flow_params" in f0_model_config["hparams"]
288
+ and f0_model_config["hparams"]["spline_flow_params"] is not None
289
+ ):
290
+ f0_model_config["hparams"]["spline_flow_params"][
291
+ "n_in_channels"
292
+ ] = 2
293
+ if (
294
+ "spline_flow_params" in energy_model_config["hparams"]
295
+ and energy_model_config["hparams"]["spline_flow_params"] is not None
296
+ ):
297
+ energy_model_config["hparams"]["spline_flow_params"][
298
+ "n_in_channels"
299
+ ] = 2
300
+ else:
301
+ if (
302
+ "spline_flow_params" in f0_model_config["hparams"]
303
+ and f0_model_config["hparams"]["spline_flow_params"] is not None
304
+ ):
305
+ f0_model_config["hparams"]["spline_flow_params"][
306
+ "n_in_channels"
307
+ ] = f0_model_config["hparams"]["n_in_dim"]
308
+ if (
309
+ "spline_flow_params" in energy_model_config["hparams"]
310
+ and energy_model_config["hparams"]["spline_flow_params"] is not None
311
+ ):
312
+ energy_model_config["hparams"]["spline_flow_params"][
313
+ "n_in_channels"
314
+ ] = energy_model_config["hparams"]["n_in_dim"]
315
+
316
+ self.f0_pred_module = get_attribute_prediction_model(f0_model_config)
317
+ self.energy_pred_module = get_attribute_prediction_model(
318
+ energy_model_config
319
+ )
320
+
321
+ def is_attribute_unconditional(self):
322
+ """
323
+ returns true if the decoder is conditioned on neither energy nor F0
324
+ """
325
+ return self.n_f0_dims == 0 and self.n_energy_avg_dims == 0
326
+
327
+ def encode_speaker(self, spk_ids):
328
+ spk_ids = spk_ids * 0 if self.dummy_speaker_embedding else spk_ids
329
+ spk_vecs = self.speaker_embedding(spk_ids)
330
+ return spk_vecs
331
+
332
+ def encode_text(self, text, in_lens):
333
+ # text_embeddings: b x len_text x n_text_dim
334
+ text_embeddings = self.embedding(text).transpose(1, 2)
335
+ # text_enc: b x n_text_dim x encoder_dim (512)
336
+ if in_lens is None:
337
+ text_enc = self.encoder.infer(text_embeddings).transpose(1, 2)
338
+ else:
339
+ text_enc = self.encoder(text_embeddings, in_lens).transpose(1, 2)
340
+
341
+ return text_enc, text_embeddings
342
+
343
+ def preprocess_context(
344
+ self, context, speaker_vecs, out_lens=None, f0=None, energy_avg=None
345
+ ):
346
+ if self.n_group_size > 1:
347
+ # unfolding zero-padded values
348
+ context = self.unfold(context.unsqueeze(-1))
349
+ if f0 is not None:
350
+ f0 = self.unfold(f0[:, None, :, None])
351
+ if energy_avg is not None:
352
+ energy_avg = self.unfold(energy_avg[:, None, :, None])
353
+ speaker_vecs = speaker_vecs[..., None].expand(-1, -1, context.shape[2])
354
+ context_w_spkvec = torch.cat((context, speaker_vecs), 1)
355
+
356
+ if self.use_context_lstm:
357
+ if self.context_lstm_w_f0_and_energy:
358
+ if f0 is not None:
359
+ context_w_spkvec = torch.cat((context_w_spkvec, f0), 1)
360
+
361
+ if energy_avg is not None:
362
+ context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1)
363
+
364
+ unfolded_out_lens = (out_lens // self.n_group_size).long().cpu()
365
+ unfolded_out_lens_packed = nn.utils.rnn.pack_padded_sequence(
366
+ context_w_spkvec.transpose(1, 2),
367
+ unfolded_out_lens,
368
+ batch_first=True,
369
+ enforce_sorted=False,
370
+ )
371
+ self.context_lstm.flatten_parameters()
372
+ context_lstm_packed_output, _ = self.context_lstm(unfolded_out_lens_packed)
373
+ context_lstm_padded_output, _ = nn.utils.rnn.pad_packed_sequence(
374
+ context_lstm_packed_output, batch_first=True
375
+ )
376
+ context_w_spkvec = context_lstm_padded_output.transpose(1, 2)
377
+
378
+ if not self.context_lstm_w_f0_and_energy:
379
+ if f0 is not None:
380
+ context_w_spkvec = torch.cat((context_w_spkvec, f0), 1)
381
+
382
+ if energy_avg is not None:
383
+ context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1)
384
+
385
+ return context_w_spkvec
386
+
387
+ def enable_inverse_cache(self):
388
+ for flow_step in self.flows:
389
+ flow_step.enable_inverse_cache()
390
+
391
+ def fold(self, mel):
392
+ """Inverse of the self.unfold(mel.unsqueeze(-1)) operation used for the
393
+ grouping or "squeeze" operation on input
394
+
395
+ Args:
396
+ mel: B x C x T tensor of temporal data
397
+ """
398
+ mel = nn.functional.fold(
399
+ mel, output_size=(mel.shape[2] * self.n_group_size, 1), **self.unfold_params
400
+ ).squeeze(-1)
401
+ return mel
402
+
403
+ def binarize_attention(self, attn, in_lens, out_lens):
404
+ """For training purposes only. Binarizes attention with MAS. These will
405
+ no longer recieve a gradient
406
+ Args:
407
+ attn: B x 1 x max_mel_len x max_text_len
408
+ """
409
+ b_size = attn.shape[0]
410
+ with torch.no_grad():
411
+ attn_cpu = attn.data.cpu().numpy()
412
+ attn_out = torch.zeros_like(attn)
413
+ for ind in range(b_size):
414
+ hard_attn = mas(attn_cpu[ind, 0, : out_lens[ind], : in_lens[ind]])
415
+ attn_out[ind, 0, : out_lens[ind], : in_lens[ind]] = torch.tensor(
416
+ hard_attn, device=attn.get_device()
417
+ )
418
+ return attn_out
419
+
420
+ def get_first_order_features(self, feats, out_lens, dilation=1):
421
+ """
422
+ feats: b x max_length
423
+ out_lens: b-dim
424
+ """
425
+ # add an extra column
426
+ feats_extended_R = torch.cat(
427
+ (feats, torch.zeros_like(feats[:, 0:dilation])), dim=1
428
+ )
429
+ feats_extended_L = torch.cat(
430
+ (torch.zeros_like(feats[:, 0:dilation]), feats), dim=1
431
+ )
432
+ dfeats_R = feats_extended_R[:, dilation:] - feats
433
+ dfeats_L = feats - feats_extended_L[:, 0:-dilation]
434
+
435
+ return (dfeats_R + dfeats_L) * 0.5
436
+
437
+ def apply_voice_mask_to_text(self, text_enc, voiced_mask):
438
+ """
439
+ text_enc: b x C x N
440
+ voiced_mask: b x N
441
+ """
442
+ voiced_mask = voiced_mask.unsqueeze(1)
443
+ voiced_embedding_s = self.v_embeddings.weight[0:1, :, None]
444
+ unvoiced_embedding_s = self.v_embeddings.weight[1:2, :, None]
445
+ voiced_embedding_b = self.v_embeddings.weight[2:3, :, None]
446
+ unvoiced_embedding_b = self.v_embeddings.weight[3:4, :, None]
447
+ scale = torch.sigmoid(
448
+ voiced_embedding_s * voiced_mask + unvoiced_embedding_s * (1 - voiced_mask)
449
+ )
450
+ bias = 0.1 * torch.tanh(
451
+ voiced_embedding_b * voiced_mask + unvoiced_embedding_b * (1 - voiced_mask)
452
+ )
453
+ return text_enc * scale + bias
454
+
455
+ def forward(
456
+ self,
457
+ mel,
458
+ speaker_ids,
459
+ text,
460
+ in_lens,
461
+ out_lens,
462
+ binarize_attention=False,
463
+ attn_prior=None,
464
+ f0=None,
465
+ energy_avg=None,
466
+ voiced_mask=None,
467
+ p_voiced=None,
468
+ ):
469
+ speaker_vecs = self.encode_speaker(speaker_ids)
470
+ text_enc, text_embeddings = self.encode_text(text, in_lens)
471
+
472
+ log_s_list, log_det_W_list, z_mel = [], [], []
473
+ attn = None
474
+ attn_soft = None
475
+ attn_hard = None
476
+ if "atn" in self.include_modules or "dec" in self.include_modules:
477
+ # make sure to do the alignments before folding
478
+ attn_mask = get_mask_from_lengths(in_lens)[..., None] == 0
479
+
480
+ text_embeddings_for_attn = text_embeddings
481
+ if self.use_speaker_emb_for_alignment:
482
+ speaker_vecs_expd = speaker_vecs[:, :, None].expand(
483
+ -1, -1, text_embeddings.shape[2]
484
+ )
485
+ text_embeddings_for_attn = torch.cat(
486
+ (text_embeddings_for_attn, speaker_vecs_expd.detach()), 1
487
+ )
488
+
489
+ # attn_mask shld be 1 for unsd t-steps in text_enc_w_spkvec tensor
490
+ attn_soft, attn_logprob = self.attention(
491
+ mel,
492
+ text_embeddings_for_attn,
493
+ out_lens,
494
+ attn_mask,
495
+ key_lens=in_lens,
496
+ attn_prior=attn_prior,
497
+ )
498
+
499
+ if binarize_attention:
500
+ attn = self.binarize_attention(attn_soft, in_lens, out_lens)
501
+ attn_hard = attn
502
+ if self.attn_straight_through_estimator:
503
+ attn_hard = attn_soft + (attn_hard - attn_soft).detach()
504
+ else:
505
+ attn = attn_soft
506
+
507
+ context = torch.bmm(text_enc, attn.squeeze(1).transpose(1, 2))
508
+
509
+ f0_bias = 0
510
+ # unvoiced bias forward pass
511
+ if self.use_unvoiced_bias:
512
+ f0_bias = self.unvoiced_bias_module(context.permute(0, 2, 1))
513
+ f0_bias = -f0_bias[..., 0]
514
+ f0_bias = f0_bias * (~voiced_mask.bool()).float()
515
+
516
+ # mel decoder forward pass
517
+ if "dec" in self.include_modules:
518
+ if self.n_group_size > 1:
519
+ # might truncate some frames at the end, but that's ok
520
+ # sometimes referred to as the "squeeeze" operation
521
+ # invert this by calling self.fold(mel_or_z)
522
+ mel = self.unfold(mel.unsqueeze(-1))
523
+ z_out = []
524
+ # where context is folded
525
+ # mask f0 in case values are interpolated
526
+
527
+ if f0 is None:
528
+ f0_aug = None
529
+ else:
530
+ if self.decoder_use_unvoiced_bias:
531
+ f0_aug = f0 * voiced_mask + f0_bias
532
+ else:
533
+ f0_aug = f0 * voiced_mask
534
+
535
+ context_w_spkvec = self.preprocess_context(
536
+ context, speaker_vecs, out_lens, f0_aug, energy_avg
537
+ )
538
+
539
+ log_s_list, log_det_W_list, z_out = [], [], []
540
+ unfolded_seq_lens = out_lens // self.n_group_size
541
+ for i, flow_step in enumerate(self.flows):
542
+ if i in self.exit_steps:
543
+ z = mel[:, : self.n_early_size]
544
+ z_out.append(z)
545
+ mel = mel[:, self.n_early_size :]
546
+ mel, log_det_W, log_s = flow_step(
547
+ mel, context_w_spkvec, seq_lens=unfolded_seq_lens
548
+ )
549
+ log_s_list.append(log_s)
550
+ log_det_W_list.append(log_det_W)
551
+
552
+ z_out.append(mel)
553
+ z_mel = torch.cat(z_out, 1)
554
+
555
+ # duration predictor forward pass
556
+ duration_model_outputs = None
557
+ if "dpm" in self.include_modules:
558
+ if attn_hard is None:
559
+ attn_hard = self.binarize_attention(attn_soft, in_lens, out_lens)
560
+
561
+ # convert hard attention to durations
562
+ attn_hard_reduced = attn_hard.sum(2)[:, 0, :]
563
+ duration_model_outputs = self.dur_pred_layer(
564
+ torch.detach(text_enc),
565
+ torch.detach(speaker_vecs),
566
+ torch.detach(attn_hard_reduced.float()),
567
+ in_lens,
568
+ )
569
+
570
+ # f0, energy, vpred predictors forward pass
571
+ f0_model_outputs = None
572
+ energy_model_outputs = None
573
+ vpred_model_outputs = None
574
+ if "apm" in self.include_modules:
575
+ if attn_hard is None:
576
+ attn_hard = self.binarize_attention(attn_soft, in_lens, out_lens)
577
+
578
+ # convert hard attention to durations
579
+ if binarize_attention:
580
+ text_enc_time_expanded = context.clone()
581
+ else:
582
+ text_enc_time_expanded = torch.bmm(
583
+ text_enc, attn_hard.squeeze(1).transpose(1, 2)
584
+ )
585
+
586
+ if self.use_vpred_module:
587
+ # unvoiced bias requires voiced mask prediction
588
+ vpred_model_outputs = self.v_pred_module(
589
+ torch.detach(text_enc_time_expanded),
590
+ torch.detach(speaker_vecs),
591
+ torch.detach(voiced_mask),
592
+ out_lens,
593
+ )
594
+
595
+ # affine transform context using voiced mask
596
+ if self.ap_use_voiced_embeddings:
597
+ text_enc_time_expanded = self.apply_voice_mask_to_text(
598
+ text_enc_time_expanded, voiced_mask
599
+ )
600
+
601
+ # whether to use the unvoiced bias in the attribute predictor
602
+ # circumvent in-place modification
603
+ f0_target = f0.clone()
604
+ if self.ap_use_unvoiced_bias:
605
+ f0_target = torch.detach(f0_target * voiced_mask + f0_bias)
606
+ else:
607
+ f0_target = torch.detach(f0_target)
608
+
609
+ # fit to log f0 in f0 predictor
610
+ f0_target[voiced_mask.bool()] = torch.log(f0_target[voiced_mask.bool()])
611
+ f0_target = f0_target / 6 # scale to ~ [0, 1] in log space
612
+ energy_avg = energy_avg * 2 - 1 # scale to ~ [-1, 1]
613
+
614
+ if self.use_first_order_features:
615
+ df0 = self.get_first_order_features(f0_target, out_lens)
616
+ denergy_avg = self.get_first_order_features(energy_avg, out_lens)
617
+
618
+ f0_voiced = torch.cat((f0_target[:, None], df0[:, None]), dim=1)
619
+ energy_avg = torch.cat(
620
+ (energy_avg[:, None], denergy_avg[:, None]), dim=1
621
+ )
622
+
623
+ f0_voiced = f0_voiced * 3 # scale to ~ 1 std
624
+ energy_avg = energy_avg * 3 # scale to ~ 1 std
625
+ else:
626
+ f0_voiced = f0_target * 2 # scale to ~ 1 std
627
+ energy_avg = energy_avg * 1.4 # scale to ~ 1 std
628
+
629
+ f0_model_outputs = self.f0_pred_module(
630
+ text_enc_time_expanded, torch.detach(speaker_vecs), f0_voiced, out_lens
631
+ )
632
+
633
+ energy_model_outputs = self.energy_pred_module(
634
+ text_enc_time_expanded, torch.detach(speaker_vecs), energy_avg, out_lens
635
+ )
636
+
637
+ outputs = {
638
+ "z_mel": z_mel,
639
+ "log_det_W_list": log_det_W_list,
640
+ "log_s_list": log_s_list,
641
+ "duration_model_outputs": duration_model_outputs,
642
+ "f0_model_outputs": f0_model_outputs,
643
+ "energy_model_outputs": energy_model_outputs,
644
+ "vpred_model_outputs": vpred_model_outputs,
645
+ "attn_soft": attn_soft,
646
+ "attn": attn,
647
+ "text_embeddings": text_embeddings,
648
+ "attn_logprob": attn_logprob,
649
+ }
650
+
651
+ return outputs
652
+
653
+ def infer(
654
+ self,
655
+ speaker_id,
656
+ text,
657
+ sigma,
658
+ sigma_dur=0.8,
659
+ sigma_f0=0.8,
660
+ sigma_energy=0.8,
661
+ token_dur_scaling=1.0,
662
+ token_duration_max=100,
663
+ speaker_id_text=None,
664
+ speaker_id_attributes=None,
665
+ dur=None,
666
+ f0=None,
667
+ energy_avg=None,
668
+ voiced_mask=None,
669
+ f0_mean=0.0,
670
+ f0_std=0.0,
671
+ energy_mean=0.0,
672
+ energy_std=0.0,
673
+ use_cuda=False,
674
+ ):
675
+ batch_size = text.shape[0]
676
+ n_tokens = text.shape[1]
677
+ spk_vec = self.encode_speaker(speaker_id)
678
+ spk_vec_text, spk_vec_attributes = spk_vec, spk_vec
679
+ if speaker_id_text is not None:
680
+ spk_vec_text = self.encode_speaker(speaker_id_text)
681
+ if speaker_id_attributes is not None:
682
+ spk_vec_attributes = self.encode_speaker(speaker_id_attributes)
683
+
684
+ txt_enc, txt_emb = self.encode_text(text, None)
685
+
686
+ if dur is None:
687
+ # get token durations
688
+ if use_cuda:
689
+ z_dur = torch.cuda.FloatTensor(batch_size, 1, n_tokens)
690
+ else:
691
+ z_dur = torch.FloatTensor(batch_size, 1, n_tokens)
692
+
693
+ z_dur = z_dur.normal_() * sigma_dur
694
+
695
+ dur = self.dur_pred_layer.infer(z_dur, txt_enc, spk_vec_text)
696
+ if dur.shape[-1] < txt_enc.shape[-1]:
697
+ to_pad = txt_enc.shape[-1] - dur.shape[2]
698
+ pad_fn = nn.ReplicationPad1d((0, to_pad))
699
+ dur = pad_fn(dur)
700
+ dur = dur[:, 0]
701
+ dur = dur.clamp(0, token_duration_max)
702
+ dur = dur * token_dur_scaling if token_dur_scaling > 0 else dur
703
+ dur = (dur + 0.5).floor().int()
704
+
705
+ out_lens = dur.sum(1).long().cpu() if dur.shape[0] != 1 else [dur.sum(1)]
706
+ max_n_frames = max(out_lens)
707
+
708
+ out_lens = torch.LongTensor(out_lens).to(txt_enc.device)
709
+
710
+ # get attributes f0, energy, vpred, etc)
711
+ txt_enc_time_expanded = self.length_regulator(
712
+ txt_enc.transpose(1, 2), dur
713
+ ).transpose(1, 2)
714
+
715
+ if not self.is_attribute_unconditional():
716
+ # if explicitly modeling attributes
717
+ if voiced_mask is None:
718
+ if self.use_vpred_module:
719
+ # get logits
720
+ voiced_mask = self.v_pred_module.infer(
721
+ None, txt_enc_time_expanded, spk_vec_attributes
722
+ )
723
+ voiced_mask = torch.sigmoid(voiced_mask[:, 0]) > 0.5
724
+ voiced_mask = voiced_mask.float()
725
+
726
+ ap_txt_enc_time_expanded = txt_enc_time_expanded
727
+ # voice mask augmentation only used for attribute prediction
728
+ if self.ap_use_voiced_embeddings:
729
+ ap_txt_enc_time_expanded = self.apply_voice_mask_to_text(
730
+ txt_enc_time_expanded, voiced_mask
731
+ )
732
+
733
+ f0_bias = 0
734
+ # unvoiced bias forward pass
735
+ if self.use_unvoiced_bias:
736
+ f0_bias = self.unvoiced_bias_module(
737
+ txt_enc_time_expanded.permute(0, 2, 1)
738
+ )
739
+ f0_bias = -f0_bias[..., 0]
740
+ f0_bias = f0_bias * (~voiced_mask.bool()).float()
741
+
742
+ if f0 is None:
743
+ n_f0_feature_channels = 2 if self.use_first_order_features else 1
744
+
745
+ if use_cuda:
746
+ z_f0 = (
747
+ torch.cuda.FloatTensor(
748
+ batch_size, n_f0_feature_channels, max_n_frames
749
+ ).normal_()
750
+ * sigma_f0
751
+ )
752
+ else:
753
+ z_f0 = (
754
+ torch.FloatTensor(
755
+ batch_size, n_f0_feature_channels, max_n_frames
756
+ ).normal_()
757
+ * sigma_f0
758
+ )
759
+
760
+ f0 = self.infer_f0(
761
+ z_f0,
762
+ ap_txt_enc_time_expanded,
763
+ spk_vec_attributes,
764
+ voiced_mask,
765
+ out_lens,
766
+ )[:, 0]
767
+
768
+ if f0_mean > 0.0:
769
+ vmask_bool = voiced_mask.bool()
770
+ f0_mu, f0_sigma = f0[vmask_bool].mean(), f0[vmask_bool].std()
771
+ f0[vmask_bool] = (f0[vmask_bool] - f0_mu) / f0_sigma
772
+ f0_std = f0_std if f0_std > 0 else f0_sigma
773
+ f0[vmask_bool] = f0[vmask_bool] * f0_std + f0_mean
774
+
775
+ if energy_avg is None:
776
+ n_energy_feature_channels = 2 if self.use_first_order_features else 1
777
+ if use_cuda:
778
+ z_energy_avg = (
779
+ torch.cuda.FloatTensor(
780
+ batch_size, n_energy_feature_channels, max_n_frames
781
+ ).normal_()
782
+ * sigma_energy
783
+ )
784
+ else:
785
+ z_energy_avg = (
786
+ torch.FloatTensor(
787
+ batch_size, n_energy_feature_channels, max_n_frames
788
+ ).normal_()
789
+ * sigma_energy
790
+ )
791
+ energy_avg = self.infer_energy(
792
+ z_energy_avg, ap_txt_enc_time_expanded, spk_vec, out_lens
793
+ )[:, 0]
794
+
795
+ # replication pad, because ungrouping with different group sizes
796
+ # may lead to mismatched lengths
797
+ if energy_avg.shape[1] < out_lens[0]:
798
+ to_pad = out_lens[0] - energy_avg.shape[1]
799
+ pad_fn = nn.ReplicationPad1d((0, to_pad))
800
+ f0 = pad_fn(f0[None])[0]
801
+ energy_avg = pad_fn(energy_avg[None])[0]
802
+ if f0.shape[1] < out_lens[0]:
803
+ to_pad = out_lens[0] - f0.shape[1]
804
+ pad_fn = nn.ReplicationPad1d((0, to_pad))
805
+ f0 = pad_fn(f0[None])[0]
806
+
807
+ if self.decoder_use_unvoiced_bias:
808
+ context_w_spkvec = self.preprocess_context(
809
+ txt_enc_time_expanded,
810
+ spk_vec,
811
+ out_lens,
812
+ f0 * voiced_mask + f0_bias,
813
+ energy_avg,
814
+ )
815
+ else:
816
+ context_w_spkvec = self.preprocess_context(
817
+ txt_enc_time_expanded,
818
+ spk_vec,
819
+ out_lens,
820
+ f0 * voiced_mask,
821
+ energy_avg,
822
+ )
823
+ else:
824
+ context_w_spkvec = self.preprocess_context(
825
+ txt_enc_time_expanded, spk_vec, out_lens, None, None
826
+ )
827
+
828
+ if use_cuda:
829
+ residual = torch.cuda.FloatTensor(
830
+ batch_size, 80 * self.n_group_size, max_n_frames // self.n_group_size
831
+ )
832
+ else:
833
+ residual = torch.FloatTensor(
834
+ batch_size, 80 * self.n_group_size, max_n_frames // self.n_group_size
835
+ )
836
+
837
+ residual = residual.normal_() * sigma
838
+
839
+ # map from z sample to data
840
+ exit_steps_stack = self.exit_steps.copy()
841
+ mel = residual[:, len(exit_steps_stack) * self.n_early_size :]
842
+ remaining_residual = residual[:, : len(exit_steps_stack) * self.n_early_size]
843
+ unfolded_seq_lens = out_lens // self.n_group_size
844
+ for i, flow_step in enumerate(reversed(self.flows)):
845
+ curr_step = len(self.flows) - i - 1
846
+ mel = flow_step(
847
+ mel, context_w_spkvec, inverse=True, seq_lens=unfolded_seq_lens
848
+ )
849
+ if len(exit_steps_stack) > 0 and curr_step == exit_steps_stack[-1]:
850
+ # concatenate the next chunk of z
851
+ exit_steps_stack.pop()
852
+ residual_to_add = remaining_residual[
853
+ :, len(exit_steps_stack) * self.n_early_size :
854
+ ]
855
+ remaining_residual = remaining_residual[
856
+ :, : len(exit_steps_stack) * self.n_early_size
857
+ ]
858
+ mel = torch.cat((residual_to_add, mel), 1)
859
+
860
+ if self.n_group_size > 1:
861
+ mel = self.fold(mel)
862
+ if self.do_mel_descaling:
863
+ mel = mel * 2 - 5.5
864
+
865
+ return {
866
+ "mel": mel,
867
+ "dur": dur,
868
+ "f0": f0,
869
+ "energy_avg": energy_avg,
870
+ "voiced_mask": voiced_mask,
871
+ }
872
+
873
+ def infer_f0(
874
+ self, residual, txt_enc_time_expanded, spk_vec, voiced_mask=None, lens=None
875
+ ):
876
+ f0 = self.f0_pred_module.infer(residual, txt_enc_time_expanded, spk_vec, lens)
877
+
878
+ if voiced_mask is not None and len(voiced_mask.shape) == 2:
879
+ voiced_mask = voiced_mask[:, None]
880
+
881
+ # constants
882
+ if self.ap_pred_log_f0:
883
+ if self.use_first_order_features:
884
+ f0 = f0[:, 0:1, :] / 3
885
+ else:
886
+ f0 = f0 / 2
887
+ f0 = f0 * 6
888
+ else:
889
+ f0 = f0 / 6
890
+ f0 = f0 / 640
891
+
892
+ if voiced_mask is None:
893
+ voiced_mask = f0 > 0.0
894
+ else:
895
+ voiced_mask = voiced_mask.bool()
896
+
897
+ # due to grouping, f0 might be 1 frame short
898
+ voiced_mask = voiced_mask[:, :, : f0.shape[-1]]
899
+ if self.ap_pred_log_f0:
900
+ # if variable is set, decoder sees linear f0
901
+ # mask = f0 > 0.0 if voiced_mask is None else voiced_mask.bool()
902
+ f0[voiced_mask] = torch.exp(f0[voiced_mask])
903
+ f0[~voiced_mask] = 0.0
904
+ return f0
905
+
906
+ def infer_energy(self, residual, txt_enc_time_expanded, spk_vec, lens):
907
+ energy = self.energy_pred_module.infer(
908
+ residual, txt_enc_time_expanded, spk_vec, lens
909
+ )
910
+
911
+ # magic constants
912
+ if self.use_first_order_features:
913
+ energy = energy / 3
914
+ else:
915
+ energy = energy / 1.4
916
+ energy = (energy + 1) / 2
917
+ return energy
918
+
919
+ def remove_norms(self):
920
+ """Removes spectral and weightnorms from model. Call before inference"""
921
+ for name, module in self.named_modules():
922
+ try:
923
+ nn.utils.remove_spectral_norm(module, name="weight_hh_l0")
924
+ print("Removed spectral norm from {}".format(name))
925
+ except:
926
+ pass
927
+ try:
928
+ nn.utils.remove_spectral_norm(module, name="weight_hh_l0_reverse")
929
+ print("Removed spectral norm from {}".format(name))
930
+ except:
931
+ pass
932
+ try:
933
+ nn.utils.remove_weight_norm(module)
934
+ print("Removed wnorm from {}".format(name))
935
+ except:
936
+ pass
requirements-dev.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ruff
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+
3
+ gradio==5.18.0
4
+
5
+ torch
6
+ torchaudio
7
+ scipy
8
+ numba
9
+ lmdb
10
+ librosa
11
+
12
+ unidecode
13
+ inflect
14
+
15
+ git+https://github.com/langtech-bsc/vocos.git@matcha
splines.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original Source:
2
+ # Original Source:
3
+ # https://github.com/ndeutschmann/zunis/blob/master/zunis_lib/zunis/models/flows/coupling_cells/piecewise_coupling/piecewise_linear.py
4
+ # https://github.com/ndeutschmann/zunis/blob/master/zunis_lib/zunis/models/flows/coupling_cells/piecewise_coupling/piecewise_quadratic.py
5
+ # Modifications made to jacobian computation by Yurong You and Kevin Shih
6
+ # Original License Text:
7
+ #########################################################################
8
+
9
+ # The MIT License (MIT)
10
+ # Copyright (c) 2020, nicolas deutschmann
11
+
12
+ # Permission is hereby granted, free of charge, to any person obtaining
13
+ # a copy of this software and associated documentation files (the
14
+ # "Software"), to deal in the Software without restriction, including
15
+ # without limitation the rights to use, copy, modify, merge, publish,
16
+ # distribute, sublicense, and/or sell copies of the Software, and to
17
+ # permit persons to whom the Software is furnished to do so, subject to
18
+ # the following conditions:
19
+
20
+ # The above copyright notice and this permission notice shall be
21
+ # included in all copies or substantial portions of the Software.
22
+
23
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24
+ # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25
+ # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26
+ # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
27
+ # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
28
+ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
29
+ # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
30
+
31
+
32
+ import torch
33
+ import torch.nn.functional as F
34
+
35
+ third_dimension_softmax = torch.nn.Softmax(dim=2)
36
+
37
+
38
+ def piecewise_linear_transform(
39
+ x, q_tilde, compute_jacobian=True, outlier_passthru=True
40
+ ):
41
+ """Apply an element-wise piecewise-linear transformation to some variables
42
+
43
+ Parameters
44
+ ----------
45
+ x : torch.Tensor
46
+ a tensor with shape (N,k) where N is the batch dimension while k is the
47
+ dimension of the variable space. This variable span the k-dimensional unit
48
+ hypercube
49
+
50
+ q_tilde: torch.Tensor
51
+ is a tensor with shape (N,k,b) where b is the number of bins.
52
+ This contains the un-normalized heights of the bins of the piecewise-constant PDF for dimension k,
53
+ i.e. q_tilde lives in all of R and we don't impose a constraint on their sum yet.
54
+ Normalization is imposed in this function using softmax.
55
+
56
+ compute_jacobian : bool, optional
57
+ determines whether the jacobian should be compute or None is returned
58
+
59
+ Returns
60
+ -------
61
+ tuple of torch.Tensor
62
+ pair `(y,h)`.
63
+ - `y` is a tensor with shape (N,k) living in the k-dimensional unit hypercube
64
+ - `j` is the jacobian of the transformation with shape (N,) if compute_jacobian==True, else None.
65
+ """
66
+ logj = None
67
+
68
+ # TODO bottom-up assesment of handling the differentiability of variables
69
+ # Compute the bin width w
70
+ N, k, b = q_tilde.shape
71
+ Nx, kx = x.shape
72
+ assert N == Nx and k == kx, "Shape mismatch"
73
+
74
+ w = 1.0 / b
75
+
76
+ # Compute normalized bin heights with softmax function on bin dimension
77
+ q = 1.0 / w * third_dimension_softmax(q_tilde)
78
+ # x is in the mx-th bin: x \in [0,1],
79
+ # mx \in [[0,b-1]], so we clamp away the case x == 1
80
+ mx = torch.clamp(torch.floor(b * x), 0, b - 1).to(torch.long)
81
+ # Need special error handling because trying to index with mx
82
+ # if it contains nans will lock the GPU. (device-side assert triggered)
83
+ if torch.any(torch.isnan(mx)).item() or torch.any(mx < 0) or torch.any(mx >= b):
84
+ raise Exception("NaN detected in PWLinear bin indexing")
85
+
86
+ # We compute the output variable in-place
87
+ out = x - mx * w # alpha (element of [0.,w], the position of x in its bin
88
+
89
+ # Multiply by the slope
90
+ # q has shape (N,k,b), mxu = mx.unsqueeze(-1) has shape (N,k) with entries that are a b-index
91
+ # gather defines slope[i, j, k] = q[i, j, mxu[i, j, k]] with k taking only 0 as a value
92
+ # i.e. we say slope[i, j] = q[i, j, mx [i, j]]
93
+ slopes = torch.gather(q, 2, mx.unsqueeze(-1)).squeeze(-1)
94
+ out = out * slopes
95
+ # The jacobian is the product of the slopes in all dimensions
96
+
97
+ # Compute the integral over the left-bins.
98
+ # 1. Compute all integrals: cumulative sum of bin height * bin weight.
99
+ # We want that index i contains the cumsum *strictly to the left* so we shift by 1
100
+ # leaving the first entry null, which is achieved with a roll and assignment
101
+ q_left_integrals = torch.roll(torch.cumsum(q, 2) * w, 1, 2)
102
+ q_left_integrals[:, :, 0] = 0
103
+
104
+ # 2. Access the correct index to get the left integral of each point and add it to our transformation
105
+ out = out + torch.gather(q_left_integrals, 2, mx.unsqueeze(-1)).squeeze(-1)
106
+
107
+ # Regularization: points must be strictly within the unit hypercube
108
+ # Use the dtype information from pytorch
109
+ eps = torch.finfo(out.dtype).eps
110
+ out = out.clamp(min=eps, max=1.0 - eps)
111
+ oob_mask = torch.logical_or(x < 0.0, x > 1.0).detach().float()
112
+ if outlier_passthru:
113
+ out = out * (1 - oob_mask) + x * oob_mask
114
+ slopes = slopes * (1 - oob_mask) + oob_mask
115
+
116
+ if compute_jacobian:
117
+ # logj = torch.log(torch.prod(slopes.float(), 1))
118
+ logj = torch.sum(torch.log(slopes), 1)
119
+ del slopes
120
+
121
+ return out, logj
122
+
123
+
124
+ def piecewise_linear_inverse_transform(
125
+ y, q_tilde, compute_jacobian=True, outlier_passthru=True
126
+ ):
127
+ """
128
+ Apply inverse of an element-wise piecewise-linear transformation to some
129
+ variables
130
+
131
+ Parameters
132
+ ----------
133
+ y : torch.Tensor
134
+ a tensor with shape (N,k) where N is the batch dimension while k is the
135
+ dimension of the variable space. This variable span the k-dimensional unit
136
+ hypercube
137
+
138
+ q_tilde: torch.Tensor
139
+ is a tensor with shape (N,k,b) where b is the number of bins.
140
+ This contains the un-normalized heights of the bins of the piecewise-constant PDF for dimension k,
141
+ i.e. q_tilde lives in all of R and we don't impose a constraint on their sum yet.
142
+ Normalization is imposed in this function using softmax.
143
+
144
+ compute_jacobian : bool, optional
145
+ determines whether the jacobian should be compute or None is returned
146
+
147
+ Returns
148
+ -------
149
+ tuple of torch.Tensor
150
+ pair `(x,h)`.
151
+ - `x` is a tensor with shape (N,k) living in the k-dimensional unit hypercube
152
+ - `j` is the jacobian of the transformation with shape (N,) if compute_jacobian==True, else None.
153
+ """
154
+
155
+ # TODO bottom-up assesment of handling the differentiability of variables
156
+
157
+ # Compute the bin width w
158
+ N, k, b = q_tilde.shape
159
+ Ny, ky = y.shape
160
+ assert N == Ny and k == ky, "Shape mismatch"
161
+
162
+ w = 1.0 / b
163
+
164
+ # Compute normalized bin heights with softmax function on the bin dimension
165
+ q = 1.0 / w * third_dimension_softmax(q_tilde)
166
+
167
+ # Compute the integral over the left-bins in the forward transform.
168
+ # 1. Compute all integrals: cumulative sum of bin height * bin weight.
169
+ # We want that index i contains the cumsum *strictly to the left*,
170
+ # so we shift by 1 leaving the first entry null,
171
+ # which is achieved with a roll and assignment
172
+ q_left_integrals = torch.roll(torch.cumsum(q.float(), 2) * w, 1, 2)
173
+ q_left_integrals[:, :, 0] = 0
174
+
175
+ # Find which bin each y belongs to by finding the smallest bin such that
176
+ # y - q_left_integral is positive
177
+
178
+ edges = (y.unsqueeze(-1) - q_left_integrals).detach()
179
+ # y and q_left_integrals are between 0 and 1,
180
+ # so that their difference is at most 1.
181
+ # By setting the negative values to 2., we know that the
182
+ # smallest value left is the smallest positive
183
+ edges[edges < 0] = 2.0
184
+ edges = torch.clamp(torch.argmin(edges, dim=2), 0, b - 1).to(torch.long)
185
+
186
+ # Need special error handling because trying to index with mx
187
+ # if it contains nans will lock the GPU. (device-side assert triggered)
188
+ if (
189
+ torch.any(torch.isnan(edges)).item()
190
+ or torch.any(edges < 0)
191
+ or torch.any(edges >= b)
192
+ ):
193
+ raise Exception("NaN detected in PWLinear bin indexing")
194
+
195
+ # Gather the left integrals at each edge. See comment about gathering in q_left_integrals
196
+ # for the unsqueeze
197
+ q_left_integrals = q_left_integrals.gather(2, edges.unsqueeze(-1)).squeeze(-1)
198
+
199
+ # Gather the slope at each edge.
200
+ q = q.gather(2, edges.unsqueeze(-1)).squeeze(-1)
201
+
202
+ # Build the output
203
+ x = (y - q_left_integrals) / q + edges * w
204
+
205
+ # Regularization: points must be strictly within the unit hypercube
206
+ # Use the dtype information from pytorch
207
+ eps = torch.finfo(x.dtype).eps
208
+ x = x.clamp(min=eps, max=1.0 - eps)
209
+ oob_mask = torch.logical_or(y < 0.0, y > 1.0).detach().float()
210
+ if outlier_passthru:
211
+ x = x * (1 - oob_mask) + y * oob_mask
212
+ q = q * (1 - oob_mask) + oob_mask
213
+
214
+ # Prepare the jacobian
215
+ logj = None
216
+ if compute_jacobian:
217
+ # logj = - torch.log(torch.prod(q, 1))
218
+ logj = -torch.sum(torch.log(q.float()), 1)
219
+ return x.detach(), logj
220
+
221
+
222
+ def unbounded_piecewise_quadratic_transform(
223
+ x, w_tilde, v_tilde, upper=1, lower=0, inverse=False
224
+ ):
225
+ assert upper > lower
226
+ _range = upper - lower
227
+ inside_interval_mask = (x >= lower) & (x < upper)
228
+ outside_interval_mask = ~inside_interval_mask
229
+
230
+ outputs = torch.zeros_like(x)
231
+ log_j = torch.zeros_like(x)
232
+
233
+ outputs[outside_interval_mask] = x[outside_interval_mask]
234
+ log_j[outside_interval_mask] = 0
235
+
236
+ output, _log_j = piecewise_quadratic_transform(
237
+ (x[inside_interval_mask] - lower) / _range,
238
+ w_tilde[inside_interval_mask, :],
239
+ v_tilde[inside_interval_mask, :],
240
+ inverse=inverse,
241
+ )
242
+ outputs[inside_interval_mask] = output * _range + lower
243
+ if not inverse:
244
+ # the before and after transformation cancel out, so the log_j would be just as it is.
245
+ log_j[inside_interval_mask] = _log_j
246
+ else:
247
+ log_j = None
248
+ return outputs, log_j
249
+
250
+
251
+ def weighted_softmax(v, w):
252
+ # to avoid NaN...
253
+ v = v - torch.max(v, dim=-1, keepdim=True)[0]
254
+ v = torch.exp(v) + 1e-8 # to avoid NaN...
255
+ v_sum = torch.sum((v[..., :-1] + v[..., 1:]) / 2 * w, dim=-1, keepdim=True)
256
+ return v / v_sum
257
+
258
+
259
+ def piecewise_quadratic_transform(x, w_tilde, v_tilde, inverse=False):
260
+ """Element-wise piecewise-quadratic transformation
261
+ Parameters
262
+ ----------
263
+ x : torch.Tensor
264
+ *, The variable spans the D-dim unit hypercube ([0,1))
265
+ w_tilde : torch.Tensor
266
+ * x K defined in the paper
267
+ v_tilde : torch.Tensor
268
+ * x (K+1) defined in the paper
269
+ inverse : bool
270
+ forward or inverse
271
+ Returns
272
+ -------
273
+ c : torch.Tensor
274
+ *, transformed value
275
+ log_j : torch.Tensor
276
+ *, log determinant of the Jacobian matrix
277
+ """
278
+ w = torch.softmax(w_tilde, dim=-1)
279
+ v = weighted_softmax(v_tilde, w)
280
+ w_cumsum = torch.cumsum(w, dim=-1)
281
+ # force sum = 1
282
+ w_cumsum[..., -1] = 1.0
283
+ w_cumsum_shift = F.pad(w_cumsum, (1, 0), "constant", 0)
284
+ cdf = torch.cumsum((v[..., 1:] + v[..., :-1]) / 2 * w, dim=-1)
285
+ # force sum = 1
286
+ cdf[..., -1] = 1.0
287
+ cdf_shift = F.pad(cdf, (1, 0), "constant", 0)
288
+
289
+ if not inverse:
290
+ # * x D x 1, (w_cumsum[idx-1] < x <= w_cumsum[idx])
291
+ bin_index = torch.searchsorted(w_cumsum, x.unsqueeze(-1))
292
+ else:
293
+ # * x D x 1, (cdf[idx-1] < x <= cdf[idx])
294
+ bin_index = torch.searchsorted(cdf, x.unsqueeze(-1))
295
+
296
+ w_b = torch.gather(w, -1, bin_index).squeeze(-1)
297
+ w_bn1 = torch.gather(w_cumsum_shift, -1, bin_index).squeeze(-1)
298
+ v_b = torch.gather(v, -1, bin_index).squeeze(-1)
299
+ v_bp1 = torch.gather(v, -1, bin_index + 1).squeeze(-1)
300
+ cdf_bn1 = torch.gather(cdf_shift, -1, bin_index).squeeze(-1)
301
+
302
+ if not inverse:
303
+ alpha = (x - w_bn1) / w_b.clamp(min=torch.finfo(w_b.dtype).eps)
304
+ c = (alpha**2) / 2 * (v_bp1 - v_b) * w_b + alpha * v_b * w_b + cdf_bn1
305
+
306
+ # just sum of log pdfs
307
+ log_j = torch.lerp(v_b, v_bp1, alpha).clamp(min=torch.finfo(c.dtype).eps).log()
308
+
309
+ # make sure it falls into [0,1)
310
+ c = c.clamp(min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(c.dtype).eps)
311
+ return c, log_j
312
+ else:
313
+ # quadratic equation for alpha
314
+ # alpha should fall into (0, 1]. Since a, b > 0, the symmetry axis -b/2a < 0 and we should pick the larger root
315
+ # skip calculating the log_j in inverse since we don't need it
316
+ a = (v_bp1 - v_b) * w_b / 2
317
+ b = v_b * w_b
318
+ c = cdf_bn1 - x
319
+ alpha = (-b + torch.sqrt((b**2) - 4 * a * c)) / (2 * a)
320
+ inv = alpha * w_b + w_bn1
321
+
322
+ # make sure it falls into [0,1)
323
+ inv = inv.clamp(
324
+ min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(inv.dtype).eps
325
+ )
326
+ return inv, None
transformer.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/transformer.py
2
+ # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from common import get_mask_from_lengths, LinearNorm
19
+
20
+
21
+ class PositionalEmbedding(nn.Module):
22
+ def __init__(self, demb):
23
+ super(PositionalEmbedding, self).__init__()
24
+ self.demb = demb
25
+ inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
26
+ self.register_buffer("inv_freq", inv_freq)
27
+
28
+ def forward(self, pos_seq, bsz=None):
29
+ sinusoid_inp = torch.matmul(
30
+ torch.unsqueeze(pos_seq, -1), torch.unsqueeze(self.inv_freq, 0)
31
+ )
32
+ pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
33
+ if bsz is not None:
34
+ return pos_emb[None, :, :].expand(bsz, -1, -1)
35
+ else:
36
+ return pos_emb[None, :, :]
37
+
38
+
39
+ class PositionwiseConvFF(nn.Module):
40
+ def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
41
+ super(PositionwiseConvFF, self).__init__()
42
+
43
+ self.d_model = d_model
44
+ self.d_inner = d_inner
45
+ self.dropout = dropout
46
+
47
+ self.CoreNet = nn.Sequential(
48
+ nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
49
+ nn.ReLU(),
50
+ # nn.Dropout(dropout), # worse convergence
51
+ nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
52
+ nn.Dropout(dropout),
53
+ )
54
+ self.layer_norm = nn.LayerNorm(d_model)
55
+ self.pre_lnorm = pre_lnorm
56
+
57
+ def forward(self, inp):
58
+ return self._forward(inp)
59
+
60
+ def _forward(self, inp):
61
+ if self.pre_lnorm:
62
+ # layer normalization + positionwise feed-forward
63
+ core_out = inp.transpose(1, 2)
64
+ core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype))
65
+ core_out = core_out.transpose(1, 2)
66
+
67
+ # residual connection
68
+ output = core_out + inp
69
+ else:
70
+ # positionwise feed-forward
71
+ core_out = inp.transpose(1, 2)
72
+ core_out = self.CoreNet(core_out)
73
+ core_out = core_out.transpose(1, 2)
74
+
75
+ # residual connection + layer normalization
76
+ output = self.layer_norm(inp + core_out).to(inp.dtype)
77
+
78
+ return output
79
+
80
+
81
+ class MultiHeadAttn(nn.Module):
82
+ def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1, pre_lnorm=False):
83
+ super(MultiHeadAttn, self).__init__()
84
+
85
+ self.n_head = n_head
86
+ self.d_model = d_model
87
+ self.d_head = d_head
88
+ self.scale = 1 / (d_head**0.5)
89
+ self.pre_lnorm = pre_lnorm
90
+
91
+ self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
92
+ self.drop = nn.Dropout(dropout)
93
+ self.dropatt = nn.Dropout(dropatt)
94
+ self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
95
+ self.layer_norm = nn.LayerNorm(d_model)
96
+
97
+ def forward(self, inp, attn_mask=None):
98
+ return self._forward(inp, attn_mask)
99
+
100
+ def _forward(self, inp, attn_mask=None):
101
+ residual = inp
102
+
103
+ if self.pre_lnorm:
104
+ # layer normalization
105
+ inp = self.layer_norm(inp)
106
+
107
+ n_head, d_head = self.n_head, self.d_head
108
+
109
+ head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)
110
+ head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
111
+ head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
112
+ head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
113
+
114
+ q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
115
+ k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
116
+ v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
117
+
118
+ attn_score = torch.bmm(q, k.transpose(1, 2))
119
+ attn_score.mul_(self.scale)
120
+
121
+ if attn_mask is not None:
122
+ attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype)
123
+ attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
124
+ attn_score.masked_fill_(attn_mask.to(torch.bool), -float("inf"))
125
+
126
+ attn_prob = F.softmax(attn_score, dim=2)
127
+ attn_prob = self.dropatt(attn_prob)
128
+ attn_vec = torch.bmm(attn_prob, v)
129
+
130
+ attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
131
+ attn_vec = (
132
+ attn_vec.permute(1, 2, 0, 3)
133
+ .contiguous()
134
+ .view(inp.size(0), inp.size(1), n_head * d_head)
135
+ )
136
+
137
+ # linear projection
138
+ attn_out = self.o_net(attn_vec)
139
+ attn_out = self.drop(attn_out)
140
+
141
+ # residual connection + layer normalization
142
+ output = self.layer_norm(residual + attn_out)
143
+
144
+ output = output.to(attn_out.dtype)
145
+
146
+ return output
147
+
148
+
149
+ class TransformerLayer(nn.Module):
150
+ def __init__(
151
+ self, n_head, d_model, d_head, d_inner, kernel_size, dropout, **kwargs
152
+ ):
153
+ super(TransformerLayer, self).__init__()
154
+
155
+ self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
156
+ self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout)
157
+
158
+ def forward(self, dec_inp, mask=None):
159
+ output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
160
+ output *= mask
161
+ output = self.pos_ff(output)
162
+ output *= mask
163
+ return output
164
+
165
+
166
+ class FFTransformer(nn.Module):
167
+ def __init__(
168
+ self,
169
+ in_dim,
170
+ out_dim=1,
171
+ n_layers=6,
172
+ n_head=1,
173
+ d_head=64,
174
+ d_inner=1024,
175
+ kernel_size=3,
176
+ dropout=0.1,
177
+ dropatt=0.1,
178
+ dropemb=0.0,
179
+ ):
180
+ super(FFTransformer, self).__init__()
181
+ self.in_dim = in_dim
182
+ self.out_dim = out_dim
183
+ self.n_head = n_head
184
+ self.d_head = d_head
185
+
186
+ self.pos_emb = PositionalEmbedding(self.in_dim)
187
+ self.drop = nn.Dropout(dropemb)
188
+ self.layers = nn.ModuleList()
189
+
190
+ for _ in range(n_layers):
191
+ self.layers.append(
192
+ TransformerLayer(
193
+ n_head,
194
+ in_dim,
195
+ d_head,
196
+ d_inner,
197
+ kernel_size,
198
+ dropout,
199
+ dropatt=dropatt,
200
+ )
201
+ )
202
+
203
+ self.dense = LinearNorm(in_dim, out_dim)
204
+
205
+ def forward(self, dec_inp, in_lens):
206
+ # B, C, T --> B, T, C
207
+ inp = dec_inp.transpose(1, 2)
208
+ mask = get_mask_from_lengths(in_lens)[..., None]
209
+
210
+ pos_seq = torch.arange(inp.size(1), device=inp.device).to(inp.dtype)
211
+ pos_emb = self.pos_emb(pos_seq) * mask
212
+
213
+ out = self.drop(inp + pos_emb)
214
+
215
+ for layer in self.layers:
216
+ out = layer(out, mask=mask)
217
+
218
+ out = self.dense(out).transpose(1, 2)
219
+ return out
tts_text_processing/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2017 Keith Ito
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
tts_text_processing/abbreviations.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ _no_period_re = re.compile(r"(No[.])(?=[ ]?[0-9])")
4
+ _percent_re = re.compile(r"([ ]?[%])")
5
+ _half_re = re.compile("([0-9]½)|(½)")
6
+
7
+
8
+ # List of (regular expression, replacement) pairs for abbreviations:
9
+ _abbreviations = [
10
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
11
+ for x in [
12
+ ("mrs", "misess"),
13
+ ("ms", "miss"),
14
+ ("mr", "mister"),
15
+ ("dr", "doctor"),
16
+ ("st", "saint"),
17
+ ("co", "company"),
18
+ ("jr", "junior"),
19
+ ("maj", "major"),
20
+ ("gen", "general"),
21
+ ("drs", "doctors"),
22
+ ("rev", "reverend"),
23
+ ("lt", "lieutenant"),
24
+ ("hon", "honorable"),
25
+ ("sgt", "sergeant"),
26
+ ("capt", "captain"),
27
+ ("esq", "esquire"),
28
+ ("ltd", "limited"),
29
+ ("col", "colonel"),
30
+ ("ft", "fort"),
31
+ ]
32
+ ]
33
+
34
+
35
+ def _expand_no_period(m):
36
+ word = m.group(0)
37
+ if word[0] == "N":
38
+ return "Number"
39
+ return "number"
40
+
41
+
42
+ def _expand_percent(m):
43
+ return " percent"
44
+
45
+
46
+ def _expand_half(m):
47
+ word = m.group(1)
48
+ if word is None:
49
+ return "half"
50
+ return word[0] + " and a half"
51
+
52
+
53
+ def normalize_abbreviations(text):
54
+ text = re.sub(_no_period_re, _expand_no_period, text)
55
+ text = re.sub(_percent_re, _expand_percent, text)
56
+ text = re.sub(_half_re, _expand_half, text)
57
+ return text
tts_text_processing/acronyms.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ _letter_to_arpabet = {
4
+ "A": "EY1",
5
+ "B": "B IY1",
6
+ "C": "S IY1",
7
+ "D": "D IY1",
8
+ "E": "IY1",
9
+ "F": "EH1 F",
10
+ "G": "JH IY1",
11
+ "H": "EY1 CH",
12
+ "I": "AY1",
13
+ "J": "JH EY1",
14
+ "K": "K EY1",
15
+ "L": "EH1 L",
16
+ "M": "EH1 M",
17
+ "N": "EH1 N",
18
+ "O": "OW1",
19
+ "P": "P IY1",
20
+ "Q": "K Y UW1",
21
+ "R": "AA1 R",
22
+ "S": "EH1 S",
23
+ "T": "T IY1",
24
+ "U": "Y UW1",
25
+ "V": "V IY1",
26
+ "X": "EH1 K S",
27
+ "Y": "W AY1",
28
+ "W": "D AH1 B AH0 L Y UW0",
29
+ "Z": "Z IY1",
30
+ "s": "Z",
31
+ }
32
+
33
+ # must ignore roman numerals
34
+ # _acronym_re = re.compile(r'([A-Z][A-Z]+)s?|([A-Z]\.([A-Z]\.)+s?)')
35
+ _acronym_re = re.compile(r"([A-Z][A-Z]+)s?")
36
+
37
+
38
+ class AcronymNormalizer(object):
39
+ def __init__(self, phoneme_dict):
40
+ self.phoneme_dict = phoneme_dict
41
+
42
+ def normalize_acronyms(self, text):
43
+ def _expand_acronyms(m, add_spaces=True):
44
+ acronym = m.group(0)
45
+ # remove dots if they exist
46
+ acronym = re.sub("\.", "", acronym)
47
+
48
+ acronym = "".join(acronym.split())
49
+ arpabet = self.phoneme_dict.lookup(acronym)
50
+
51
+ if arpabet is None:
52
+ acronym = list(acronym)
53
+ arpabet = ["{" + _letter_to_arpabet[letter] + "}" for letter in acronym]
54
+ # temporary fix
55
+ if arpabet[-1] == "{Z}" and len(arpabet) > 1:
56
+ arpabet[-2] = arpabet[-2][:-1] + " " + arpabet[-1][1:]
57
+ del arpabet[-1]
58
+ arpabet = " ".join(arpabet)
59
+ elif len(arpabet) == 1:
60
+ arpabet = "{" + arpabet[0] + "}"
61
+ else:
62
+ arpabet = acronym
63
+ return arpabet
64
+
65
+ text = re.sub(_acronym_re, _expand_acronyms, text)
66
+ return text
67
+
68
+ def __call__(self, text):
69
+ return self.normalize_acronyms(text)
tts_text_processing/cleaners.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """adapted from https://github.com/keithito/tacotron"""
2
+
3
+ """
4
+ Cleaners are transformations that run over the input text at both training and eval time.
5
+
6
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8
+ 1. "english_cleaners" for English text
9
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12
+ the symbols in symbols.py to match your data).
13
+ """
14
+
15
+ import re
16
+ from string import punctuation
17
+ from functools import reduce
18
+ from unidecode import unidecode
19
+ from .numerical import normalize_numbers, normalize_currency
20
+ from .acronyms import AcronymNormalizer
21
+ from .datestime import normalize_datestime
22
+ from .letters_and_numbers import normalize_letters_and_numbers
23
+ from .abbreviations import normalize_abbreviations
24
+
25
+
26
+ # Regular expression matching whitespace:
27
+ _whitespace_re = re.compile(r"\s+")
28
+
29
+ # Regular expression separating words enclosed in curly braces for cleaning
30
+ _arpa_re = re.compile(r"{[^}]+}|\S+")
31
+
32
+
33
+ def expand_abbreviations(text):
34
+ return normalize_abbreviations(text)
35
+
36
+
37
+ def expand_numbers(text):
38
+ return normalize_numbers(text)
39
+
40
+
41
+ def expand_currency(text):
42
+ return normalize_currency(text)
43
+
44
+
45
+ def expand_datestime(text):
46
+ return normalize_datestime(text)
47
+
48
+
49
+ def expand_letters_and_numbers(text):
50
+ return normalize_letters_and_numbers(text)
51
+
52
+
53
+ def lowercase(text):
54
+ return text.lower()
55
+
56
+
57
+ def collapse_whitespace(text):
58
+ return re.sub(_whitespace_re, " ", text)
59
+
60
+
61
+ def separate_acronyms(text):
62
+ text = re.sub(r"([0-9]+)([a-zA-Z]+)", r"\1 \2", text)
63
+ text = re.sub(r"([a-zA-Z]+)([0-9]+)", r"\1 \2", text)
64
+ return text
65
+
66
+
67
+ def convert_to_ascii(text):
68
+ return unidecode(text)
69
+
70
+
71
+ def dehyphenize_compound_words(text):
72
+ text = re.sub(r"(?<=[a-zA-Z0-9])-(?=[a-zA-Z])", " ", text)
73
+ return text
74
+
75
+
76
+ def remove_space_before_punctuation(text):
77
+ return re.sub(r"\s([{}](?:\s|$))".format(punctuation), r"\1", text)
78
+
79
+
80
+ class Cleaner(object):
81
+ def __init__(self, cleaner_names, phonemedict):
82
+ self.cleaner_names = cleaner_names
83
+ self.phonemedict = phonemedict
84
+ self.acronym_normalizer = AcronymNormalizer(self.phonemedict)
85
+
86
+ def __call__(self, text):
87
+ for cleaner_name in self.cleaner_names:
88
+ sequence_fns, word_fns = self.get_cleaner_fns(cleaner_name)
89
+ for fn in sequence_fns:
90
+ text = fn(text)
91
+
92
+ text = [
93
+ reduce(lambda x, y: y(x), word_fns, split) if split[0] != "{" else split
94
+ for split in _arpa_re.findall(text)
95
+ ]
96
+ text = " ".join(text)
97
+ text = remove_space_before_punctuation(text)
98
+ return text
99
+
100
+ def get_cleaner_fns(self, cleaner_name):
101
+ if cleaner_name == "basic_cleaners":
102
+ sequence_fns = [lowercase, collapse_whitespace]
103
+ word_fns = []
104
+ elif cleaner_name == "english_cleaners":
105
+ sequence_fns = [collapse_whitespace, convert_to_ascii, lowercase]
106
+ word_fns = [expand_numbers, expand_abbreviations]
107
+ elif cleaner_name == "radtts_cleaners":
108
+ sequence_fns = [
109
+ collapse_whitespace,
110
+ expand_currency,
111
+ expand_datestime,
112
+ expand_letters_and_numbers,
113
+ ]
114
+ word_fns = [expand_numbers, expand_abbreviations]
115
+ elif cleaner_name == "ukrainian_cleaners":
116
+ sequence_fns = [lowercase, collapse_whitespace]
117
+ word_fns = []
118
+ elif cleaner_name == "transliteration_cleaners":
119
+ sequence_fns = [convert_to_ascii, lowercase, collapse_whitespace]
120
+ else:
121
+ raise Exception("{} cleaner not supported".format(cleaner_name))
122
+
123
+ return sequence_fns, word_fns
tts_text_processing/cmudict.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """adapted from https://github.com/keithito/tacotron"""
2
+
3
+ import re
4
+
5
+
6
+ valid_symbols = [
7
+ "AA",
8
+ "AA0",
9
+ "AA1",
10
+ "AA2",
11
+ "AE",
12
+ "AE0",
13
+ "AE1",
14
+ "AE2",
15
+ "AH",
16
+ "AH0",
17
+ "AH1",
18
+ "AH2",
19
+ "AO",
20
+ "AO0",
21
+ "AO1",
22
+ "AO2",
23
+ "AW",
24
+ "AW0",
25
+ "AW1",
26
+ "AW2",
27
+ "AY",
28
+ "AY0",
29
+ "AY1",
30
+ "AY2",
31
+ "B",
32
+ "CH",
33
+ "D",
34
+ "DH",
35
+ "EH",
36
+ "EH0",
37
+ "EH1",
38
+ "EH2",
39
+ "ER",
40
+ "ER0",
41
+ "ER1",
42
+ "ER2",
43
+ "EY",
44
+ "EY0",
45
+ "EY1",
46
+ "EY2",
47
+ "F",
48
+ "G",
49
+ "HH",
50
+ "IH",
51
+ "IH0",
52
+ "IH1",
53
+ "IH2",
54
+ "IY",
55
+ "IY0",
56
+ "IY1",
57
+ "IY2",
58
+ "JH",
59
+ "K",
60
+ "L",
61
+ "M",
62
+ "N",
63
+ "NG",
64
+ "OW",
65
+ "OW0",
66
+ "OW1",
67
+ "OW2",
68
+ "OY",
69
+ "OY0",
70
+ "OY1",
71
+ "OY2",
72
+ "P",
73
+ "R",
74
+ "S",
75
+ "SH",
76
+ "T",
77
+ "TH",
78
+ "UH",
79
+ "UH0",
80
+ "UH1",
81
+ "UH2",
82
+ "UW",
83
+ "UW0",
84
+ "UW1",
85
+ "UW2",
86
+ "V",
87
+ "W",
88
+ "Y",
89
+ "Z",
90
+ "ZH",
91
+ ]
92
+
93
+ _valid_symbol_set = set(valid_symbols)
94
+
95
+
96
+ class CMUDict:
97
+ """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
98
+
99
+ def __init__(self, file_or_path, keep_ambiguous=True):
100
+ if isinstance(file_or_path, str):
101
+ with open(file_or_path, encoding="latin-1") as f:
102
+ entries = _parse_cmudict(f)
103
+ else:
104
+ entries = _parse_cmudict(file_or_path)
105
+ if not keep_ambiguous:
106
+ entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
107
+ self._entries = entries
108
+
109
+ def __len__(self):
110
+ return len(self._entries)
111
+
112
+ def lookup(self, word):
113
+ """Returns list of ARPAbet pronunciations of the given word."""
114
+ return self._entries.get(word.upper())
115
+
116
+
117
+ _alt_re = re.compile(r"\([0-9]+\)")
118
+
119
+
120
+ def _parse_cmudict(file):
121
+ cmudict = {}
122
+ for line in file:
123
+ if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
124
+ parts = line.split(" ")
125
+ word = re.sub(_alt_re, "", parts[0])
126
+ pronunciation = _get_pronunciation(parts[1])
127
+ if pronunciation:
128
+ if word in cmudict:
129
+ cmudict[word].append(pronunciation)
130
+ else:
131
+ cmudict[word] = [pronunciation]
132
+ return cmudict
133
+
134
+
135
+ def _get_pronunciation(s):
136
+ parts = s.strip().split(" ")
137
+ for part in parts:
138
+ if part not in _valid_symbol_set:
139
+ return None
140
+ return " ".join(parts)
tts_text_processing/datestime.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """adapted from https://github.com/keithito/tacotron"""
2
+
3
+ import re
4
+
5
+ _ampm_re = re.compile(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):?([0-5][0-9])?\s*([AaPp][Mm]\b)")
6
+
7
+
8
+ def _expand_ampm(m):
9
+ matches = list(m.groups(0))
10
+ txt = matches[0]
11
+ txt = txt if int(matches[1]) == 0 else txt + " " + matches[1]
12
+
13
+ if matches[2][0].lower() == "a":
14
+ txt += " a.m."
15
+ elif matches[2][0].lower() == "p":
16
+ txt += " p.m."
17
+
18
+ return txt
19
+
20
+
21
+ def normalize_datestime(text):
22
+ text = re.sub(_ampm_re, _expand_ampm, text)
23
+ # text = re.sub(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])?", r"\1 \2", text)
24
+ return text
tts_text_processing/grapheme_dictionary.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """adapted from https://github.com/keithito/tacotron"""
2
+
3
+ import re
4
+
5
+ _alt_re = re.compile(r"\([0-9]+\)")
6
+
7
+
8
+ class Grapheme2PhonemeDictionary:
9
+ """Thin wrapper around g2p data."""
10
+
11
+ def __init__(self, file_or_path, keep_ambiguous=True, encoding="latin-1"):
12
+ with open(file_or_path, encoding=encoding) as f:
13
+ entries = _parse_g2p(f)
14
+ if not keep_ambiguous:
15
+ entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
16
+ self._entries = entries
17
+
18
+ def __len__(self):
19
+ return len(self._entries)
20
+
21
+ def lookup(self, word):
22
+ """Returns list of pronunciations of the given word."""
23
+ return self._entries.get(word.upper())
24
+
25
+
26
+ def _parse_g2p(file):
27
+ g2p = {}
28
+ for line in file:
29
+ if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
30
+ parts = line.split(" ")
31
+ word = re.sub(_alt_re, "", parts[0])
32
+ pronunciation = parts[1].strip()
33
+ if word in g2p:
34
+ g2p[word].append(pronunciation)
35
+ else:
36
+ g2p[word] = [pronunciation]
37
+ return g2p
tts_text_processing/heteronyms ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ abject
2
+ abrogate
3
+ absent
4
+ abstract
5
+ abuse
6
+ ache
7
+ acre
8
+ acuminate
9
+ addict
10
+ address
11
+ adduct
12
+ adele
13
+ advocate
14
+ affect
15
+ affiliate
16
+ agape
17
+ aged
18
+ agglomerate
19
+ aggregate
20
+ agonic
21
+ agora
22
+ allied
23
+ ally
24
+ alternate
25
+ alum
26
+ am
27
+ analyses
28
+ andrea
29
+ animate
30
+ apply
31
+ appropriate
32
+ approximate
33
+ ares
34
+ arithmetic
35
+ arsenic
36
+ articulate
37
+ associate
38
+ attribute
39
+ august
40
+ axes
41
+ ay
42
+ aye
43
+ bases
44
+ bass
45
+ bathed
46
+ bested
47
+ bifurcate
48
+ blessed
49
+ blotto
50
+ bow
51
+ bowed
52
+ bowman
53
+ brassy
54
+ buffet
55
+ bustier
56
+ carbonate
57
+ celtic
58
+ choral
59
+ chumash
60
+ close
61
+ closer
62
+ coax
63
+ coincidence
64
+ color coordinate
65
+ colour coordinate
66
+ comber
67
+ combine
68
+ combs
69
+ committee
70
+ commune
71
+ compact
72
+ complex
73
+ compound
74
+ compress
75
+ concert
76
+ conduct
77
+ confine
78
+ confines
79
+ conflict
80
+ conglomerate
81
+ conscript
82
+ conserve
83
+ consist
84
+ console
85
+ consort
86
+ construct
87
+ consult
88
+ consummate
89
+ content
90
+ contest
91
+ contract
92
+ contracts
93
+ contrast
94
+ converse
95
+ convert
96
+ convict
97
+ coop
98
+ coordinate
99
+ covey
100
+ crooked
101
+ curate
102
+ cussed
103
+ decollate
104
+ decrease
105
+ defect
106
+ defense
107
+ delegate
108
+ deliberate
109
+ denier
110
+ desert
111
+ detail
112
+ deviate
113
+ diagnoses
114
+ diffuse
115
+ digest
116
+ discard
117
+ discharge
118
+ discount
119
+ do
120
+ document
121
+ does
122
+ dogged
123
+ domesticate
124
+ dominican
125
+ dove
126
+ dr
127
+ drawer
128
+ duplicate
129
+ egress
130
+ ejaculate
131
+ eject
132
+ elaborate
133
+ ellipses
134
+ email
135
+ emu
136
+ entrace
137
+ entrance
138
+ escort
139
+ estimate
140
+ eta
141
+ etna
142
+ evening
143
+ excise
144
+ excuse
145
+ exploit
146
+ export
147
+ extract
148
+ fine
149
+ flower
150
+ forbear
151
+ four-legged
152
+ frequent
153
+ furrier
154
+ gallant
155
+ gel
156
+ geminate
157
+ gillie
158
+ glower
159
+ gotham
160
+ graduate
161
+ haggis
162
+ heavy
163
+ hinder
164
+ house
165
+ housewife
166
+ impact
167
+ imped
168
+ implant
169
+ implement
170
+ import
171
+ impress
172
+ incense
173
+ incline
174
+ increase
175
+ infix
176
+ insert
177
+ instar
178
+ insult
179
+ integral
180
+ intercept
181
+ interchange
182
+ interflow
183
+ interleaf
184
+ intermediate
185
+ intern
186
+ interspace
187
+ intimate
188
+ intrigue
189
+ invalid
190
+ invert
191
+ invite
192
+ irony
193
+ jagged
194
+ jesses
195
+ julies
196
+ kite
197
+ laminate
198
+ laos
199
+ lather
200
+ lead
201
+ learned
202
+ leasing
203
+ lech
204
+ legitimate
205
+ lied
206
+ lima
207
+ lipread
208
+ live
209
+ lower
210
+ lunged
211
+ maas
212
+ magdalen
213
+ manes
214
+ mare
215
+ marked
216
+ merchandise
217
+ merlion
218
+ minute
219
+ misconduct
220
+ misled
221
+ misprint
222
+ mobile
223
+ moderate
224
+ mong
225
+ moped
226
+ moth
227
+ mouth
228
+ mow
229
+ mpg
230
+ multiply
231
+ mush
232
+ nana
233
+ nice
234
+ nice
235
+ number
236
+ numerate
237
+ nun
238
+ object
239
+ opiate
240
+ ornament
241
+ outbox
242
+ outcry
243
+ outpour
244
+ outreach
245
+ outride
246
+ outright
247
+ outside
248
+ outwork
249
+ overall
250
+ overbid
251
+ overcall
252
+ overcast
253
+ overfall
254
+ overflow
255
+ overhaul
256
+ overhead
257
+ overlap
258
+ overlay
259
+ overuse
260
+ overweight
261
+ overwork
262
+ pace
263
+ palled
264
+ palling
265
+ para
266
+ pasty
267
+ pate
268
+ pauline
269
+ pedal
270
+ peer
271
+ perfect
272
+ periodic
273
+ permit
274
+ pervert
275
+ pinta
276
+ placer
277
+ platy
278
+ polish
279
+ polish
280
+ poll
281
+ pontificate
282
+ postulate
283
+ pram
284
+ prayer
285
+ precipitate
286
+ predate
287
+ predicate
288
+ prefix
289
+ preposition
290
+ present
291
+ pretest
292
+ primer
293
+ proceeds
294
+ produce
295
+ progress
296
+ project
297
+ proportionate
298
+ prospect
299
+ protest
300
+ pussy
301
+ putter
302
+ putting
303
+ quite
304
+ ragged
305
+ raven
306
+ re
307
+ read
308
+ reading
309
+ reading
310
+ real
311
+ rebel
312
+ recall
313
+ recap
314
+ recitative
315
+ recollect
316
+ record
317
+ recreate
318
+ recreation
319
+ redress
320
+ refill
321
+ refund
322
+ refuse
323
+ reject
324
+ relay
325
+ remake
326
+ repaint
327
+ reprint
328
+ reread
329
+ rerun
330
+ resent
331
+ reside
332
+ resign
333
+ respray
334
+ resume
335
+ retard
336
+ retest
337
+ retread
338
+ rewrite
339
+ root
340
+ routed
341
+ routing
342
+ row
343
+ rugged
344
+ rummy
345
+ sais
346
+ sake
347
+ sambuca
348
+ saucier
349
+ second
350
+ secrete
351
+ secreted
352
+ secreting
353
+ segment
354
+ separate
355
+ sewer
356
+ shirk
357
+ shower
358
+ sin
359
+ skied
360
+ slaver
361
+ slough
362
+ sow
363
+ spoof
364
+ squid
365
+ stingy
366
+ subject
367
+ subordinate
368
+ subvert
369
+ supply
370
+ supposed
371
+ survey
372
+ suspect
373
+ syringes
374
+ tabulate
375
+ tales
376
+ tarrier
377
+ tarry
378
+ taxes
379
+ taxis
380
+ tear
381
+ theron
382
+ thou
383
+ three-legged
384
+ tier
385
+ tinged
386
+ torment
387
+ transfer
388
+ transform
389
+ transplant
390
+ transport
391
+ transpose
392
+ tush
393
+ two-legged
394
+ unionised
395
+ unionized
396
+ update
397
+ uplift
398
+ upset
399
+ use
400
+ used
401
+ vale
402
+ violist
403
+ viva
404
+ ware
405
+ whinged
406
+ whoop
407
+ wicked
408
+ wind
409
+ windy
410
+ wino
411
+ won
412
+ worsted
413
+ wound
tts_text_processing/letters_and_numbers.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """adapted from https://github.com/keithito/tacotron"""
2
+
3
+ import re
4
+
5
+ _letters_and_numbers_re = re.compile(
6
+ r"((?:[a-zA-Z]+[0-9]|[0-9]+[a-zA-Z])[a-zA-Z0-9']*)", re.IGNORECASE
7
+ )
8
+
9
+ _hardware_re = re.compile(
10
+ "([0-9]+(?:[.,][0-9]+)?)(?:\s?)(tb|gb|mb|kb|ghz|mhz|khz|hz|mm)", re.IGNORECASE
11
+ )
12
+ _hardware_key = {
13
+ "tb": "terabyte",
14
+ "gb": "gigabyte",
15
+ "mb": "megabyte",
16
+ "kb": "kilobyte",
17
+ "ghz": "gigahertz",
18
+ "mhz": "megahertz",
19
+ "khz": "kilohertz",
20
+ "hz": "hertz",
21
+ "mm": "millimeter",
22
+ "cm": "centimeter",
23
+ "km": "kilometer",
24
+ }
25
+
26
+ _dimension_re = re.compile(
27
+ r"\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b|\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b"
28
+ )
29
+ _dimension_key = {"m": "meter", "in": "inch", "inch": "inch"}
30
+
31
+
32
+ def _expand_letters_and_numbers(m):
33
+ text = re.split(r"(\d+)", m.group(0))
34
+
35
+ # remove trailing space
36
+ if text[-1] == "":
37
+ text = text[:-1]
38
+ elif text[0] == "":
39
+ text = text[1:]
40
+
41
+ # if not like 1920s, or AK47's , 20th, 1st, 2nd, 3rd, etc...
42
+ if text[-1] in ("'s", "s", "th", "nd", "st", "rd") and text[-2].isdigit():
43
+ text[-2] = text[-2] + text[-1]
44
+ text = text[:-1]
45
+
46
+ # for combining digits 2 by 2
47
+ new_text = []
48
+ for i in range(len(text)):
49
+ string = text[i]
50
+ if string.isdigit() and len(string) < 5:
51
+ # heuristics
52
+ if len(string) > 2 and string[-2] == "0":
53
+ if string[-1] == "0":
54
+ string = [string]
55
+ else:
56
+ string = [string[:-3], string[-2], string[-1]]
57
+ elif len(string) % 2 == 0:
58
+ string = [string[i : i + 2] for i in range(0, len(string), 2)]
59
+ elif len(string) > 2:
60
+ string = [string[0]] + [
61
+ string[i : i + 2] for i in range(1, len(string), 2)
62
+ ]
63
+ new_text.extend(string)
64
+ else:
65
+ new_text.append(string)
66
+
67
+ text = new_text
68
+ text = " ".join(text)
69
+ return text
70
+
71
+
72
+ def _expand_hardware(m):
73
+ quantity, measure = m.groups(0)
74
+ measure = _hardware_key[measure.lower()]
75
+ if measure[-1] != "z" and float(quantity.replace(",", "")) > 1:
76
+ return "{} {}s".format(quantity, measure)
77
+ return "{} {}".format(quantity, measure)
78
+
79
+
80
+ def _expand_dimension(m):
81
+ text = "".join([x for x in m.groups(0) if x != 0])
82
+ text = text.replace(" x ", " by ")
83
+ text = text.replace("x", " by ")
84
+ if text.endswith(tuple(_dimension_key.keys())):
85
+ if text[-2].isdigit():
86
+ text = "{} {}".format(text[:-1], _dimension_key[text[-1:]])
87
+ elif text[-3].isdigit():
88
+ text = "{} {}".format(text[:-2], _dimension_key[text[-2:]])
89
+ return text
90
+
91
+
92
+ def normalize_letters_and_numbers(text):
93
+ text = re.sub(_hardware_re, _expand_hardware, text)
94
+ text = re.sub(_dimension_re, _expand_dimension, text)
95
+ text = re.sub(_letters_and_numbers_re, _expand_letters_and_numbers, text)
96
+ return text
tts_text_processing/numerical.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """adapted from https://github.com/keithito/tacotron"""
2
+
3
+ import inflect
4
+ import re
5
+
6
+ _magnitudes = ["trillion", "billion", "million", "thousand", "hundred", "m", "b", "t"]
7
+ _magnitudes_key = {"m": "million", "b": "billion", "t": "trillion"}
8
+ _measurements = "(f|c|k|d|m)"
9
+ _measurements_key = {"f": "fahrenheit", "c": "celsius", "k": "thousand", "m": "meters"}
10
+ _currency_key = {"$": "dollar", "£": "pound", "€": "euro", "₩": "won"}
11
+ _inflect = inflect.engine()
12
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
13
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
14
+ _currency_re = re.compile(
15
+ r"([\$€£₩])([0-9\.\,]*[0-9]+)(?:[ ]?({})(?=[^a-zA-Z]))?".format(
16
+ "|".join(_magnitudes)
17
+ ),
18
+ re.IGNORECASE,
19
+ )
20
+ _measurement_re = re.compile(
21
+ r"([0-9\.\,]*[0-9]+(\s)?{}\b)".format(_measurements), re.IGNORECASE
22
+ )
23
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
24
+ # _range_re = re.compile(r'(?<=[0-9])+(-)(?=[0-9])+.*?')
25
+ _roman_re = re.compile(
26
+ r"\b(?=[MDCLXVI]+\b)M{0,4}(CM|CD|D?C{0,3})(XC|XL|L?X{0,3})(IX|IV|V?I{2,3})\b"
27
+ ) # avoid I
28
+ _multiply_re = re.compile(r"(\b[0-9]+)(x)([0-9]+)")
29
+ _number_re = re.compile(r"[0-9]+'s|[0-9]+s|[0-9]+")
30
+
31
+
32
+ def _remove_commas(m):
33
+ return m.group(1).replace(",", "")
34
+
35
+
36
+ def _expand_decimal_point(m):
37
+ return m.group(1).replace(".", " point ")
38
+
39
+
40
+ def _expand_currency(m):
41
+ currency = _currency_key[m.group(1)]
42
+ quantity = m.group(2)
43
+ magnitude = m.group(3)
44
+
45
+ # remove commas from quantity to be able to convert to numerical
46
+ quantity = quantity.replace(",", "")
47
+
48
+ # check for million, billion, etc...
49
+ if magnitude is not None and magnitude.lower() in _magnitudes:
50
+ if len(magnitude) == 1:
51
+ magnitude = _magnitudes_key[magnitude.lower()]
52
+ return "{} {} {}".format(_expand_hundreds(quantity), magnitude, currency + "s")
53
+
54
+ parts = quantity.split(".")
55
+ if len(parts) > 2:
56
+ return quantity + " " + currency + "s" # Unexpected format
57
+
58
+ dollars = int(parts[0]) if parts[0] else 0
59
+
60
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
61
+ if dollars and cents:
62
+ dollar_unit = currency if dollars == 1 else currency + "s"
63
+ cent_unit = "cent" if cents == 1 else "cents"
64
+ return "{} {}, {} {}".format(
65
+ _expand_hundreds(dollars),
66
+ dollar_unit,
67
+ _inflect.number_to_words(cents),
68
+ cent_unit,
69
+ )
70
+ elif dollars:
71
+ dollar_unit = currency if dollars == 1 else currency + "s"
72
+ return "{} {}".format(_expand_hundreds(dollars), dollar_unit)
73
+ elif cents:
74
+ cent_unit = "cent" if cents == 1 else "cents"
75
+ return "{} {}".format(_inflect.number_to_words(cents), cent_unit)
76
+ else:
77
+ return "zero" + " " + currency + "s"
78
+
79
+
80
+ def _expand_hundreds(text):
81
+ number = float(text)
82
+ if number > 1000 < 10000 and (number % 100 == 0) and (number % 1000 != 0):
83
+ return _inflect.number_to_words(int(number / 100)) + " hundred"
84
+ else:
85
+ return _inflect.number_to_words(text)
86
+
87
+
88
+ def _expand_ordinal(m):
89
+ return _inflect.number_to_words(m.group(0))
90
+
91
+
92
+ def _expand_measurement(m):
93
+ _, number, measurement = re.split("(\d+(?:\.\d+)?)", m.group(0))
94
+ number = _inflect.number_to_words(number)
95
+ measurement = "".join(measurement.split())
96
+ measurement = _measurements_key[measurement.lower()]
97
+ return "{} {}".format(number, measurement)
98
+
99
+
100
+ def _expand_range(m):
101
+ return " to "
102
+
103
+
104
+ def _expand_multiply(m):
105
+ left = m.group(1)
106
+ right = m.group(3)
107
+ return "{} by {}".format(left, right)
108
+
109
+
110
+ def _expand_roman(m):
111
+ # from https://stackoverflow.com/questions/19308177/converting-roman-numerals-to-integers-in-python
112
+ roman_numerals = {"I": 1, "V": 5, "X": 10, "L": 50, "C": 100, "D": 500, "M": 1000}
113
+ result = 0
114
+ num = m.group(0)
115
+ for i, c in enumerate(num):
116
+ if (i + 1) == len(num) or roman_numerals[c] >= roman_numerals[num[i + 1]]:
117
+ result += roman_numerals[c]
118
+ else:
119
+ result -= roman_numerals[c]
120
+ return str(result)
121
+
122
+
123
+ def _expand_number(m):
124
+ _, number, suffix = re.split(r"(\d+(?:'?\d+)?)", m.group(0))
125
+ number = int(number)
126
+ if (
127
+ number > 1000
128
+ and number < 10000
129
+ and (number % 100 == 0)
130
+ and (number % 1000 != 0)
131
+ ):
132
+ text = _inflect.number_to_words(number // 100) + " hundred"
133
+ elif number > 1000 and number < 3000:
134
+ if number == 2000:
135
+ text = "two thousand"
136
+ elif number > 2000 and number < 2010:
137
+ text = "two thousand " + _inflect.number_to_words(number % 100)
138
+ elif number % 100 == 0:
139
+ text = _inflect.number_to_words(number // 100) + " hundred"
140
+ else:
141
+ number = _inflect.number_to_words(
142
+ number, andword="", zero="oh", group=2
143
+ ).replace(", ", " ")
144
+ number = re.sub(r"-", " ", number)
145
+ text = number
146
+ else:
147
+ number = _inflect.number_to_words(number, andword="and")
148
+ number = re.sub(r"-", " ", number)
149
+ number = re.sub(r",", "", number)
150
+ text = number
151
+
152
+ if suffix in ("'s", "s"):
153
+ if text[-1] == "y":
154
+ text = text[:-1] + "ies"
155
+ else:
156
+ text = text + suffix
157
+
158
+ return text
159
+
160
+
161
+ def normalize_currency(text):
162
+ return re.sub(_currency_re, _expand_currency, text)
163
+
164
+
165
+ def normalize_numbers(text):
166
+ text = re.sub(_comma_number_re, _remove_commas, text)
167
+ text = re.sub(_currency_re, _expand_currency, text)
168
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
169
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
170
+ # text = re.sub(_range_re, _expand_range, text)
171
+ # text = re.sub(_measurement_re, _expand_measurement, text)
172
+ text = re.sub(_roman_re, _expand_roman, text)
173
+ text = re.sub(_multiply_re, _expand_multiply, text)
174
+ text = re.sub(_number_re, _expand_number, text)
175
+ return text
tts_text_processing/symbols.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """adapted from https://github.com/keithito/tacotron"""
2
+
3
+ """
4
+ Defines the set of symbols used in text input to the model.
5
+
6
+ The default is a set of ASCII characters that works well for English or text
7
+ that has been run through Unidecode. For other data, you can modify
8
+ _characters."""
9
+
10
+
11
+ arpabet = [
12
+ "AA",
13
+ "AA0",
14
+ "AA1",
15
+ "AA2",
16
+ "AE",
17
+ "AE0",
18
+ "AE1",
19
+ "AE2",
20
+ "AH",
21
+ "AH0",
22
+ "AH1",
23
+ "AH2",
24
+ "AO",
25
+ "AO0",
26
+ "AO1",
27
+ "AO2",
28
+ "AW",
29
+ "AW0",
30
+ "AW1",
31
+ "AW2",
32
+ "AY",
33
+ "AY0",
34
+ "AY1",
35
+ "AY2",
36
+ "B",
37
+ "CH",
38
+ "D",
39
+ "DH",
40
+ "EH",
41
+ "EH0",
42
+ "EH1",
43
+ "EH2",
44
+ "ER",
45
+ "ER0",
46
+ "ER1",
47
+ "ER2",
48
+ "EY",
49
+ "EY0",
50
+ "EY1",
51
+ "EY2",
52
+ "F",
53
+ "G",
54
+ "HH",
55
+ "IH",
56
+ "IH0",
57
+ "IH1",
58
+ "IH2",
59
+ "IY",
60
+ "IY0",
61
+ "IY1",
62
+ "IY2",
63
+ "JH",
64
+ "K",
65
+ "L",
66
+ "M",
67
+ "N",
68
+ "NG",
69
+ "OW",
70
+ "OW0",
71
+ "OW1",
72
+ "OW2",
73
+ "OY",
74
+ "OY0",
75
+ "OY1",
76
+ "OY2",
77
+ "P",
78
+ "R",
79
+ "S",
80
+ "SH",
81
+ "T",
82
+ "TH",
83
+ "UH",
84
+ "UH0",
85
+ "UH1",
86
+ "UH2",
87
+ "UW",
88
+ "UW0",
89
+ "UW1",
90
+ "UW2",
91
+ "V",
92
+ "W",
93
+ "Y",
94
+ "Z",
95
+ "ZH",
96
+ ]
97
+
98
+
99
+ def get_symbols(symbol_set):
100
+ if symbol_set == "english_basic":
101
+ _pad = "_"
102
+ _punctuation = "!'\"(),.:;? "
103
+ _special = "-"
104
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
105
+ _arpabet = ["@" + s for s in arpabet]
106
+ symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
107
+ elif symbol_set == "english_basic_lowercase":
108
+ _pad = "_"
109
+ _punctuation = "!'\"(),.:;? "
110
+ _special = "-"
111
+ _letters = "abcdefghijklmnopqrstuvwxyz"
112
+ _arpabet = ["@" + s for s in arpabet]
113
+ symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
114
+ elif symbol_set == "english_expanded":
115
+ _punctuation = "!'\",.:;? "
116
+ _math = "#%&*+-/[]()"
117
+ _special = "_@©°½—₩€$"
118
+ _accented = "áçéêëñöøćž"
119
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
120
+ _arpabet = ["@" + s for s in arpabet]
121
+ symbols = (
122
+ list(_punctuation + _math + _special + _accented + _letters) + _arpabet
123
+ )
124
+ elif symbol_set == "ukrainian":
125
+ _punctuation = "'.,?! "
126
+ _special = "-+"
127
+ _letters = "абвгґдежзийклмнопрстуфхцчшщьюяєії"
128
+ symbols = list(_punctuation + _special + _letters)
129
+ elif symbol_set == "radtts":
130
+ _punctuation = "!'\",.:;? "
131
+ _math = "#%&*+-/[]()"
132
+ _special = "_@©°½—₩€$"
133
+ _accented = "áçéêëñöøćž"
134
+ _numbers = "0123456789"
135
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
136
+ _arpabet = ["@" + s for s in arpabet]
137
+ symbols = (
138
+ list(_punctuation + _math + _special + _accented + _numbers + _letters)
139
+ + _arpabet
140
+ )
141
+ else:
142
+ raise Exception("{} symbol set does not exist".format(symbol_set))
143
+
144
+ return symbols
tts_text_processing/text_processing.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """adapted from https://github.com/keithito/tacotron"""
2
+
3
+ import re
4
+ import numpy as np
5
+ from .cleaners import Cleaner
6
+ from .symbols import get_symbols
7
+ from .grapheme_dictionary import Grapheme2PhonemeDictionary
8
+
9
+
10
+ #########
11
+ # REGEX #
12
+ #########
13
+
14
+ # Regular expression matching text enclosed in curly braces for encoding
15
+ _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
16
+
17
+ # Regular expression matching words and not words
18
+ _words_re = re.compile(
19
+ r"([a-zA-ZÀ-ž]+['][a-zA-ZÀ-ž]+|[a-zA-ZÀ-ž]+)|([{][^}]+[}]|[^a-zA-ZÀ-ž{}]+)"
20
+ )
21
+
22
+
23
+ def lines_to_list(filename):
24
+ with open(filename, encoding="utf-8") as f:
25
+ lines = f.readlines()
26
+ lines = [l.rstrip() for l in lines]
27
+ return lines
28
+
29
+
30
+ class TextProcessing(object):
31
+ def __init__(
32
+ self,
33
+ symbol_set,
34
+ cleaner_name,
35
+ heteronyms_path,
36
+ phoneme_dict_path,
37
+ p_phoneme,
38
+ handle_phoneme,
39
+ handle_phoneme_ambiguous,
40
+ prepend_space_to_text=False,
41
+ append_space_to_text=False,
42
+ add_bos_eos_to_text=False,
43
+ encoding="latin-1",
44
+ ):
45
+ if heteronyms_path is not None and heteronyms_path != "":
46
+ self.heteronyms = set(lines_to_list(heteronyms_path))
47
+ else:
48
+ self.heteronyms = []
49
+ # phoneme dict
50
+ self.phonemedict = {}
51
+
52
+ self.p_phoneme = p_phoneme
53
+ self.handle_phoneme = handle_phoneme
54
+ self.handle_phoneme_ambiguous = handle_phoneme_ambiguous
55
+
56
+ self.symbols = get_symbols(symbol_set)
57
+ self.cleaner_names = cleaner_name
58
+ self.cleaner = Cleaner(cleaner_name, self.phonemedict)
59
+
60
+ self.prepend_space_to_text = prepend_space_to_text
61
+ self.append_space_to_text = append_space_to_text
62
+ self.add_bos_eos_to_text = add_bos_eos_to_text
63
+
64
+ if add_bos_eos_to_text:
65
+ self.symbols.append("<bos>")
66
+ self.symbols.append("<eos>")
67
+
68
+ # Mappings from symbol to numeric ID and vice versa:
69
+ self.symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
70
+ self.id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
71
+
72
+ def text_to_sequence(self, text):
73
+ sequence = []
74
+
75
+ # Check for curly braces and treat their contents as phoneme:
76
+ while len(text):
77
+ m = _curly_re.match(text)
78
+ if not m:
79
+ sequence += self.symbols_to_sequence(text)
80
+ break
81
+ sequence += self.symbols_to_sequence(m.group(1))
82
+ sequence += self.phoneme_to_sequence(m.group(2))
83
+ text = m.group(3)
84
+
85
+ return sequence
86
+
87
+ def sequence_to_text(self, sequence):
88
+ result = ""
89
+ for symbol_id in sequence:
90
+ if symbol_id in self.id_to_symbol:
91
+ s = self.id_to_symbol[symbol_id]
92
+ # Enclose phoneme back in curly braces:
93
+ if len(s) > 1 and s[0] == "@":
94
+ s = "{%s}" % s[1:]
95
+ result += s
96
+ return result.replace("}{", " ")
97
+
98
+ def clean_text(self, text):
99
+ text = self.cleaner(text)
100
+ return text
101
+
102
+ def symbols_to_sequence(self, symbols):
103
+ return [self.symbol_to_id[s] for s in symbols if s in self.symbol_to_id]
104
+
105
+ def phoneme_to_sequence(self, text):
106
+ return self.symbols_to_sequence(["@" + s for s in text.split()])
107
+
108
+ def get_phoneme(self, word):
109
+ phoneme_suffix = ""
110
+
111
+ if word.lower() in self.heteronyms:
112
+ return word
113
+
114
+ if len(word) > 2 and word.endswith("'s"):
115
+ phoneme = self.phonemedict.lookup(word)
116
+ if phoneme is None:
117
+ phoneme = self.phonemedict.lookup(word[:-2])
118
+ phoneme_suffix = "" if phoneme is None else " Z"
119
+
120
+ elif len(word) > 1 and word.endswith("s"):
121
+ phoneme = self.phonemedict.lookup(word)
122
+ if phoneme is None:
123
+ phoneme = self.phonemedict.lookup(word[:-1])
124
+ phoneme_suffix = "" if phoneme is None else " Z"
125
+ else:
126
+ phoneme = self.phonemedict.lookup(word)
127
+
128
+ if phoneme is None:
129
+ return word
130
+
131
+ if len(phoneme) > 1:
132
+ if self.handle_phoneme_ambiguous == "first":
133
+ phoneme = phoneme[0]
134
+ elif self.handle_phoneme_ambiguous == "random":
135
+ phoneme = np.random.choice(phoneme)
136
+ elif self.handle_phoneme_ambiguous == "ignore":
137
+ return word
138
+ else:
139
+ phoneme = phoneme[0]
140
+
141
+ phoneme = "{" + phoneme + phoneme_suffix + "}"
142
+
143
+ return phoneme
144
+
145
+ def encode_text(self, text, return_all=False):
146
+ text_clean = self.clean_text(text)
147
+ text = text_clean
148
+
149
+ text_phoneme = ""
150
+ if self.p_phoneme > 0:
151
+ text_phoneme = self.convert_to_phoneme(text)
152
+ text = text_phoneme
153
+
154
+ text_encoded = self.text_to_sequence(text)
155
+
156
+ if self.prepend_space_to_text:
157
+ text_encoded.insert(0, self.symbol_to_id[" "])
158
+
159
+ if self.append_space_to_text:
160
+ text_encoded.append(self.symbol_to_id[" "])
161
+
162
+ if self.add_bos_eos_to_text:
163
+ text_encoded.insert(0, self.symbol_to_id["<bos>"])
164
+ text_encoded.append(self.symbol_to_id["<eos>"])
165
+
166
+ if return_all:
167
+ return text_encoded, text_clean, text_phoneme
168
+
169
+ return text_encoded
170
+
171
+ def convert_to_phoneme(self, text):
172
+ if self.handle_phoneme == "sentence":
173
+ if np.random.uniform() < self.p_phoneme:
174
+ words = _words_re.findall(text)
175
+ text_phoneme = [
176
+ self.get_phoneme(word[0])
177
+ if (word[0] != "")
178
+ else re.sub(r"\s(\d)", r"\1", word[1].upper())
179
+ for word in words
180
+ ]
181
+ text_phoneme = "".join(text_phoneme)
182
+ text = text_phoneme
183
+ elif self.handle_phoneme == "word":
184
+ words = _words_re.findall(text)
185
+ text_phoneme = [
186
+ re.sub(r"\s(\d)", r"\1", word[1].upper())
187
+ if word[0] == ""
188
+ else (
189
+ self.get_phoneme(word[0])
190
+ if np.random.uniform() < self.p_phoneme
191
+ else word[0]
192
+ )
193
+ for word in words
194
+ ]
195
+ text_phoneme = "".join(text_phoneme)
196
+ text = text_phoneme
197
+ elif self.handle_phoneme != "":
198
+ raise Exception(
199
+ "{} handle_phoneme is not supported".format(self.handle_phoneme)
200
+ )
201
+ return text