Spaces:
Running
Running
Init
Browse files- .dockerignore +3 -0
- .gitignore +9 -0
- RADTTS-LICENSE +19 -0
- README.md +24 -7
- alignment.py +54 -0
- app.py +356 -0
- attribute_prediction_model.py +402 -0
- audio_processing.py +328 -0
- autoregressive_flow.py +259 -0
- common.py +1083 -0
- configs/radtts-pp-dap-model.json +218 -0
- data.py +606 -0
- distributed.py +161 -0
- filelists/3speakers_ukrainian_train_filelist.txt +0 -0
- filelists/3speakers_ukrainian_train_filelist_dc.txt +0 -0
- filelists/3speakers_ukrainian_val_filelist.txt +85 -0
- filelists/3speakers_ukrainian_val_filelist_dc.txt +85 -0
- loss.py +228 -0
- partialconv1d.py +77 -0
- radam.py +114 -0
- radtts.py +936 -0
- requirements-dev.txt +1 -0
- requirements.txt +15 -0
- splines.py +326 -0
- transformer.py +219 -0
- tts_text_processing/LICENSE +19 -0
- tts_text_processing/abbreviations.py +57 -0
- tts_text_processing/acronyms.py +69 -0
- tts_text_processing/cleaners.py +123 -0
- tts_text_processing/cmudict.py +140 -0
- tts_text_processing/datestime.py +24 -0
- tts_text_processing/grapheme_dictionary.py +37 -0
- tts_text_processing/heteronyms +413 -0
- tts_text_processing/letters_and_numbers.py +96 -0
- tts_text_processing/numerical.py +175 -0
- tts_text_processing/symbols.py +144 -0
- tts_text_processing/text_processing.py +201 -0
.dockerignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.ruff_cache/
|
2 |
+
.venv/
|
3 |
+
models/
|
.gitignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea/
|
2 |
+
.venv/
|
3 |
+
.ruff_cache/
|
4 |
+
__pycache__/
|
5 |
+
|
6 |
+
flagged/
|
7 |
+
models/
|
8 |
+
|
9 |
+
audio.wav
|
RADTTS-LICENSE
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation
|
6 |
+
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
Software is furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in
|
11 |
+
all copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
DEALINGS IN THE SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,29 @@
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo: indigo
|
6 |
sdk: gradio
|
|
|
|
|
|
|
|
|
7 |
sdk_version: 5.19.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
license: apache-2.0
|
3 |
+
title: RAD-TTS++ Ukrainian (Vocos)
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
emoji: 🎧
|
6 |
+
colorFrom: blue
|
7 |
+
colorTo: gray
|
8 |
+
short_description: Use RAD-TTS++ model to synthesize text in Ukrainian
|
9 |
sdk_version: 5.19.0
|
|
|
|
|
10 |
---
|
11 |
|
12 |
+
## Install
|
13 |
+
|
14 |
+
```shell
|
15 |
+
uv venv --python 3.10
|
16 |
+
|
17 |
+
source .venv/bin/activate
|
18 |
+
|
19 |
+
uv pip install -r requirements.txt
|
20 |
+
|
21 |
+
# in development mode
|
22 |
+
uv pip install -r requirements-dev.txt
|
23 |
+
```
|
24 |
+
|
25 |
+
## Run
|
26 |
+
|
27 |
+
```shell
|
28 |
+
python app.py
|
29 |
+
```
|
alignment.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
from numba import jit
|
24 |
+
|
25 |
+
|
26 |
+
@jit(nopython=True)
|
27 |
+
def mas_width1(attn_map):
|
28 |
+
"""mas with hardcoded width=1"""
|
29 |
+
# assumes mel x text
|
30 |
+
opt = np.zeros_like(attn_map)
|
31 |
+
attn_map = np.log(attn_map)
|
32 |
+
attn_map[0, 1:] = -np.inf
|
33 |
+
log_p = np.zeros_like(attn_map)
|
34 |
+
log_p[0, :] = attn_map[0, :]
|
35 |
+
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
|
36 |
+
for i in range(1, attn_map.shape[0]):
|
37 |
+
for j in range(attn_map.shape[1]): # for each text dim
|
38 |
+
prev_log = log_p[i - 1, j]
|
39 |
+
prev_j = j
|
40 |
+
|
41 |
+
if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
|
42 |
+
prev_log = log_p[i - 1, j - 1]
|
43 |
+
prev_j = j - 1
|
44 |
+
|
45 |
+
log_p[i, j] = attn_map[i, j] + prev_log
|
46 |
+
prev_ind[i, j] = prev_j
|
47 |
+
|
48 |
+
# now backtrack
|
49 |
+
curr_text_idx = attn_map.shape[1] - 1
|
50 |
+
for i in range(attn_map.shape[0] - 1, -1, -1):
|
51 |
+
opt[i, curr_text_idx] = 1
|
52 |
+
curr_text_idx = prev_ind[i, curr_text_idx]
|
53 |
+
opt[0, curr_text_idx] = 1
|
54 |
+
return opt
|
app.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
|
6 |
+
from importlib.metadata import version
|
7 |
+
from enum import Enum
|
8 |
+
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
|
11 |
+
use_zerogpu = False
|
12 |
+
|
13 |
+
try:
|
14 |
+
import spaces # it's for ZeroGPU
|
15 |
+
use_zerogpu = True
|
16 |
+
print("ZeroGPU is available, changing inference call.")
|
17 |
+
except ImportError:
|
18 |
+
print("ZeroGPU is not available, skipping...")
|
19 |
+
|
20 |
+
import gradio as gr
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torchaudio
|
24 |
+
|
25 |
+
# Vocos
|
26 |
+
from vocos import Vocos
|
27 |
+
|
28 |
+
# RAD-TTS code
|
29 |
+
from radtts import RADTTS
|
30 |
+
from data import Data
|
31 |
+
from common import update_params
|
32 |
+
|
33 |
+
use_cuda = torch.cuda.is_available()
|
34 |
+
|
35 |
+
if use_cuda:
|
36 |
+
print("CUDA is available, setting correct inference_device variable.")
|
37 |
+
device = "cuda"
|
38 |
+
else:
|
39 |
+
device = "cpu"
|
40 |
+
|
41 |
+
|
42 |
+
def download_file_from_repo(
|
43 |
+
repo_id: str,
|
44 |
+
filename: str,
|
45 |
+
local_dir: str = ".",
|
46 |
+
repo_type: str = "model",
|
47 |
+
) -> str:
|
48 |
+
try:
|
49 |
+
os.makedirs(local_dir, exist_ok=True)
|
50 |
+
|
51 |
+
file_path = hf_hub_download(
|
52 |
+
repo_id=repo_id,
|
53 |
+
filename=filename,
|
54 |
+
local_dir=local_dir,
|
55 |
+
cache_dir=None,
|
56 |
+
force_download=False,
|
57 |
+
repo_type=repo_type,
|
58 |
+
)
|
59 |
+
|
60 |
+
return file_path
|
61 |
+
except Exception as e:
|
62 |
+
raise Exception(f"An error occurred during download: {e}") from e
|
63 |
+
|
64 |
+
|
65 |
+
download_file_from_repo(
|
66 |
+
"Yehor/radtts-uk",
|
67 |
+
"radtts-pp-dap-model/model_dap_84000.pt",
|
68 |
+
"./models/",
|
69 |
+
)
|
70 |
+
|
71 |
+
# Init the model
|
72 |
+
seed = 1234
|
73 |
+
|
74 |
+
config = "configs/radtts-pp-dap-model.json"
|
75 |
+
radtts_path = "models/radtts-pp-dap-model/model_dap_84000.pt"
|
76 |
+
|
77 |
+
params = []
|
78 |
+
|
79 |
+
# Load the config
|
80 |
+
with open(config) as f:
|
81 |
+
data = f.read()
|
82 |
+
|
83 |
+
config = json.loads(data)
|
84 |
+
update_params(config, params)
|
85 |
+
|
86 |
+
data_config = config["data_config"]
|
87 |
+
model_config = config["model_config"]
|
88 |
+
|
89 |
+
# Seed
|
90 |
+
torch.manual_seed(seed)
|
91 |
+
torch.cuda.manual_seed(seed)
|
92 |
+
|
93 |
+
# Load vocoder
|
94 |
+
vocos = Vocos.from_pretrained("patriotyk/vocos-mel-hifigan-compat-44100khz").to(device)
|
95 |
+
|
96 |
+
# Load RAD-TTS
|
97 |
+
if use_cuda:
|
98 |
+
radtts = RADTTS(**model_config).cuda()
|
99 |
+
else:
|
100 |
+
radtts = RADTTS(**model_config)
|
101 |
+
|
102 |
+
radtts.enable_inverse_cache() # cache inverse matrix for 1x1 invertible convs
|
103 |
+
|
104 |
+
checkpoint_dict = torch.load(radtts_path, map_location="cpu") # todo: CPU?
|
105 |
+
radtts.load_state_dict(checkpoint_dict["state_dict"], strict=False)
|
106 |
+
radtts.eval()
|
107 |
+
|
108 |
+
print(f"Loaded checkpoint '{radtts_path}')")
|
109 |
+
|
110 |
+
ignore_keys = ["training_files", "validation_files"]
|
111 |
+
trainset = Data(
|
112 |
+
data_config["training_files"],
|
113 |
+
**dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
|
114 |
+
)
|
115 |
+
|
116 |
+
# Config
|
117 |
+
concurrency_limit = 5
|
118 |
+
|
119 |
+
title = "RAD-TTS++ Ukrainian"
|
120 |
+
|
121 |
+
# https://www.tablesgenerator.com/markdown_tables
|
122 |
+
authors_table = """
|
123 |
+
## Authors
|
124 |
+
|
125 |
+
Follow them on social networks and **contact** if you need any help or have any questions:
|
126 |
+
|
127 |
+
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** |
|
128 |
+
|-------------------------------------------------------------------------------------------------|
|
129 |
+
| https://t.me/smlkw in Telegram |
|
130 |
+
| https://x.com/yehor_smoliakov at X |
|
131 |
+
| https://github.com/egorsmkv at GitHub |
|
132 |
+
| https://huggingface.co/Yehor at Hugging Face |
|
133 |
+
| or use [email protected] |
|
134 |
+
""".strip()
|
135 |
+
|
136 |
+
description_head = f"""
|
137 |
+
# {title}
|
138 |
+
|
139 |
+
## Overview
|
140 |
+
|
141 |
+
Type your text in Ukrainian and select a voice to synthesize speech using [the RAD-TTS++ model](https://huggingface.co/Yehor/radtts-uk) and [Vocos](https://huggingface.co/patriotyk/vocos-mel-hifigan-compat-44100khz) with 44100 Hz.
|
142 |
+
""".strip()
|
143 |
+
|
144 |
+
description_foot = f"""
|
145 |
+
{authors_table}
|
146 |
+
""".strip()
|
147 |
+
|
148 |
+
tech_env = f"""
|
149 |
+
#### Environment
|
150 |
+
|
151 |
+
- Python: {sys.version}
|
152 |
+
""".strip()
|
153 |
+
|
154 |
+
tech_libraries = f"""
|
155 |
+
#### Libraries
|
156 |
+
|
157 |
+
- gradio: {version("gradio")}
|
158 |
+
- torch: {version("torch")}
|
159 |
+
- scipy: {version("scipy")}
|
160 |
+
- numba: {version("numba")}
|
161 |
+
- librosa: {version("librosa")}
|
162 |
+
- unidecode: {version("unidecode")}
|
163 |
+
- inflect: {version("inflect")}
|
164 |
+
""".strip()
|
165 |
+
|
166 |
+
|
167 |
+
class VoiceOption(Enum):
|
168 |
+
Tetiana = "Tetiana (female) 👩"
|
169 |
+
Mykyta = "Mykyta (male) 👨"
|
170 |
+
Lada = "Lada (female) 👩"
|
171 |
+
|
172 |
+
|
173 |
+
voice_mapping = {
|
174 |
+
VoiceOption.Tetiana.value: "tetiana",
|
175 |
+
VoiceOption.Mykyta.value: "mykyta",
|
176 |
+
VoiceOption.Lada.value: "lada",
|
177 |
+
}
|
178 |
+
|
179 |
+
|
180 |
+
examples = [
|
181 |
+
[
|
182 |
+
"Прокинувся ґазда вранці. Пішов, вичистив з-під коня, вичистив з-під бика, вичистив з-під овечок, вибрав молодняк, відніс його набік.",
|
183 |
+
VoiceOption.Mykyta.value,
|
184 |
+
],
|
185 |
+
[
|
186 |
+
"Пішов взяв сіна, дав корові. Пішов взяв сіна, дав бикові. Ячміню коняці насипав. Зайшов почистив корову, зайшов ��очистив бика, зайшов почистив коня, за яйця його мацнув.",
|
187 |
+
VoiceOption.Lada.value,
|
188 |
+
],
|
189 |
+
[
|
190 |
+
"Кінь ногою здригнув, на хазяїна ласкавим оком подивився. Тоді дядько пішов відкрив курей, гусей, качок, повиносив їм зерна, огірків нарізаних, нагодував. Коли чує – з хати дружина кличе. Зайшов. Дітки повмивані, сидять за столом, всі чекають тата. Взяв він ложку, перехрестив дітей, перехрестив лоба, почали снідати. Поснідали, він дістав пряників, роздав дітям. Діти зібралися, пішли в школу. Дядько вийшов, сів на призьбі, взяв сапку, почав мантачити. Мантачив-мантачив, коли – жінка виходить. Він їй ту сапку дає, ласкаво за сраку вщипнув, жінка до нього лагідно всміхнулася, пішла на город – сапати. Коли – йде пастух і товар кличе в череду. Повідмикав дядько овечок, коровку, бика, коня, все відпустив. Сів попри хати, дістав табАку, відірвав шмат газети, насипав, наслинив собі гарну таку цигарку. Благодать божа – і сонечко вже здійнялося над деревами. Дядько встромив цигарку в рота, дістав сірники, тільки чиркати – коли раптом з хати: Доброе утро! Московское время – шесть часов утра! Витяг дядько цигарку с рота, сплюнув набік, і сам собі каже: Ана маєш. Прокинулись, бляді!",
|
191 |
+
VoiceOption.Tetiana.value,
|
192 |
+
],
|
193 |
+
]
|
194 |
+
|
195 |
+
|
196 |
+
def inference(text, voice):
|
197 |
+
if not text:
|
198 |
+
raise gr.Error("Please paste your text.")
|
199 |
+
|
200 |
+
gr.Info("Starting...", duration=0.5)
|
201 |
+
|
202 |
+
speaker = voice_mapping[voice]
|
203 |
+
speaker = speaker_text = speaker_attributes = speaker
|
204 |
+
|
205 |
+
n_takes = 1
|
206 |
+
|
207 |
+
sigma = 0.8 # sampling sigma for decoder
|
208 |
+
sigma_tkndur = 0.666 # sampling sigma for duration
|
209 |
+
sigma_f0 = 1.0 # sampling sigma for f0
|
210 |
+
sigma_energy = 1.0 # sampling sigma for energy avg
|
211 |
+
|
212 |
+
token_dur_scaling = 1.0
|
213 |
+
|
214 |
+
f0_mean = 0
|
215 |
+
f0_std = 0
|
216 |
+
energy_mean = 0
|
217 |
+
energy_std = 0
|
218 |
+
|
219 |
+
if use_cuda:
|
220 |
+
speaker_id = trainset.get_speaker_id(speaker).cuda()
|
221 |
+
speaker_id_text, speaker_id_attributes = speaker_id, speaker_id
|
222 |
+
|
223 |
+
if speaker_text is not None:
|
224 |
+
speaker_id_text = trainset.get_speaker_id(speaker_text).cuda()
|
225 |
+
|
226 |
+
if speaker_attributes is not None:
|
227 |
+
speaker_id_attributes = trainset.get_speaker_id(speaker_attributes).cuda()
|
228 |
+
|
229 |
+
tensor_text = trainset.get_text(text).cuda()[None]
|
230 |
+
else:
|
231 |
+
speaker_id = trainset.get_speaker_id(speaker)
|
232 |
+
speaker_id_text, speaker_id_attributes = speaker_id, speaker_id
|
233 |
+
|
234 |
+
if speaker_text is not None:
|
235 |
+
speaker_id_text = trainset.get_speaker_id(speaker_text)
|
236 |
+
|
237 |
+
if speaker_attributes is not None:
|
238 |
+
speaker_id_attributes = trainset.get_speaker_id(speaker_attributes)
|
239 |
+
|
240 |
+
tensor_text = trainset.get_text(text)[None]
|
241 |
+
|
242 |
+
inference_start = time.time()
|
243 |
+
|
244 |
+
for take in range(n_takes):
|
245 |
+
with torch.autocast(device, enabled=False):
|
246 |
+
with torch.inference_mode():
|
247 |
+
outputs = radtts.infer(
|
248 |
+
speaker_id,
|
249 |
+
tensor_text,
|
250 |
+
sigma,
|
251 |
+
sigma_tkndur,
|
252 |
+
sigma_f0,
|
253 |
+
sigma_energy,
|
254 |
+
token_dur_scaling,
|
255 |
+
token_duration_max=100,
|
256 |
+
speaker_id_text=speaker_id_text,
|
257 |
+
speaker_id_attributes=speaker_id_attributes,
|
258 |
+
f0_mean=f0_mean,
|
259 |
+
f0_std=f0_std,
|
260 |
+
energy_mean=energy_mean,
|
261 |
+
energy_std=energy_std,
|
262 |
+
use_cuda=use_cuda,
|
263 |
+
)
|
264 |
+
|
265 |
+
mel = outputs["mel"]
|
266 |
+
|
267 |
+
gr.Info(
|
268 |
+
"Synthesized MEL spectrogram, converting to WAVE.", duration=0.5
|
269 |
+
)
|
270 |
+
|
271 |
+
wav_gen = vocos.decode(mel)
|
272 |
+
wav_gen_float = wav_gen.cpu()
|
273 |
+
|
274 |
+
torchaudio.save("audio.wav", wav_gen_float, 44_100, encoding="PCM_S")
|
275 |
+
|
276 |
+
duration = len(wav_gen_float[0]) / 44_100
|
277 |
+
|
278 |
+
elapsed_time = time.time() - inference_start
|
279 |
+
rtf = elapsed_time / duration
|
280 |
+
|
281 |
+
speed_ratio = duration / elapsed_time
|
282 |
+
speech_rate = len(text.split(" ")) / duration
|
283 |
+
|
284 |
+
rtf_value = f"Real-Time Factor: {round(rtf, 4)}, time: {round(elapsed_time, 4)} seconds, audio duration: {round(duration, 4)} seconds. Speed ratio: {round(speed_ratio, 2)}x. Speech rate: {round(speech_rate, 4)} words-per-second."
|
285 |
+
|
286 |
+
gr.Success("Finished!", duration=0.5)
|
287 |
+
|
288 |
+
return [gr.Audio("audio.wav"), rtf_value]
|
289 |
+
|
290 |
+
|
291 |
+
try:
|
292 |
+
@spaces.GPU
|
293 |
+
def inference_zerogpu(text, voice):
|
294 |
+
return inference(text, voice)
|
295 |
+
except NameError:
|
296 |
+
print("ZeroGPU is not available, skipping...")
|
297 |
+
|
298 |
+
|
299 |
+
def inference_cpu(text, voice):
|
300 |
+
return inference(text, voice)
|
301 |
+
|
302 |
+
|
303 |
+
demo = gr.Blocks(
|
304 |
+
title=title,
|
305 |
+
analytics_enabled=False,
|
306 |
+
theme=gr.themes.Base(),
|
307 |
+
)
|
308 |
+
|
309 |
+
with demo:
|
310 |
+
gr.Markdown(description_head)
|
311 |
+
|
312 |
+
gr.Markdown("## Usage")
|
313 |
+
|
314 |
+
with gr.Row():
|
315 |
+
with gr.Column():
|
316 |
+
audio = gr.Audio(label="Synthesized audio")
|
317 |
+
rtf = gr.Markdown(
|
318 |
+
label="Real-Time Factor",
|
319 |
+
value="Here you will see how fast the model and the speaker is.",
|
320 |
+
)
|
321 |
+
|
322 |
+
with gr.Row():
|
323 |
+
with gr.Column():
|
324 |
+
text = gr.Text(
|
325 |
+
label="Text",
|
326 |
+
value="Сл+ава Укра+їні! — українське вітання, національне гасло.",
|
327 |
+
)
|
328 |
+
voice = gr.Radio(
|
329 |
+
label="Voice",
|
330 |
+
choices=[option.value for option in VoiceOption],
|
331 |
+
value=VoiceOption.Tetiana.value,
|
332 |
+
)
|
333 |
+
|
334 |
+
gr.Button("Run").click(
|
335 |
+
inference_zerogpu if use_zerogpu else inference_cpu,
|
336 |
+
concurrency_limit=concurrency_limit,
|
337 |
+
inputs=[text, voice],
|
338 |
+
outputs=[audio, rtf],
|
339 |
+
)
|
340 |
+
|
341 |
+
with gr.Row():
|
342 |
+
gr.Examples(
|
343 |
+
label="Choose an example",
|
344 |
+
inputs=[text, voice],
|
345 |
+
examples=examples,
|
346 |
+
)
|
347 |
+
|
348 |
+
gr.Markdown(description_foot)
|
349 |
+
|
350 |
+
gr.Markdown("### Gradio app uses:")
|
351 |
+
gr.Markdown(tech_env)
|
352 |
+
gr.Markdown(tech_libraries)
|
353 |
+
|
354 |
+
if __name__ == "__main__":
|
355 |
+
demo.queue()
|
356 |
+
demo.launch()
|
attribute_prediction_model.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
import torch
|
22 |
+
from torch import nn
|
23 |
+
from common import ConvNorm, Invertible1x1Conv
|
24 |
+
from common import AffineTransformationLayer, SplineTransformationLayer
|
25 |
+
from common import ConvLSTMLinear
|
26 |
+
from transformer import FFTransformer
|
27 |
+
from autoregressive_flow import AR_Step, AR_Back_Step
|
28 |
+
|
29 |
+
|
30 |
+
def get_attribute_prediction_model(config):
|
31 |
+
name = config["name"]
|
32 |
+
hparams = config["hparams"]
|
33 |
+
if name == "dap":
|
34 |
+
model = DAP(**hparams)
|
35 |
+
elif name == "bgap":
|
36 |
+
model = BGAP(**hparams)
|
37 |
+
elif name == "agap":
|
38 |
+
model = AGAP(**hparams)
|
39 |
+
else:
|
40 |
+
raise Exception("{} model is not supported".format(name))
|
41 |
+
|
42 |
+
return model
|
43 |
+
|
44 |
+
|
45 |
+
class AttributeProcessing:
|
46 |
+
def __init__(self, take_log_of_input=False):
|
47 |
+
super(AttributeProcessing).__init__()
|
48 |
+
self.take_log_of_input = take_log_of_input
|
49 |
+
|
50 |
+
def normalize(self, x):
|
51 |
+
if self.take_log_of_input:
|
52 |
+
x = torch.log(x + 1)
|
53 |
+
return x
|
54 |
+
|
55 |
+
def denormalize(self, x):
|
56 |
+
if self.take_log_of_input:
|
57 |
+
x = torch.exp(x) - 1
|
58 |
+
return x
|
59 |
+
|
60 |
+
|
61 |
+
class BottleneckLayerLayer(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
in_dim,
|
65 |
+
reduction_factor,
|
66 |
+
norm="weightnorm",
|
67 |
+
non_linearity="relu",
|
68 |
+
kernel_size=3,
|
69 |
+
use_partial_padding=False,
|
70 |
+
):
|
71 |
+
super(BottleneckLayerLayer, self).__init__()
|
72 |
+
|
73 |
+
self.reduction_factor = reduction_factor
|
74 |
+
reduced_dim = int(in_dim / reduction_factor)
|
75 |
+
self.out_dim = reduced_dim
|
76 |
+
if self.reduction_factor > 1:
|
77 |
+
fn = ConvNorm(
|
78 |
+
in_dim,
|
79 |
+
reduced_dim,
|
80 |
+
kernel_size=kernel_size,
|
81 |
+
use_weight_norm=(norm == "weightnorm"),
|
82 |
+
)
|
83 |
+
if norm == "instancenorm":
|
84 |
+
fn = nn.Sequential(fn, nn.InstanceNorm1d(reduced_dim, affine=True))
|
85 |
+
|
86 |
+
self.projection_fn = fn
|
87 |
+
self.non_linearity = nn.ReLU()
|
88 |
+
if non_linearity == "leakyrelu":
|
89 |
+
self.non_linearity = nn.LeakyReLU()
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
if self.reduction_factor > 1:
|
93 |
+
x = self.projection_fn(x)
|
94 |
+
x = self.non_linearity(x)
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
class DAP(nn.Module):
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
n_speaker_dim,
|
102 |
+
bottleneck_hparams,
|
103 |
+
take_log_of_input,
|
104 |
+
arch_hparams,
|
105 |
+
use_transformer=False,
|
106 |
+
):
|
107 |
+
super(DAP, self).__init__()
|
108 |
+
self.attribute_processing = AttributeProcessing(take_log_of_input)
|
109 |
+
self.bottleneck_layer = BottleneckLayerLayer(**bottleneck_hparams)
|
110 |
+
|
111 |
+
arch_hparams["in_dim"] = self.bottleneck_layer.out_dim + n_speaker_dim
|
112 |
+
if use_transformer:
|
113 |
+
self.feat_pred_fn = FFTransformer(**arch_hparams)
|
114 |
+
else:
|
115 |
+
self.feat_pred_fn = ConvLSTMLinear(**arch_hparams)
|
116 |
+
|
117 |
+
def forward(self, txt_enc, spk_emb, x, lens):
|
118 |
+
if x is not None:
|
119 |
+
x = self.attribute_processing.normalize(x)
|
120 |
+
|
121 |
+
txt_enc = self.bottleneck_layer(txt_enc)
|
122 |
+
spk_emb_expanded = spk_emb[..., None].expand(-1, -1, txt_enc.shape[2])
|
123 |
+
context = torch.cat((txt_enc, spk_emb_expanded), 1)
|
124 |
+
|
125 |
+
x_hat = self.feat_pred_fn(context, lens)
|
126 |
+
|
127 |
+
outputs = {"x_hat": x_hat, "x": x}
|
128 |
+
return outputs
|
129 |
+
|
130 |
+
def infer(self, z, txt_enc, spk_emb, lens=None):
|
131 |
+
x_hat = self.forward(txt_enc, spk_emb, x=None, lens=lens)["x_hat"]
|
132 |
+
x_hat = self.attribute_processing.denormalize(x_hat)
|
133 |
+
return x_hat
|
134 |
+
|
135 |
+
|
136 |
+
class BGAP(torch.nn.Module):
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
n_in_dim,
|
140 |
+
n_speaker_dim,
|
141 |
+
bottleneck_hparams,
|
142 |
+
n_flows,
|
143 |
+
n_group_size,
|
144 |
+
n_layers,
|
145 |
+
with_dilation,
|
146 |
+
kernel_size,
|
147 |
+
scaling_fn,
|
148 |
+
take_log_of_input=False,
|
149 |
+
n_channels=1024,
|
150 |
+
use_quadratic=False,
|
151 |
+
n_bins=8,
|
152 |
+
n_spline_steps=2,
|
153 |
+
):
|
154 |
+
super(BGAP, self).__init__()
|
155 |
+
# assert(n_group_size % 2 == 0)
|
156 |
+
self.n_flows = n_flows
|
157 |
+
self.n_group_size = n_group_size
|
158 |
+
self.transforms = torch.nn.ModuleList()
|
159 |
+
self.convinv = torch.nn.ModuleList()
|
160 |
+
self.n_speaker_dim = n_speaker_dim
|
161 |
+
self.scaling_fn = scaling_fn
|
162 |
+
self.attribute_processing = AttributeProcessing(take_log_of_input)
|
163 |
+
self.n_spline_steps = n_spline_steps
|
164 |
+
self.bottleneck_layer = BottleneckLayerLayer(**bottleneck_hparams)
|
165 |
+
n_txt_reduced_dim = self.bottleneck_layer.out_dim
|
166 |
+
context_dim = n_txt_reduced_dim * n_group_size + n_speaker_dim
|
167 |
+
|
168 |
+
if self.n_group_size > 1:
|
169 |
+
self.unfold_params = {
|
170 |
+
"kernel_size": (n_group_size, 1),
|
171 |
+
"stride": n_group_size,
|
172 |
+
"padding": 0,
|
173 |
+
"dilation": 1,
|
174 |
+
}
|
175 |
+
self.unfold = nn.Unfold(**self.unfold_params)
|
176 |
+
|
177 |
+
for k in range(n_flows):
|
178 |
+
self.convinv.append(Invertible1x1Conv(n_in_dim * n_group_size))
|
179 |
+
if k >= n_flows - self.n_spline_steps:
|
180 |
+
left = -3
|
181 |
+
right = 3
|
182 |
+
top = 3
|
183 |
+
bottom = -3
|
184 |
+
self.transforms.append(
|
185 |
+
SplineTransformationLayer(
|
186 |
+
n_in_dim * n_group_size,
|
187 |
+
context_dim,
|
188 |
+
n_layers,
|
189 |
+
with_dilation=with_dilation,
|
190 |
+
kernel_size=kernel_size,
|
191 |
+
scaling_fn=scaling_fn,
|
192 |
+
n_channels=n_channels,
|
193 |
+
top=top,
|
194 |
+
bottom=bottom,
|
195 |
+
left=left,
|
196 |
+
right=right,
|
197 |
+
use_quadratic=use_quadratic,
|
198 |
+
n_bins=n_bins,
|
199 |
+
)
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
self.transforms.append(
|
203 |
+
AffineTransformationLayer(
|
204 |
+
n_in_dim * n_group_size,
|
205 |
+
context_dim,
|
206 |
+
n_layers,
|
207 |
+
with_dilation=with_dilation,
|
208 |
+
kernel_size=kernel_size,
|
209 |
+
scaling_fn=scaling_fn,
|
210 |
+
affine_model="simple_conv",
|
211 |
+
n_channels=n_channels,
|
212 |
+
)
|
213 |
+
)
|
214 |
+
|
215 |
+
def fold(self, data):
|
216 |
+
"""Inverse of the self.unfold(data.unsqueeze(-1)) operation used for
|
217 |
+
the grouping or "squeeze" operation on input
|
218 |
+
|
219 |
+
Args:
|
220 |
+
data: B x C x T tensor of temporal data
|
221 |
+
"""
|
222 |
+
output_size = (data.shape[2] * self.n_group_size, 1)
|
223 |
+
data = nn.functional.fold(
|
224 |
+
data, output_size=output_size, **self.unfold_params
|
225 |
+
).squeeze(-1)
|
226 |
+
return data
|
227 |
+
|
228 |
+
def preprocess_context(self, txt_emb, speaker_vecs, std_scale=None):
|
229 |
+
if self.n_group_size > 1:
|
230 |
+
txt_emb = self.unfold(txt_emb[..., None])
|
231 |
+
speaker_vecs = speaker_vecs[..., None].expand(-1, -1, txt_emb.shape[2])
|
232 |
+
context = torch.cat((txt_emb, speaker_vecs), 1)
|
233 |
+
return context
|
234 |
+
|
235 |
+
def forward(self, txt_enc, spk_emb, x, lens):
|
236 |
+
"""x<tensor>: duration or pitch or energy average"""
|
237 |
+
assert txt_enc.size(2) >= x.size(1)
|
238 |
+
if len(x.shape) == 2:
|
239 |
+
# add channel dimension
|
240 |
+
x = x[:, None]
|
241 |
+
txt_enc = self.bottleneck_layer(txt_enc)
|
242 |
+
|
243 |
+
# lens including padded values
|
244 |
+
lens_grouped = (lens // self.n_group_size).long()
|
245 |
+
context = self.preprocess_context(txt_enc, spk_emb)
|
246 |
+
x = self.unfold(x[..., None])
|
247 |
+
log_s_list, log_det_W_list = [], []
|
248 |
+
for k in range(self.n_flows):
|
249 |
+
x, log_s = self.transforms[k](x, context, seq_lens=lens_grouped)
|
250 |
+
x, log_det_W = self.convinv[k](x)
|
251 |
+
log_det_W_list.append(log_det_W)
|
252 |
+
log_s_list.append(log_s)
|
253 |
+
# prepare outputs
|
254 |
+
outputs = {"z": x, "log_det_W_list": log_det_W_list, "log_s_list": log_s_list}
|
255 |
+
|
256 |
+
return outputs
|
257 |
+
|
258 |
+
def infer(self, z, txt_enc, spk_emb, seq_lens):
|
259 |
+
txt_enc = self.bottleneck_layer(txt_enc)
|
260 |
+
context = self.preprocess_context(txt_enc, spk_emb)
|
261 |
+
lens_grouped = (seq_lens // self.n_group_size).long()
|
262 |
+
z = self.unfold(z[..., None])
|
263 |
+
for k in reversed(range(self.n_flows)):
|
264 |
+
z = self.convinv[k](z, inverse=True)
|
265 |
+
z = self.transforms[k].forward(
|
266 |
+
z, context, inverse=True, seq_lens=lens_grouped
|
267 |
+
)
|
268 |
+
# z mapped to input domain
|
269 |
+
x_hat = self.fold(z)
|
270 |
+
# pad on the way out
|
271 |
+
return x_hat
|
272 |
+
|
273 |
+
|
274 |
+
class AGAP(torch.nn.Module):
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
n_in_dim,
|
278 |
+
n_speaker_dim,
|
279 |
+
n_flows,
|
280 |
+
n_hidden,
|
281 |
+
n_lstm_layers,
|
282 |
+
bottleneck_hparams,
|
283 |
+
scaling_fn="exp",
|
284 |
+
take_log_of_input=False,
|
285 |
+
p_dropout=0.0,
|
286 |
+
setup="",
|
287 |
+
spline_flow_params=None,
|
288 |
+
n_group_size=1,
|
289 |
+
):
|
290 |
+
super(AGAP, self).__init__()
|
291 |
+
self.flows = torch.nn.ModuleList()
|
292 |
+
self.n_group_size = n_group_size
|
293 |
+
self.n_speaker_dim = n_speaker_dim
|
294 |
+
self.attribute_processing = AttributeProcessing(take_log_of_input)
|
295 |
+
self.n_in_dim = n_in_dim
|
296 |
+
self.bottleneck_layer = BottleneckLayerLayer(**bottleneck_hparams)
|
297 |
+
n_txt_reduced_dim = self.bottleneck_layer.out_dim
|
298 |
+
|
299 |
+
if self.n_group_size > 1:
|
300 |
+
self.unfold_params = {
|
301 |
+
"kernel_size": (n_group_size, 1),
|
302 |
+
"stride": n_group_size,
|
303 |
+
"padding": 0,
|
304 |
+
"dilation": 1,
|
305 |
+
}
|
306 |
+
self.unfold = nn.Unfold(**self.unfold_params)
|
307 |
+
|
308 |
+
if spline_flow_params is not None:
|
309 |
+
spline_flow_params["n_in_channels"] *= self.n_group_size
|
310 |
+
|
311 |
+
for i in range(n_flows):
|
312 |
+
if i % 2 == 0:
|
313 |
+
self.flows.append(
|
314 |
+
AR_Step(
|
315 |
+
n_in_dim * n_group_size,
|
316 |
+
n_speaker_dim,
|
317 |
+
n_txt_reduced_dim * n_group_size,
|
318 |
+
n_hidden,
|
319 |
+
n_lstm_layers,
|
320 |
+
scaling_fn,
|
321 |
+
spline_flow_params,
|
322 |
+
)
|
323 |
+
)
|
324 |
+
else:
|
325 |
+
self.flows.append(
|
326 |
+
AR_Back_Step(
|
327 |
+
n_in_dim * n_group_size,
|
328 |
+
n_speaker_dim,
|
329 |
+
n_txt_reduced_dim * n_group_size,
|
330 |
+
n_hidden,
|
331 |
+
n_lstm_layers,
|
332 |
+
scaling_fn,
|
333 |
+
spline_flow_params,
|
334 |
+
)
|
335 |
+
)
|
336 |
+
|
337 |
+
def fold(self, data):
|
338 |
+
"""Inverse of the self.unfold(data.unsqueeze(-1)) operation used for
|
339 |
+
the grouping or "squeeze" operation on input
|
340 |
+
|
341 |
+
Args:
|
342 |
+
data: B x C x T tensor of temporal data
|
343 |
+
"""
|
344 |
+
output_size = (data.shape[2] * self.n_group_size, 1)
|
345 |
+
data = nn.functional.fold(
|
346 |
+
data, output_size=output_size, **self.unfold_params
|
347 |
+
).squeeze(-1)
|
348 |
+
return data
|
349 |
+
|
350 |
+
def preprocess_context(self, txt_emb, speaker_vecs):
|
351 |
+
if self.n_group_size > 1:
|
352 |
+
txt_emb = self.unfold(txt_emb[..., None])
|
353 |
+
speaker_vecs = speaker_vecs[..., None].expand(-1, -1, txt_emb.shape[2])
|
354 |
+
context = torch.cat((txt_emb, speaker_vecs), 1)
|
355 |
+
return context
|
356 |
+
|
357 |
+
def forward(self, txt_emb, spk_emb, x, lens):
|
358 |
+
"""x<tensor>: duration or pitch or energy average"""
|
359 |
+
|
360 |
+
x = x[:, None] if len(x.shape) == 2 else x # add channel dimension
|
361 |
+
if self.n_group_size > 1:
|
362 |
+
x = self.unfold(x[..., None])
|
363 |
+
x = x.permute(2, 0, 1) # permute to time, batch, dims
|
364 |
+
x = self.attribute_processing.normalize(x)
|
365 |
+
|
366 |
+
txt_emb = self.bottleneck_layer(txt_emb)
|
367 |
+
context = self.preprocess_context(txt_emb, spk_emb)
|
368 |
+
context = context.permute(2, 0, 1) # permute to time, batch, dims
|
369 |
+
|
370 |
+
lens_groupped = (lens / self.n_group_size).long()
|
371 |
+
log_s_list = []
|
372 |
+
for i, flow in enumerate(self.flows):
|
373 |
+
x, log_s = flow(x, context, lens_groupped)
|
374 |
+
log_s_list.append(log_s)
|
375 |
+
|
376 |
+
x = x.permute(1, 2, 0) # x mapped to z
|
377 |
+
log_s_list = [log_s_elt.permute(1, 2, 0) for log_s_elt in log_s_list]
|
378 |
+
outputs = {"z": x, "log_s_list": log_s_list, "log_det_W_list": []}
|
379 |
+
return outputs
|
380 |
+
|
381 |
+
def infer(self, z, txt_emb, spk_emb, seq_lens=None):
|
382 |
+
if self.n_group_size > 1:
|
383 |
+
n_frames = z.shape[2]
|
384 |
+
z = self.unfold(z[..., None])
|
385 |
+
z = z.permute(2, 0, 1) # permute to time, batch, dims
|
386 |
+
|
387 |
+
txt_emb = self.bottleneck_layer(txt_emb)
|
388 |
+
context = self.preprocess_context(txt_emb, spk_emb)
|
389 |
+
context = context.permute(2, 0, 1) # permute to time, batch, dims
|
390 |
+
|
391 |
+
for i, flow in enumerate(reversed(self.flows)):
|
392 |
+
z = flow.infer(z, context)
|
393 |
+
|
394 |
+
x_hat = z.permute(1, 2, 0)
|
395 |
+
if self.n_group_size > 1:
|
396 |
+
x_hat = self.fold(x_hat)
|
397 |
+
if n_frames > x_hat.shape[2]:
|
398 |
+
m = nn.ReflectionPad1d((0, n_frames - x_hat.shape[2]))
|
399 |
+
x_hat = m(x_hat)
|
400 |
+
|
401 |
+
x_hat = self.attribute_processing.denormalize(x_hat)
|
402 |
+
return x_hat
|
audio_processing.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
import torch
|
22 |
+
import numpy as np
|
23 |
+
from scipy.signal import get_window
|
24 |
+
from librosa.filters import mel as librosa_mel_fn
|
25 |
+
import librosa.util as librosa_util
|
26 |
+
|
27 |
+
|
28 |
+
def window_sumsquare(
|
29 |
+
window,
|
30 |
+
n_frames,
|
31 |
+
hop_length=200,
|
32 |
+
win_length=800,
|
33 |
+
n_fft=800,
|
34 |
+
dtype=np.float32,
|
35 |
+
norm=None,
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
# from librosa 0.6
|
39 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
40 |
+
|
41 |
+
This is used to estimate modulation effects induced by windowing
|
42 |
+
observations in short-time fourier transforms.
|
43 |
+
|
44 |
+
Parameters
|
45 |
+
----------
|
46 |
+
window : string, tuple, number, callable, or list-like
|
47 |
+
Window specification, as in `get_window`
|
48 |
+
|
49 |
+
n_frames : int > 0
|
50 |
+
The number of analysis frames
|
51 |
+
|
52 |
+
hop_length : int > 0
|
53 |
+
The number of samples to advance between frames
|
54 |
+
|
55 |
+
win_length : [optional]
|
56 |
+
The length of the window function. By default, this matches `n_fft`.
|
57 |
+
|
58 |
+
n_fft : int > 0
|
59 |
+
The length of each analysis frame.
|
60 |
+
|
61 |
+
dtype : np.dtype
|
62 |
+
The data type of the output
|
63 |
+
|
64 |
+
Returns
|
65 |
+
-------
|
66 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
67 |
+
The sum-squared envelope of the window function
|
68 |
+
"""
|
69 |
+
if win_length is None:
|
70 |
+
win_length = n_fft
|
71 |
+
|
72 |
+
n = n_fft + hop_length * (n_frames - 1)
|
73 |
+
x = np.zeros(n, dtype=dtype)
|
74 |
+
|
75 |
+
# Compute the squared window at the desired length
|
76 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
77 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
|
78 |
+
win_sq = librosa_util.pad_center(win_sq, size=n_fft)
|
79 |
+
|
80 |
+
# Fill the envelope
|
81 |
+
for i in range(n_frames):
|
82 |
+
sample = i * hop_length
|
83 |
+
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
def griffin_lim(magnitudes, stft_fn, n_iters=30):
|
88 |
+
"""
|
89 |
+
PARAMS
|
90 |
+
------
|
91 |
+
magnitudes: spectrogram magnitudes
|
92 |
+
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
|
93 |
+
"""
|
94 |
+
|
95 |
+
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
|
96 |
+
angles = angles.astype(np.float32)
|
97 |
+
angles = torch.autograd.Variable(torch.from_numpy(angles))
|
98 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
99 |
+
|
100 |
+
for i in range(n_iters):
|
101 |
+
_, angles = stft_fn.transform(signal)
|
102 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
103 |
+
return signal
|
104 |
+
|
105 |
+
|
106 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
107 |
+
"""
|
108 |
+
PARAMS
|
109 |
+
------
|
110 |
+
C: compression factor
|
111 |
+
"""
|
112 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
113 |
+
|
114 |
+
|
115 |
+
def dynamic_range_decompression(x, C=1):
|
116 |
+
"""
|
117 |
+
PARAMS
|
118 |
+
------
|
119 |
+
C: compression factor used to compress
|
120 |
+
"""
|
121 |
+
return torch.exp(x) / C
|
122 |
+
|
123 |
+
|
124 |
+
class TacotronSTFT(torch.nn.Module):
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
filter_length=1024,
|
128 |
+
hop_length=256,
|
129 |
+
win_length=1024,
|
130 |
+
n_mel_channels=80,
|
131 |
+
sampling_rate=22050,
|
132 |
+
mel_fmin=0.0,
|
133 |
+
mel_fmax=None,
|
134 |
+
):
|
135 |
+
super(TacotronSTFT, self).__init__()
|
136 |
+
self.n_mel_channels = n_mel_channels
|
137 |
+
self.sampling_rate = sampling_rate
|
138 |
+
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
139 |
+
mel_basis = librosa_mel_fn(
|
140 |
+
sr=sampling_rate,
|
141 |
+
n_fft=filter_length,
|
142 |
+
n_mels=n_mel_channels,
|
143 |
+
fmin=mel_fmin,
|
144 |
+
fmax=mel_fmax,
|
145 |
+
)
|
146 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
147 |
+
self.register_buffer("mel_basis", mel_basis)
|
148 |
+
|
149 |
+
def spectral_normalize(self, magnitudes):
|
150 |
+
output = dynamic_range_compression(magnitudes)
|
151 |
+
return output
|
152 |
+
|
153 |
+
def spectral_de_normalize(self, magnitudes):
|
154 |
+
output = dynamic_range_decompression(magnitudes)
|
155 |
+
return output
|
156 |
+
|
157 |
+
def mel_spectrogram(self, y):
|
158 |
+
"""Computes mel-spectrograms from a batch of waves
|
159 |
+
PARAMS
|
160 |
+
------
|
161 |
+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
162 |
+
|
163 |
+
RETURNS
|
164 |
+
-------
|
165 |
+
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
166 |
+
"""
|
167 |
+
assert torch.min(y.data) >= -1
|
168 |
+
assert torch.max(y.data) <= 1
|
169 |
+
|
170 |
+
magnitudes, phases = self.stft_fn.transform(y)
|
171 |
+
magnitudes = magnitudes.data
|
172 |
+
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
173 |
+
mel_output = self.spectral_normalize(mel_output)
|
174 |
+
return mel_output
|
175 |
+
|
176 |
+
|
177 |
+
"""
|
178 |
+
BSD 3-Clause License
|
179 |
+
|
180 |
+
Copyright (c) 2017, Prem Seetharaman
|
181 |
+
All rights reserved.
|
182 |
+
|
183 |
+
* Redistribution and use in source and binary forms, with or without
|
184 |
+
modification, are permitted provided that the following conditions are met:
|
185 |
+
|
186 |
+
* Redistributions of source code must retain the above copyright notice,
|
187 |
+
this list of conditions and the following disclaimer.
|
188 |
+
|
189 |
+
* Redistributions in binary form must reproduce the above copyright notice, this
|
190 |
+
list of conditions and the following disclaimer in the
|
191 |
+
documentation and/or other materials provided with the distribution.
|
192 |
+
|
193 |
+
* Neither the name of the copyright holder nor the names of its
|
194 |
+
contributors may be used to endorse or promote products derived from this
|
195 |
+
software without specific prior written permission.
|
196 |
+
|
197 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
198 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
199 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
200 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
201 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
202 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
203 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
204 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
205 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
206 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
207 |
+
"""
|
208 |
+
import torch.nn.functional as F
|
209 |
+
from torch.autograd import Variable
|
210 |
+
from scipy.signal import get_window
|
211 |
+
from librosa.util import pad_center, tiny
|
212 |
+
|
213 |
+
|
214 |
+
class STFT(torch.nn.Module):
|
215 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
216 |
+
|
217 |
+
def __init__(
|
218 |
+
self, filter_length=800, hop_length=200, win_length=800, window="hann"
|
219 |
+
):
|
220 |
+
super(STFT, self).__init__()
|
221 |
+
self.filter_length = filter_length
|
222 |
+
self.hop_length = hop_length
|
223 |
+
self.win_length = win_length
|
224 |
+
self.window = window
|
225 |
+
self.forward_transform = None
|
226 |
+
scale = self.filter_length / self.hop_length
|
227 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
228 |
+
|
229 |
+
cutoff = int((self.filter_length / 2 + 1))
|
230 |
+
fourier_basis = np.vstack(
|
231 |
+
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
232 |
+
)
|
233 |
+
|
234 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
235 |
+
inverse_basis = torch.FloatTensor(
|
236 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
237 |
+
)
|
238 |
+
|
239 |
+
if window is not None:
|
240 |
+
assert win_length >= filter_length
|
241 |
+
# get window and zero center pad it to filter_length
|
242 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
243 |
+
fft_window = pad_center(fft_window, size=filter_length)
|
244 |
+
fft_window = torch.from_numpy(fft_window).float()
|
245 |
+
|
246 |
+
# window the bases
|
247 |
+
forward_basis *= fft_window
|
248 |
+
inverse_basis *= fft_window
|
249 |
+
|
250 |
+
self.register_buffer("forward_basis", forward_basis.float())
|
251 |
+
self.register_buffer("inverse_basis", inverse_basis.float())
|
252 |
+
|
253 |
+
def transform(self, input_data):
|
254 |
+
num_batches = input_data.size(0)
|
255 |
+
num_samples = input_data.size(1)
|
256 |
+
|
257 |
+
self.num_samples = num_samples
|
258 |
+
|
259 |
+
# similar to librosa, reflect-pad the input
|
260 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
261 |
+
input_data = F.pad(
|
262 |
+
input_data.unsqueeze(1),
|
263 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
264 |
+
mode="reflect",
|
265 |
+
)
|
266 |
+
input_data = input_data.squeeze(1)
|
267 |
+
|
268 |
+
forward_transform = F.conv1d(
|
269 |
+
input_data,
|
270 |
+
Variable(self.forward_basis, requires_grad=False),
|
271 |
+
stride=self.hop_length,
|
272 |
+
padding=0,
|
273 |
+
)
|
274 |
+
|
275 |
+
cutoff = int((self.filter_length / 2) + 1)
|
276 |
+
real_part = forward_transform[:, :cutoff, :]
|
277 |
+
imag_part = forward_transform[:, cutoff:, :]
|
278 |
+
|
279 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
280 |
+
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
|
281 |
+
|
282 |
+
return magnitude, phase
|
283 |
+
|
284 |
+
def inverse(self, magnitude, phase):
|
285 |
+
recombine_magnitude_phase = torch.cat(
|
286 |
+
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
287 |
+
)
|
288 |
+
|
289 |
+
inverse_transform = F.conv_transpose1d(
|
290 |
+
recombine_magnitude_phase,
|
291 |
+
Variable(self.inverse_basis, requires_grad=False),
|
292 |
+
stride=self.hop_length,
|
293 |
+
padding=0,
|
294 |
+
)
|
295 |
+
|
296 |
+
if self.window is not None:
|
297 |
+
window_sum = window_sumsquare(
|
298 |
+
self.window,
|
299 |
+
magnitude.size(-1),
|
300 |
+
hop_length=self.hop_length,
|
301 |
+
win_length=self.win_length,
|
302 |
+
n_fft=self.filter_length,
|
303 |
+
dtype=np.float32,
|
304 |
+
)
|
305 |
+
# remove modulation effects
|
306 |
+
approx_nonzero_indices = torch.from_numpy(
|
307 |
+
np.where(window_sum > tiny(window_sum))[0]
|
308 |
+
)
|
309 |
+
window_sum = torch.autograd.Variable(
|
310 |
+
torch.from_numpy(window_sum), requires_grad=False
|
311 |
+
)
|
312 |
+
window_sum = window_sum.to(magnitude.device)
|
313 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
|
314 |
+
approx_nonzero_indices
|
315 |
+
]
|
316 |
+
|
317 |
+
# scale by hop ratio
|
318 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
319 |
+
|
320 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
|
321 |
+
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
|
322 |
+
|
323 |
+
return inverse_transform
|
324 |
+
|
325 |
+
def forward(self, input_data):
|
326 |
+
self.magnitude, self.phase = self.transform(input_data)
|
327 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
328 |
+
return reconstruction
|
autoregressive_flow.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
|
22 |
+
# AR_Back_Step and AR_Step based on implementation from
|
23 |
+
# https://github.com/NVIDIA/flowtron/blob/master/flowtron.py
|
24 |
+
# Original license text:
|
25 |
+
###############################################################################
|
26 |
+
#
|
27 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
28 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
29 |
+
# you may not use this file except in compliance with the License.
|
30 |
+
# You may obtain a copy of the License at
|
31 |
+
#
|
32 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
33 |
+
#
|
34 |
+
# Unless required by applicable law or agreed to in writing, software
|
35 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
36 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
37 |
+
# See the License for the specific language governing permissions and
|
38 |
+
# limitations under the License.
|
39 |
+
#
|
40 |
+
###############################################################################
|
41 |
+
# Original Author and Contact: Rafael Valle
|
42 |
+
# Modification by Rafael Valle
|
43 |
+
|
44 |
+
import torch
|
45 |
+
from torch import nn
|
46 |
+
from common import DenseLayer, SplineTransformationLayerAR
|
47 |
+
|
48 |
+
|
49 |
+
class AR_Back_Step(torch.nn.Module):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
n_attr_channels,
|
53 |
+
n_speaker_dim,
|
54 |
+
n_text_dim,
|
55 |
+
n_hidden,
|
56 |
+
n_lstm_layers,
|
57 |
+
scaling_fn,
|
58 |
+
spline_flow_params=None,
|
59 |
+
):
|
60 |
+
super(AR_Back_Step, self).__init__()
|
61 |
+
self.ar_step = AR_Step(
|
62 |
+
n_attr_channels,
|
63 |
+
n_speaker_dim,
|
64 |
+
n_text_dim,
|
65 |
+
n_hidden,
|
66 |
+
n_lstm_layers,
|
67 |
+
scaling_fn,
|
68 |
+
spline_flow_params,
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, mel, context, lens):
|
72 |
+
mel = torch.flip(mel, (0,))
|
73 |
+
context = torch.flip(context, (0,))
|
74 |
+
# backwards flow, send padded zeros back to end
|
75 |
+
for k in range(mel.size(1)):
|
76 |
+
mel[:, k] = mel[:, k].roll(lens[k].item(), dims=0)
|
77 |
+
context[:, k] = context[:, k].roll(lens[k].item(), dims=0)
|
78 |
+
|
79 |
+
mel, log_s = self.ar_step(mel, context, lens)
|
80 |
+
|
81 |
+
# move padded zeros back to beginning
|
82 |
+
for k in range(mel.size(1)):
|
83 |
+
mel[:, k] = mel[:, k].roll(-lens[k].item(), dims=0)
|
84 |
+
|
85 |
+
return torch.flip(mel, (0,)), log_s
|
86 |
+
|
87 |
+
def infer(self, residual, context):
|
88 |
+
residual = self.ar_step.infer(
|
89 |
+
torch.flip(residual, (0,)), torch.flip(context, (0,))
|
90 |
+
)
|
91 |
+
residual = torch.flip(residual, (0,))
|
92 |
+
return residual
|
93 |
+
|
94 |
+
|
95 |
+
class AR_Step(torch.nn.Module):
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
n_attr_channels,
|
99 |
+
n_speaker_dim,
|
100 |
+
n_text_channels,
|
101 |
+
n_hidden,
|
102 |
+
n_lstm_layers,
|
103 |
+
scaling_fn,
|
104 |
+
spline_flow_params=None,
|
105 |
+
):
|
106 |
+
super(AR_Step, self).__init__()
|
107 |
+
if spline_flow_params is not None:
|
108 |
+
self.spline_flow = SplineTransformationLayerAR(**spline_flow_params)
|
109 |
+
else:
|
110 |
+
self.n_out_dims = n_attr_channels
|
111 |
+
self.conv = torch.nn.Conv1d(n_hidden, 2 * n_attr_channels, 1)
|
112 |
+
self.conv.weight.data = 0.0 * self.conv.weight.data
|
113 |
+
self.conv.bias.data = 0.0 * self.conv.bias.data
|
114 |
+
|
115 |
+
self.attr_lstm = torch.nn.LSTM(n_attr_channels, n_hidden)
|
116 |
+
self.lstm = torch.nn.LSTM(
|
117 |
+
n_hidden + n_text_channels + n_speaker_dim, n_hidden, n_lstm_layers
|
118 |
+
)
|
119 |
+
|
120 |
+
if spline_flow_params is None:
|
121 |
+
self.dense_layer = DenseLayer(in_dim=n_hidden, sizes=[n_hidden, n_hidden])
|
122 |
+
self.scaling_fn = scaling_fn
|
123 |
+
|
124 |
+
def run_padded_sequence(
|
125 |
+
self, sorted_idx, unsort_idx, lens, padded_data, recurrent_model
|
126 |
+
):
|
127 |
+
"""Sorts input data by previded ordering (and un-ordering) and runs the
|
128 |
+
packed data through the recurrent model
|
129 |
+
|
130 |
+
Args:
|
131 |
+
sorted_idx (torch.tensor): 1D sorting index
|
132 |
+
unsort_idx (torch.tensor): 1D unsorting index (inverse sorted_idx)
|
133 |
+
lens: lengths of input data (sorted in descending order)
|
134 |
+
padded_data (torch.tensor): input sequences (padded)
|
135 |
+
recurrent_model (nn.Module): recurrent model to run data through
|
136 |
+
Returns:
|
137 |
+
hidden_vectors (torch.tensor): outputs of the RNN, in the original,
|
138 |
+
unsorted, ordering
|
139 |
+
"""
|
140 |
+
|
141 |
+
# sort the data by decreasing length using provided index
|
142 |
+
# we assume batch index is in dim=1
|
143 |
+
padded_data = padded_data[:, sorted_idx]
|
144 |
+
padded_data = nn.utils.rnn.pack_padded_sequence(padded_data, lens.cpu())
|
145 |
+
hidden_vectors = recurrent_model(padded_data)[0]
|
146 |
+
hidden_vectors, _ = nn.utils.rnn.pad_packed_sequence(hidden_vectors)
|
147 |
+
# unsort the results at dim=1 and return
|
148 |
+
hidden_vectors = hidden_vectors[:, unsort_idx]
|
149 |
+
return hidden_vectors
|
150 |
+
|
151 |
+
def get_scaling_and_logs(self, scale_unconstrained):
|
152 |
+
if self.scaling_fn == "translate":
|
153 |
+
s = torch.exp(scale_unconstrained * 0)
|
154 |
+
log_s = scale_unconstrained * 0
|
155 |
+
elif self.scaling_fn == "exp":
|
156 |
+
s = torch.exp(scale_unconstrained)
|
157 |
+
log_s = scale_unconstrained # log(exp
|
158 |
+
elif self.scaling_fn == "tanh":
|
159 |
+
s = torch.tanh(scale_unconstrained) + 1 + 1e-6
|
160 |
+
log_s = torch.log(s)
|
161 |
+
elif self.scaling_fn == "sigmoid":
|
162 |
+
s = torch.sigmoid(scale_unconstrained + 10) + 1e-6
|
163 |
+
log_s = torch.log(s)
|
164 |
+
else:
|
165 |
+
raise Exception("Scaling fn {} not supp.".format(self.scaling_fn))
|
166 |
+
|
167 |
+
return s, log_s
|
168 |
+
|
169 |
+
def forward(self, mel, context, lens):
|
170 |
+
dummy = torch.FloatTensor(1, mel.size(1), mel.size(2)).zero_()
|
171 |
+
dummy = dummy.type(mel.type())
|
172 |
+
# seq_len x batch x dim
|
173 |
+
mel0 = torch.cat([dummy, mel[:-1]], 0)
|
174 |
+
|
175 |
+
self.lstm.flatten_parameters()
|
176 |
+
self.attr_lstm.flatten_parameters()
|
177 |
+
if lens is not None:
|
178 |
+
# collect decreasing length indices
|
179 |
+
lens, ids = torch.sort(lens, descending=True)
|
180 |
+
original_ids = [0] * lens.size(0)
|
181 |
+
for i, ids_i in enumerate(ids):
|
182 |
+
original_ids[ids_i] = i
|
183 |
+
# mel_seq_len x batch x hidden_dim
|
184 |
+
mel_hidden = self.run_padded_sequence(
|
185 |
+
ids, original_ids, lens, mel0, self.attr_lstm
|
186 |
+
)
|
187 |
+
else:
|
188 |
+
mel_hidden = self.attr_lstm(mel0)[0]
|
189 |
+
|
190 |
+
decoder_input = torch.cat((mel_hidden, context), -1)
|
191 |
+
|
192 |
+
if lens is not None:
|
193 |
+
# reorder, run padded sequence and undo reordering
|
194 |
+
lstm_hidden = self.run_padded_sequence(
|
195 |
+
ids, original_ids, lens, decoder_input, self.lstm
|
196 |
+
)
|
197 |
+
else:
|
198 |
+
lstm_hidden = self.lstm(decoder_input)[0]
|
199 |
+
|
200 |
+
if hasattr(self, "spline_flow"):
|
201 |
+
# spline flow fn expects inputs to be batch, channel, time
|
202 |
+
lstm_hidden = lstm_hidden.permute(1, 2, 0)
|
203 |
+
mel = mel.permute(1, 2, 0)
|
204 |
+
mel, log_s = self.spline_flow(mel, lstm_hidden, inverse=False)
|
205 |
+
mel = mel.permute(2, 0, 1)
|
206 |
+
log_s = log_s.permute(2, 0, 1)
|
207 |
+
else:
|
208 |
+
lstm_hidden = self.dense_layer(lstm_hidden).permute(1, 2, 0)
|
209 |
+
decoder_output = self.conv(lstm_hidden).permute(2, 0, 1)
|
210 |
+
|
211 |
+
scale, log_s = self.get_scaling_and_logs(
|
212 |
+
decoder_output[:, :, : self.n_out_dims]
|
213 |
+
)
|
214 |
+
bias = decoder_output[:, :, self.n_out_dims :]
|
215 |
+
|
216 |
+
mel = scale * mel + bias
|
217 |
+
|
218 |
+
return mel, log_s
|
219 |
+
|
220 |
+
def infer(self, residual, context):
|
221 |
+
total_output = [] # seems 10FPS faster than pre-allocation
|
222 |
+
|
223 |
+
output = None
|
224 |
+
dummy = torch.cuda.FloatTensor(1, residual.size(1), residual.size(2)).zero_()
|
225 |
+
self.attr_lstm.flatten_parameters()
|
226 |
+
|
227 |
+
for i in range(0, residual.size(0)):
|
228 |
+
if i == 0:
|
229 |
+
output = dummy
|
230 |
+
mel_hidden, (h, c) = self.attr_lstm(output)
|
231 |
+
else:
|
232 |
+
mel_hidden, (h, c) = self.attr_lstm(output, (h, c))
|
233 |
+
|
234 |
+
decoder_input = torch.cat((mel_hidden, context[i][None]), -1)
|
235 |
+
|
236 |
+
if i == 0:
|
237 |
+
lstm_hidden, (h1, c1) = self.lstm(decoder_input)
|
238 |
+
else:
|
239 |
+
lstm_hidden, (h1, c1) = self.lstm(decoder_input, (h1, c1))
|
240 |
+
|
241 |
+
if hasattr(self, "spline_flow"):
|
242 |
+
# expects inputs to be batch, channel, time
|
243 |
+
lstm_hidden = lstm_hidden.permute(1, 2, 0)
|
244 |
+
output = residual[i : i + 1].permute(1, 2, 0)
|
245 |
+
output = self.spline_flow(output, lstm_hidden, inverse=True)
|
246 |
+
output = output.permute(2, 0, 1)
|
247 |
+
else:
|
248 |
+
lstm_hidden = self.dense_layer(lstm_hidden).permute(1, 2, 0)
|
249 |
+
decoder_output = self.conv(lstm_hidden).permute(2, 0, 1)
|
250 |
+
|
251 |
+
s, log_s = self.get_scaling_and_logs(
|
252 |
+
decoder_output[:, :, : decoder_output.size(2) // 2]
|
253 |
+
)
|
254 |
+
b = decoder_output[:, :, decoder_output.size(2) // 2 :]
|
255 |
+
output = (residual[i : i + 1] - b) / s
|
256 |
+
total_output.append(output)
|
257 |
+
|
258 |
+
total_output = torch.cat(total_output, 0)
|
259 |
+
return total_output
|
common.py
ADDED
@@ -0,0 +1,1083 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
|
22 |
+
# 1x1InvertibleConv and WN based on implementation from WaveGlow https://github.com/NVIDIA/waveglow/blob/master/glow.py
|
23 |
+
# Original license:
|
24 |
+
# *****************************************************************************
|
25 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
26 |
+
#
|
27 |
+
# Redistribution and use in source and binary forms, with or without
|
28 |
+
# modification, are permitted provided that the following conditions are met:
|
29 |
+
# * Redistributions of source code must retain the above copyright
|
30 |
+
# notice, this list of conditions and the following disclaimer.
|
31 |
+
# * Redistributions in binary form must reproduce the above copyright
|
32 |
+
# notice, this list of conditions and the following disclaimer in the
|
33 |
+
# documentation and/or other materials provided with the distribution.
|
34 |
+
# * Neither the name of the NVIDIA CORPORATION nor the
|
35 |
+
# names of its contributors may be used to endorse or promote products
|
36 |
+
# derived from this software without specific prior written permission.
|
37 |
+
#
|
38 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
39 |
+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
40 |
+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
41 |
+
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
42 |
+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
43 |
+
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
44 |
+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
45 |
+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
46 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
47 |
+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
48 |
+
#
|
49 |
+
# *****************************************************************************
|
50 |
+
|
51 |
+
import torch
|
52 |
+
from torch import nn
|
53 |
+
from torch.nn import functional as F
|
54 |
+
|
55 |
+
import numpy as np
|
56 |
+
import ast
|
57 |
+
|
58 |
+
from splines import (
|
59 |
+
piecewise_linear_transform,
|
60 |
+
piecewise_linear_inverse_transform,
|
61 |
+
unbounded_piecewise_quadratic_transform,
|
62 |
+
)
|
63 |
+
from partialconv1d import PartialConv1d as pconv1d
|
64 |
+
from typing import Tuple
|
65 |
+
|
66 |
+
use_cuda = torch.cuda.is_available()
|
67 |
+
|
68 |
+
if use_cuda:
|
69 |
+
device = "cuda"
|
70 |
+
else:
|
71 |
+
device = "cpu"
|
72 |
+
|
73 |
+
|
74 |
+
def update_params(config, params):
|
75 |
+
for param in params:
|
76 |
+
print(param)
|
77 |
+
k, v = param.split("=")
|
78 |
+
try:
|
79 |
+
v = ast.literal_eval(v)
|
80 |
+
except:
|
81 |
+
pass
|
82 |
+
|
83 |
+
k_split = k.split(".")
|
84 |
+
if len(k_split) > 1:
|
85 |
+
parent_k = k_split[0]
|
86 |
+
cur_param = [".".join(k_split[1:]) + "=" + str(v)]
|
87 |
+
update_params(config[parent_k], cur_param)
|
88 |
+
elif k in config and len(k_split) == 1:
|
89 |
+
print(f"overriding {k} with {v}")
|
90 |
+
config[k] = v
|
91 |
+
else:
|
92 |
+
print("{}, {} params not updated".format(k, v))
|
93 |
+
|
94 |
+
|
95 |
+
def get_mask_from_lengths(lengths):
|
96 |
+
"""Constructs binary mask from a 1D torch tensor of input lengths
|
97 |
+
|
98 |
+
Args:
|
99 |
+
lengths (torch.tensor): 1D tensor
|
100 |
+
Returns:
|
101 |
+
mask (torch.tensor): num_sequences x max_length x 1 binary tensor
|
102 |
+
"""
|
103 |
+
max_len = torch.max(lengths).item()
|
104 |
+
if torch.cuda.is_available():
|
105 |
+
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
|
106 |
+
else:
|
107 |
+
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len))
|
108 |
+
mask = (ids < lengths.unsqueeze(1)).bool()
|
109 |
+
return mask
|
110 |
+
|
111 |
+
|
112 |
+
class ExponentialClass(torch.nn.Module):
|
113 |
+
def __init__(self):
|
114 |
+
super(ExponentialClass, self).__init__()
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
return torch.exp(x)
|
118 |
+
|
119 |
+
|
120 |
+
class LinearNorm(torch.nn.Module):
|
121 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
|
122 |
+
super(LinearNorm, self).__init__()
|
123 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
124 |
+
|
125 |
+
torch.nn.init.xavier_uniform_(
|
126 |
+
self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
|
127 |
+
)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
return self.linear_layer(x)
|
131 |
+
|
132 |
+
|
133 |
+
class ConvNorm(torch.nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
in_channels,
|
137 |
+
out_channels,
|
138 |
+
kernel_size=1,
|
139 |
+
stride=1,
|
140 |
+
padding=None,
|
141 |
+
dilation=1,
|
142 |
+
bias=True,
|
143 |
+
w_init_gain="linear",
|
144 |
+
use_partial_padding=False,
|
145 |
+
use_weight_norm=False,
|
146 |
+
):
|
147 |
+
super(ConvNorm, self).__init__()
|
148 |
+
if padding is None:
|
149 |
+
assert kernel_size % 2 == 1
|
150 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
151 |
+
self.kernel_size = kernel_size
|
152 |
+
self.dilation = dilation
|
153 |
+
self.use_partial_padding = use_partial_padding
|
154 |
+
self.use_weight_norm = use_weight_norm
|
155 |
+
conv_fn = torch.nn.Conv1d
|
156 |
+
if self.use_partial_padding:
|
157 |
+
conv_fn = pconv1d
|
158 |
+
self.conv = conv_fn(
|
159 |
+
in_channels,
|
160 |
+
out_channels,
|
161 |
+
kernel_size=kernel_size,
|
162 |
+
stride=stride,
|
163 |
+
padding=padding,
|
164 |
+
dilation=dilation,
|
165 |
+
bias=bias,
|
166 |
+
)
|
167 |
+
torch.nn.init.xavier_uniform_(
|
168 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
|
169 |
+
)
|
170 |
+
if self.use_weight_norm:
|
171 |
+
self.conv = nn.utils.weight_norm(self.conv)
|
172 |
+
|
173 |
+
def forward(self, signal, mask=None):
|
174 |
+
if self.use_partial_padding:
|
175 |
+
conv_signal = self.conv(signal, mask)
|
176 |
+
else:
|
177 |
+
conv_signal = self.conv(signal)
|
178 |
+
if mask is not None:
|
179 |
+
# always re-zero output if mask is
|
180 |
+
# available to match zero-padding
|
181 |
+
conv_signal = conv_signal * mask
|
182 |
+
return conv_signal
|
183 |
+
|
184 |
+
|
185 |
+
class DenseLayer(nn.Module):
|
186 |
+
def __init__(self, in_dim=1024, sizes=[1024, 1024]):
|
187 |
+
super(DenseLayer, self).__init__()
|
188 |
+
in_sizes = [in_dim] + sizes[:-1]
|
189 |
+
self.layers = nn.ModuleList(
|
190 |
+
[
|
191 |
+
LinearNorm(in_size, out_size, bias=True)
|
192 |
+
for (in_size, out_size) in zip(in_sizes, sizes)
|
193 |
+
]
|
194 |
+
)
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
for linear in self.layers:
|
198 |
+
x = torch.tanh(linear(x))
|
199 |
+
return x
|
200 |
+
|
201 |
+
|
202 |
+
class LengthRegulator(nn.Module):
|
203 |
+
def __init__(self):
|
204 |
+
super().__init__()
|
205 |
+
|
206 |
+
def forward(self, x, dur):
|
207 |
+
output = []
|
208 |
+
for x_i, dur_i in zip(x, dur):
|
209 |
+
expanded = self.expand(x_i, dur_i)
|
210 |
+
output.append(expanded)
|
211 |
+
output = self.pad(output)
|
212 |
+
return output
|
213 |
+
|
214 |
+
def expand(self, x, dur):
|
215 |
+
output = []
|
216 |
+
for i, frame in enumerate(x):
|
217 |
+
expanded_len = int(dur[i] + 0.5)
|
218 |
+
expanded = frame.expand(expanded_len, -1)
|
219 |
+
output.append(expanded)
|
220 |
+
output = torch.cat(output, 0)
|
221 |
+
return output
|
222 |
+
|
223 |
+
def pad(self, x):
|
224 |
+
output = []
|
225 |
+
max_len = max([x[i].size(0) for i in range(len(x))])
|
226 |
+
for i, seq in enumerate(x):
|
227 |
+
padded = F.pad(seq, [0, 0, 0, max_len - seq.size(0)], "constant", 0.0)
|
228 |
+
output.append(padded)
|
229 |
+
output = torch.stack(output)
|
230 |
+
return output
|
231 |
+
|
232 |
+
|
233 |
+
class ConvLSTMLinear(nn.Module):
|
234 |
+
def __init__(
|
235 |
+
self,
|
236 |
+
in_dim,
|
237 |
+
out_dim,
|
238 |
+
n_layers=2,
|
239 |
+
n_channels=256,
|
240 |
+
kernel_size=3,
|
241 |
+
p_dropout=0.1,
|
242 |
+
lstm_type="bilstm",
|
243 |
+
use_linear=True,
|
244 |
+
):
|
245 |
+
super(ConvLSTMLinear, self).__init__()
|
246 |
+
self.out_dim = out_dim
|
247 |
+
self.lstm_type = lstm_type
|
248 |
+
self.use_linear = use_linear
|
249 |
+
self.dropout = nn.Dropout(p=p_dropout)
|
250 |
+
|
251 |
+
convolutions = []
|
252 |
+
for i in range(n_layers):
|
253 |
+
conv_layer = ConvNorm(
|
254 |
+
in_dim if i == 0 else n_channels,
|
255 |
+
n_channels,
|
256 |
+
kernel_size=kernel_size,
|
257 |
+
stride=1,
|
258 |
+
padding=int((kernel_size - 1) / 2),
|
259 |
+
dilation=1,
|
260 |
+
w_init_gain="relu",
|
261 |
+
)
|
262 |
+
conv_layer = torch.nn.utils.weight_norm(conv_layer.conv, name="weight")
|
263 |
+
convolutions.append(conv_layer)
|
264 |
+
|
265 |
+
self.convolutions = nn.ModuleList(convolutions)
|
266 |
+
|
267 |
+
if not self.use_linear:
|
268 |
+
n_channels = out_dim
|
269 |
+
|
270 |
+
if self.lstm_type != "":
|
271 |
+
use_bilstm = False
|
272 |
+
lstm_channels = n_channels
|
273 |
+
if self.lstm_type == "bilstm":
|
274 |
+
use_bilstm = True
|
275 |
+
lstm_channels = int(n_channels // 2)
|
276 |
+
|
277 |
+
self.bilstm = nn.LSTM(
|
278 |
+
n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm
|
279 |
+
)
|
280 |
+
lstm_norm_fn_pntr = nn.utils.spectral_norm
|
281 |
+
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0")
|
282 |
+
if self.lstm_type == "bilstm":
|
283 |
+
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse")
|
284 |
+
|
285 |
+
if self.use_linear:
|
286 |
+
self.dense = nn.Linear(n_channels, out_dim)
|
287 |
+
|
288 |
+
def run_padded_sequence(self, context, lens):
|
289 |
+
context_embedded = []
|
290 |
+
for b_ind in range(context.size()[0]): # TODO: speed up
|
291 |
+
curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone()
|
292 |
+
for conv in self.convolutions:
|
293 |
+
curr_context = self.dropout(F.relu(conv(curr_context)))
|
294 |
+
context_embedded.append(curr_context[0].transpose(0, 1))
|
295 |
+
context = torch.nn.utils.rnn.pad_sequence(context_embedded, batch_first=True)
|
296 |
+
return context
|
297 |
+
|
298 |
+
def run_unsorted_inputs(self, fn, context, lens):
|
299 |
+
lens_sorted, ids_sorted = torch.sort(lens, descending=True)
|
300 |
+
unsort_ids = [0] * lens.size(0)
|
301 |
+
for i in range(len(ids_sorted)):
|
302 |
+
unsort_ids[ids_sorted[i]] = i
|
303 |
+
lens_sorted = lens_sorted.long().cpu()
|
304 |
+
|
305 |
+
context = context[ids_sorted]
|
306 |
+
context = nn.utils.rnn.pack_padded_sequence(
|
307 |
+
context, lens_sorted, batch_first=True
|
308 |
+
)
|
309 |
+
context = fn(context)[0]
|
310 |
+
context = nn.utils.rnn.pad_packed_sequence(context, batch_first=True)[0]
|
311 |
+
|
312 |
+
# map back to original indices
|
313 |
+
context = context[unsort_ids]
|
314 |
+
return context
|
315 |
+
|
316 |
+
def forward(self, context, lens):
|
317 |
+
if context.size()[0] > 1:
|
318 |
+
context = self.run_padded_sequence(context, lens)
|
319 |
+
# to B, D, T
|
320 |
+
context = context.transpose(1, 2)
|
321 |
+
else:
|
322 |
+
for conv in self.convolutions:
|
323 |
+
context = self.dropout(F.relu(conv(context)))
|
324 |
+
|
325 |
+
if self.lstm_type != "":
|
326 |
+
context = context.transpose(1, 2)
|
327 |
+
self.bilstm.flatten_parameters()
|
328 |
+
if lens is not None:
|
329 |
+
context = self.run_unsorted_inputs(self.bilstm, context, lens)
|
330 |
+
else:
|
331 |
+
context = self.bilstm(context)[0]
|
332 |
+
context = context.transpose(1, 2)
|
333 |
+
|
334 |
+
x_hat = context
|
335 |
+
if self.use_linear:
|
336 |
+
x_hat = self.dense(context.transpose(1, 2)).transpose(1, 2)
|
337 |
+
|
338 |
+
return x_hat
|
339 |
+
|
340 |
+
def infer(self, z, txt_enc, spk_emb):
|
341 |
+
x_hat = self.forward(txt_enc, spk_emb)["x_hat"]
|
342 |
+
x_hat = self.feature_processing.denormalize(x_hat)
|
343 |
+
return x_hat
|
344 |
+
|
345 |
+
|
346 |
+
class Encoder(nn.Module):
|
347 |
+
"""Encoder module:
|
348 |
+
- Three 1-d convolution banks
|
349 |
+
- Bidirectional LSTM
|
350 |
+
"""
|
351 |
+
|
352 |
+
def __init__(
|
353 |
+
self,
|
354 |
+
encoder_n_convolutions=3,
|
355 |
+
encoder_embedding_dim=512,
|
356 |
+
encoder_kernel_size=5,
|
357 |
+
norm_fn=nn.BatchNorm1d,
|
358 |
+
lstm_norm_fn=None,
|
359 |
+
):
|
360 |
+
super(Encoder, self).__init__()
|
361 |
+
|
362 |
+
convolutions = []
|
363 |
+
for _ in range(encoder_n_convolutions):
|
364 |
+
conv_layer = nn.Sequential(
|
365 |
+
ConvNorm(
|
366 |
+
encoder_embedding_dim,
|
367 |
+
encoder_embedding_dim,
|
368 |
+
kernel_size=encoder_kernel_size,
|
369 |
+
stride=1,
|
370 |
+
padding=int((encoder_kernel_size - 1) / 2),
|
371 |
+
dilation=1,
|
372 |
+
w_init_gain="relu",
|
373 |
+
use_partial_padding=True,
|
374 |
+
),
|
375 |
+
norm_fn(encoder_embedding_dim, affine=True),
|
376 |
+
)
|
377 |
+
convolutions.append(conv_layer)
|
378 |
+
self.convolutions = nn.ModuleList(convolutions)
|
379 |
+
|
380 |
+
self.lstm = nn.LSTM(
|
381 |
+
encoder_embedding_dim,
|
382 |
+
int(encoder_embedding_dim / 2),
|
383 |
+
1,
|
384 |
+
batch_first=True,
|
385 |
+
bidirectional=True,
|
386 |
+
)
|
387 |
+
if lstm_norm_fn is not None:
|
388 |
+
if "spectral" in lstm_norm_fn:
|
389 |
+
print("Applying spectral norm to text encoder LSTM")
|
390 |
+
lstm_norm_fn_pntr = torch.nn.utils.spectral_norm
|
391 |
+
elif "weight" in lstm_norm_fn:
|
392 |
+
print("Applying weight norm to text encoder LSTM")
|
393 |
+
lstm_norm_fn_pntr = torch.nn.utils.weight_norm
|
394 |
+
self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0")
|
395 |
+
self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0_reverse")
|
396 |
+
|
397 |
+
@torch.autocast(device, enabled=False)
|
398 |
+
def forward(self, x, in_lens):
|
399 |
+
"""
|
400 |
+
Args:
|
401 |
+
x (torch.tensor): N x C x L padded input of text embeddings
|
402 |
+
in_lens (torch.tensor): 1D tensor of sequence lengths
|
403 |
+
"""
|
404 |
+
if x.size()[0] > 1:
|
405 |
+
x_embedded = []
|
406 |
+
for b_ind in range(x.size()[0]): # TODO: improve speed
|
407 |
+
curr_x = x[b_ind : b_ind + 1, :, : in_lens[b_ind]].clone()
|
408 |
+
for conv in self.convolutions:
|
409 |
+
curr_x = F.dropout(F.relu(conv(curr_x)), 0.5, self.training)
|
410 |
+
x_embedded.append(curr_x[0].transpose(0, 1))
|
411 |
+
x = torch.nn.utils.rnn.pad_sequence(x_embedded, batch_first=True)
|
412 |
+
else:
|
413 |
+
for conv in self.convolutions:
|
414 |
+
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
415 |
+
x = x.transpose(1, 2)
|
416 |
+
|
417 |
+
# recent amp change -- change in_lens to int
|
418 |
+
in_lens = in_lens.int().cpu()
|
419 |
+
|
420 |
+
x = nn.utils.rnn.pack_padded_sequence(x, in_lens, batch_first=True)
|
421 |
+
|
422 |
+
self.lstm.flatten_parameters()
|
423 |
+
outputs, _ = self.lstm(x)
|
424 |
+
|
425 |
+
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
|
426 |
+
|
427 |
+
return outputs
|
428 |
+
|
429 |
+
@torch.autocast(device, enabled=False)
|
430 |
+
def infer(self, x):
|
431 |
+
for conv in self.convolutions:
|
432 |
+
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
433 |
+
|
434 |
+
x = x.transpose(1, 2)
|
435 |
+
self.lstm.flatten_parameters()
|
436 |
+
outputs, _ = self.lstm(x)
|
437 |
+
|
438 |
+
return outputs
|
439 |
+
|
440 |
+
|
441 |
+
class Invertible1x1ConvLUS(torch.nn.Module):
|
442 |
+
def __init__(self, c, cache_inverse=False):
|
443 |
+
super(Invertible1x1ConvLUS, self).__init__()
|
444 |
+
# Sample a random orthonormal matrix to initialize weights
|
445 |
+
W = torch.linalg.qr(torch.FloatTensor(c, c).normal_())[0]
|
446 |
+
# Ensure determinant is 1.0 not -1.0
|
447 |
+
if torch.det(W) < 0:
|
448 |
+
W[:, 0] = -1 * W[:, 0]
|
449 |
+
p, lower, upper = torch.lu_unpack(*torch.lu(W))
|
450 |
+
|
451 |
+
self.register_buffer("p", p)
|
452 |
+
# diagonals of lower will always be 1s anyway
|
453 |
+
lower = torch.tril(lower, -1)
|
454 |
+
lower_diag = torch.diag(torch.eye(c, c))
|
455 |
+
self.register_buffer("lower_diag", lower_diag)
|
456 |
+
self.lower = nn.Parameter(lower)
|
457 |
+
self.upper_diag = nn.Parameter(torch.diag(upper))
|
458 |
+
self.upper = nn.Parameter(torch.triu(upper, 1))
|
459 |
+
self.cache_inverse = cache_inverse
|
460 |
+
|
461 |
+
@torch.autocast(device, enabled=False)
|
462 |
+
def forward(self, z, inverse=False):
|
463 |
+
U = torch.triu(self.upper, 1) + torch.diag(self.upper_diag)
|
464 |
+
L = torch.tril(self.lower, -1) + torch.diag(self.lower_diag)
|
465 |
+
W = torch.mm(self.p, torch.mm(L, U))
|
466 |
+
if inverse:
|
467 |
+
if not hasattr(self, "W_inverse"):
|
468 |
+
# inverse computation
|
469 |
+
W_inverse = W.float().inverse()
|
470 |
+
if z.type() == "torch.cuda.HalfTensor":
|
471 |
+
W_inverse = W_inverse.half()
|
472 |
+
|
473 |
+
self.W_inverse = W_inverse[..., None]
|
474 |
+
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
|
475 |
+
if not self.cache_inverse:
|
476 |
+
delattr(self, "W_inverse")
|
477 |
+
return z
|
478 |
+
else:
|
479 |
+
W = W[..., None]
|
480 |
+
z = F.conv1d(z, W, bias=None, stride=1, padding=0)
|
481 |
+
log_det_W = torch.sum(torch.log(torch.abs(self.upper_diag)))
|
482 |
+
return z, log_det_W
|
483 |
+
|
484 |
+
|
485 |
+
class Invertible1x1Conv(torch.nn.Module):
|
486 |
+
"""
|
487 |
+
The layer outputs both the convolution, and the log determinant
|
488 |
+
of its weight matrix. If inverse=True it does convolution with
|
489 |
+
inverse
|
490 |
+
"""
|
491 |
+
|
492 |
+
def __init__(self, c, cache_inverse=False):
|
493 |
+
super(Invertible1x1Conv, self).__init__()
|
494 |
+
self.conv = torch.nn.Conv1d(
|
495 |
+
c, c, kernel_size=1, stride=1, padding=0, bias=False
|
496 |
+
)
|
497 |
+
|
498 |
+
# Sample a random orthonormal matrix to initialize weights
|
499 |
+
W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
|
500 |
+
|
501 |
+
# Ensure determinant is 1.0 not -1.0
|
502 |
+
if torch.det(W) < 0:
|
503 |
+
W[:, 0] = -1 * W[:, 0]
|
504 |
+
W = W.view(c, c, 1)
|
505 |
+
self.conv.weight.data = W
|
506 |
+
self.cache_inverse = cache_inverse
|
507 |
+
|
508 |
+
def forward(self, z, inverse=False):
|
509 |
+
# DO NOT apply n_of_groups, as it doesn't account for padded sequences
|
510 |
+
W = self.conv.weight.squeeze()
|
511 |
+
|
512 |
+
if inverse:
|
513 |
+
if not hasattr(self, "W_inverse"):
|
514 |
+
# Inverse computation
|
515 |
+
W_inverse = W.float().inverse()
|
516 |
+
if z.type() == "torch.cuda.HalfTensor":
|
517 |
+
W_inverse = W_inverse.half()
|
518 |
+
|
519 |
+
self.W_inverse = W_inverse[..., None]
|
520 |
+
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
|
521 |
+
if not self.cache_inverse:
|
522 |
+
delattr(self, "W_inverse")
|
523 |
+
return z
|
524 |
+
else:
|
525 |
+
# Forward computation
|
526 |
+
log_det_W = torch.logdet(W).clone()
|
527 |
+
z = self.conv(z)
|
528 |
+
return z, log_det_W
|
529 |
+
|
530 |
+
|
531 |
+
class SimpleConvNet(torch.nn.Module):
|
532 |
+
def __init__(
|
533 |
+
self,
|
534 |
+
n_mel_channels,
|
535 |
+
n_context_dim,
|
536 |
+
final_out_channels,
|
537 |
+
n_layers=2,
|
538 |
+
kernel_size=5,
|
539 |
+
with_dilation=True,
|
540 |
+
max_channels=1024,
|
541 |
+
zero_init=True,
|
542 |
+
use_partial_padding=True,
|
543 |
+
):
|
544 |
+
super(SimpleConvNet, self).__init__()
|
545 |
+
self.layers = torch.nn.ModuleList()
|
546 |
+
self.n_layers = n_layers
|
547 |
+
in_channels = n_mel_channels + n_context_dim
|
548 |
+
out_channels = -1
|
549 |
+
self.use_partial_padding = use_partial_padding
|
550 |
+
for i in range(n_layers):
|
551 |
+
dilation = 2**i if with_dilation else 1
|
552 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
553 |
+
out_channels = min(max_channels, in_channels * 2)
|
554 |
+
self.layers.append(
|
555 |
+
ConvNorm(
|
556 |
+
in_channels,
|
557 |
+
out_channels,
|
558 |
+
kernel_size=kernel_size,
|
559 |
+
stride=1,
|
560 |
+
padding=padding,
|
561 |
+
dilation=dilation,
|
562 |
+
bias=True,
|
563 |
+
w_init_gain="relu",
|
564 |
+
use_partial_padding=use_partial_padding,
|
565 |
+
)
|
566 |
+
)
|
567 |
+
in_channels = out_channels
|
568 |
+
|
569 |
+
self.last_layer = torch.nn.Conv1d(
|
570 |
+
out_channels, final_out_channels, kernel_size=1
|
571 |
+
)
|
572 |
+
|
573 |
+
if zero_init:
|
574 |
+
self.last_layer.weight.data *= 0
|
575 |
+
self.last_layer.bias.data *= 0
|
576 |
+
|
577 |
+
def forward(self, z_w_context, seq_lens: torch.Tensor = None):
|
578 |
+
# seq_lens: tensor array of sequence sequence lengths
|
579 |
+
# output should be b x n_mel_channels x z_w_context.shape(2)
|
580 |
+
mask = None
|
581 |
+
if seq_lens is not None:
|
582 |
+
mask = get_mask_from_lengths(seq_lens).unsqueeze(1).float()
|
583 |
+
|
584 |
+
for i in range(self.n_layers):
|
585 |
+
z_w_context = self.layers[i](z_w_context, mask)
|
586 |
+
z_w_context = torch.relu(z_w_context)
|
587 |
+
|
588 |
+
z_w_context = self.last_layer(z_w_context)
|
589 |
+
return z_w_context
|
590 |
+
|
591 |
+
|
592 |
+
class WN(torch.nn.Module):
|
593 |
+
"""
|
594 |
+
Adapted from WN() module in WaveGlow with modififcations to variable names
|
595 |
+
"""
|
596 |
+
|
597 |
+
def __init__(
|
598 |
+
self,
|
599 |
+
n_in_channels,
|
600 |
+
n_context_dim,
|
601 |
+
n_layers,
|
602 |
+
n_channels,
|
603 |
+
kernel_size=5,
|
604 |
+
affine_activation="softplus",
|
605 |
+
use_partial_padding=True,
|
606 |
+
):
|
607 |
+
super(WN, self).__init__()
|
608 |
+
assert kernel_size % 2 == 1
|
609 |
+
assert n_channels % 2 == 0
|
610 |
+
self.n_layers = n_layers
|
611 |
+
self.n_channels = n_channels
|
612 |
+
self.in_layers = torch.nn.ModuleList()
|
613 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
614 |
+
start = torch.nn.Conv1d(n_in_channels + n_context_dim, n_channels, 1)
|
615 |
+
start = torch.nn.utils.weight_norm(start, name="weight")
|
616 |
+
self.start = start
|
617 |
+
self.softplus = torch.nn.Softplus()
|
618 |
+
self.affine_activation = affine_activation
|
619 |
+
self.use_partial_padding = use_partial_padding
|
620 |
+
# Initializing last layer to 0 makes the affine coupling layers
|
621 |
+
# do nothing at first. This helps with training stability
|
622 |
+
end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1)
|
623 |
+
end.weight.data.zero_()
|
624 |
+
end.bias.data.zero_()
|
625 |
+
self.end = end
|
626 |
+
|
627 |
+
for i in range(n_layers):
|
628 |
+
dilation = 2**i
|
629 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
630 |
+
in_layer = ConvNorm(
|
631 |
+
n_channels,
|
632 |
+
n_channels,
|
633 |
+
kernel_size=kernel_size,
|
634 |
+
dilation=dilation,
|
635 |
+
padding=padding,
|
636 |
+
use_partial_padding=use_partial_padding,
|
637 |
+
use_weight_norm=True,
|
638 |
+
)
|
639 |
+
# in_layer = nn.Conv1d(n_channels, n_channels, kernel_size,
|
640 |
+
# dilation=dilation, padding=padding)
|
641 |
+
# in_layer = nn.utils.weight_norm(in_layer)
|
642 |
+
self.in_layers.append(in_layer)
|
643 |
+
res_skip_layer = nn.Conv1d(n_channels, n_channels, 1)
|
644 |
+
res_skip_layer = nn.utils.weight_norm(res_skip_layer)
|
645 |
+
self.res_skip_layers.append(res_skip_layer)
|
646 |
+
|
647 |
+
def forward(
|
648 |
+
self,
|
649 |
+
forward_input: Tuple[torch.Tensor, torch.Tensor],
|
650 |
+
seq_lens: torch.Tensor = None,
|
651 |
+
):
|
652 |
+
z, context = forward_input
|
653 |
+
z = torch.cat((z, context), 1) # append context to z as well
|
654 |
+
z = self.start(z)
|
655 |
+
output = torch.zeros_like(z)
|
656 |
+
mask = None
|
657 |
+
if seq_lens is not None:
|
658 |
+
mask = get_mask_from_lengths(seq_lens).unsqueeze(1).float()
|
659 |
+
non_linearity = torch.relu
|
660 |
+
if self.affine_activation == "softplus":
|
661 |
+
non_linearity = self.softplus
|
662 |
+
|
663 |
+
for i in range(self.n_layers):
|
664 |
+
z = non_linearity(self.in_layers[i](z, mask))
|
665 |
+
res_skip_acts = non_linearity(self.res_skip_layers[i](z))
|
666 |
+
output = output + res_skip_acts
|
667 |
+
|
668 |
+
output = self.end(output) # [B, dim, seq_len]
|
669 |
+
return output
|
670 |
+
|
671 |
+
|
672 |
+
# Affine Coupling Layers
|
673 |
+
class SplineTransformationLayerAR(torch.nn.Module):
|
674 |
+
def __init__(
|
675 |
+
self,
|
676 |
+
n_in_channels,
|
677 |
+
n_context_dim,
|
678 |
+
n_layers,
|
679 |
+
affine_model="simple_conv",
|
680 |
+
kernel_size=1,
|
681 |
+
scaling_fn="exp",
|
682 |
+
affine_activation="softplus",
|
683 |
+
n_channels=1024,
|
684 |
+
n_bins=8,
|
685 |
+
left=-6,
|
686 |
+
right=6,
|
687 |
+
bottom=-6,
|
688 |
+
top=6,
|
689 |
+
use_quadratic=False,
|
690 |
+
):
|
691 |
+
super(SplineTransformationLayerAR, self).__init__()
|
692 |
+
self.n_in_channels = n_in_channels # input dimensions
|
693 |
+
self.left = left
|
694 |
+
self.right = right
|
695 |
+
self.bottom = bottom
|
696 |
+
self.top = top
|
697 |
+
self.n_bins = n_bins
|
698 |
+
self.spline_fn = piecewise_linear_transform
|
699 |
+
self.inv_spline_fn = piecewise_linear_inverse_transform
|
700 |
+
self.use_quadratic = use_quadratic
|
701 |
+
|
702 |
+
if self.use_quadratic:
|
703 |
+
self.spline_fn = unbounded_piecewise_quadratic_transform
|
704 |
+
self.inv_spline_fn = unbounded_piecewise_quadratic_transform
|
705 |
+
self.n_bins = 2 * self.n_bins + 1
|
706 |
+
final_out_channels = self.n_in_channels * self.n_bins
|
707 |
+
|
708 |
+
# autoregressive flow, kernel size 1 and no dilation
|
709 |
+
self.param_predictor = SimpleConvNet(
|
710 |
+
n_context_dim,
|
711 |
+
0,
|
712 |
+
final_out_channels,
|
713 |
+
n_layers,
|
714 |
+
with_dilation=False,
|
715 |
+
kernel_size=1,
|
716 |
+
zero_init=True,
|
717 |
+
use_partial_padding=False,
|
718 |
+
)
|
719 |
+
|
720 |
+
# output is unnormalized bin weights
|
721 |
+
|
722 |
+
def normalize(self, z, inverse):
|
723 |
+
# normalize to [0, 1]
|
724 |
+
if inverse:
|
725 |
+
z = (z - self.bottom) / (self.top - self.bottom)
|
726 |
+
else:
|
727 |
+
z = (z - self.left) / (self.right - self.left)
|
728 |
+
|
729 |
+
return z
|
730 |
+
|
731 |
+
def denormalize(self, z, inverse):
|
732 |
+
if inverse:
|
733 |
+
z = z * (self.right - self.left) + self.left
|
734 |
+
else:
|
735 |
+
z = z * (self.top - self.bottom) + self.bottom
|
736 |
+
|
737 |
+
return z
|
738 |
+
|
739 |
+
def forward(self, z, context, inverse=False):
|
740 |
+
b_s, c_s, t_s = z.size(0), z.size(1), z.size(2)
|
741 |
+
|
742 |
+
z = self.normalize(z, inverse)
|
743 |
+
|
744 |
+
if z.min() < 0.0 or z.max() > 1.0:
|
745 |
+
print("spline z scaled beyond [0, 1]", z.min(), z.max())
|
746 |
+
|
747 |
+
z_reshaped = z.permute(0, 2, 1).reshape(b_s * t_s, -1)
|
748 |
+
affine_params = self.param_predictor(context)
|
749 |
+
q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, c_s, -1)
|
750 |
+
with torch.autocast(device, enabled=False):
|
751 |
+
if self.use_quadratic:
|
752 |
+
w = q_tilde[:, :, : self.n_bins // 2]
|
753 |
+
v = q_tilde[:, :, self.n_bins // 2 :]
|
754 |
+
z_tformed, log_s = self.spline_fn(
|
755 |
+
z_reshaped.float(), w.float(), v.float(), inverse=inverse
|
756 |
+
)
|
757 |
+
else:
|
758 |
+
z_tformed, log_s = self.spline_fn(z_reshaped.float(), q_tilde.float())
|
759 |
+
|
760 |
+
z = z_tformed.reshape(b_s, t_s, -1).permute(0, 2, 1)
|
761 |
+
z = self.denormalize(z, inverse)
|
762 |
+
if inverse:
|
763 |
+
return z
|
764 |
+
|
765 |
+
log_s = log_s.reshape(b_s, t_s, -1)
|
766 |
+
log_s = log_s.permute(0, 2, 1)
|
767 |
+
log_s = log_s + c_s * (
|
768 |
+
np.log(self.top - self.bottom) - np.log(self.right - self.left)
|
769 |
+
)
|
770 |
+
return z, log_s
|
771 |
+
|
772 |
+
|
773 |
+
class SplineTransformationLayer(torch.nn.Module):
|
774 |
+
def __init__(
|
775 |
+
self,
|
776 |
+
n_mel_channels,
|
777 |
+
n_context_dim,
|
778 |
+
n_layers,
|
779 |
+
with_dilation=True,
|
780 |
+
kernel_size=5,
|
781 |
+
scaling_fn="exp",
|
782 |
+
affine_activation="softplus",
|
783 |
+
n_channels=1024,
|
784 |
+
n_bins=8,
|
785 |
+
left=-4,
|
786 |
+
right=4,
|
787 |
+
bottom=-4,
|
788 |
+
top=4,
|
789 |
+
use_quadratic=False,
|
790 |
+
):
|
791 |
+
super(SplineTransformationLayer, self).__init__()
|
792 |
+
self.n_mel_channels = n_mel_channels # input dimensions
|
793 |
+
self.half_mel_channels = int(n_mel_channels / 2) # half, because we split
|
794 |
+
self.left = left
|
795 |
+
self.right = right
|
796 |
+
self.bottom = bottom
|
797 |
+
self.top = top
|
798 |
+
self.n_bins = n_bins
|
799 |
+
self.spline_fn = piecewise_linear_transform
|
800 |
+
self.inv_spline_fn = piecewise_linear_inverse_transform
|
801 |
+
self.use_quadratic = use_quadratic
|
802 |
+
|
803 |
+
if self.use_quadratic:
|
804 |
+
self.spline_fn = unbounded_piecewise_quadratic_transform
|
805 |
+
self.inv_spline_fn = unbounded_piecewise_quadratic_transform
|
806 |
+
self.n_bins = 2 * self.n_bins + 1
|
807 |
+
final_out_channels = self.half_mel_channels * self.n_bins
|
808 |
+
|
809 |
+
self.param_predictor = SimpleConvNet(
|
810 |
+
self.half_mel_channels,
|
811 |
+
n_context_dim,
|
812 |
+
final_out_channels,
|
813 |
+
n_layers,
|
814 |
+
with_dilation=with_dilation,
|
815 |
+
kernel_size=kernel_size,
|
816 |
+
zero_init=False,
|
817 |
+
)
|
818 |
+
|
819 |
+
# output is unnormalized bin weights
|
820 |
+
|
821 |
+
def forward(self, z, context, inverse=False, seq_lens=None):
|
822 |
+
b_s, c_s, t_s = z.size(0), z.size(1), z.size(2)
|
823 |
+
|
824 |
+
# condition on z_0, transform z_1
|
825 |
+
n_half = self.half_mel_channels
|
826 |
+
z_0, z_1 = z[:, :n_half], z[:, n_half:]
|
827 |
+
|
828 |
+
# normalize to [0,1]
|
829 |
+
if inverse:
|
830 |
+
z_1 = (z_1 - self.bottom) / (self.top - self.bottom)
|
831 |
+
else:
|
832 |
+
z_1 = (z_1 - self.left) / (self.right - self.left)
|
833 |
+
|
834 |
+
z_w_context = torch.cat((z_0, context), 1)
|
835 |
+
affine_params = self.param_predictor(z_w_context, seq_lens)
|
836 |
+
z_1_reshaped = z_1.permute(0, 2, 1).reshape(b_s * t_s, -1)
|
837 |
+
q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, n_half, self.n_bins)
|
838 |
+
|
839 |
+
with torch.autocast(device, enabled=False):
|
840 |
+
if self.use_quadratic:
|
841 |
+
w = q_tilde[:, :, : self.n_bins // 2]
|
842 |
+
v = q_tilde[:, :, self.n_bins // 2 :]
|
843 |
+
z_1_tformed, log_s = self.spline_fn(
|
844 |
+
z_1_reshaped.float(), w.float(), v.float(), inverse=inverse
|
845 |
+
)
|
846 |
+
if not inverse:
|
847 |
+
log_s = torch.sum(log_s, 1)
|
848 |
+
else:
|
849 |
+
if inverse:
|
850 |
+
z_1_tformed, _dc = self.inv_spline_fn(
|
851 |
+
z_1_reshaped.float(), q_tilde.float(), False
|
852 |
+
)
|
853 |
+
else:
|
854 |
+
z_1_tformed, log_s = self.spline_fn(
|
855 |
+
z_1_reshaped.float(), q_tilde.float()
|
856 |
+
)
|
857 |
+
|
858 |
+
z_1 = z_1_tformed.reshape(b_s, t_s, -1).permute(0, 2, 1)
|
859 |
+
|
860 |
+
# undo [0, 1] normalization
|
861 |
+
if inverse:
|
862 |
+
z_1 = z_1 * (self.right - self.left) + self.left
|
863 |
+
z = torch.cat((z_0, z_1), dim=1)
|
864 |
+
return z
|
865 |
+
else: # training
|
866 |
+
z_1 = z_1 * (self.top - self.bottom) + self.bottom
|
867 |
+
z = torch.cat((z_0, z_1), dim=1)
|
868 |
+
log_s = log_s.reshape(b_s, t_s).unsqueeze(1) + n_half * (
|
869 |
+
np.log(self.top - self.bottom) - np.log(self.right - self.left)
|
870 |
+
)
|
871 |
+
return z, log_s
|
872 |
+
|
873 |
+
|
874 |
+
class AffineTransformationLayer(torch.nn.Module):
|
875 |
+
def __init__(
|
876 |
+
self,
|
877 |
+
n_mel_channels,
|
878 |
+
n_context_dim,
|
879 |
+
n_layers,
|
880 |
+
affine_model="simple_conv",
|
881 |
+
with_dilation=True,
|
882 |
+
kernel_size=5,
|
883 |
+
scaling_fn="exp",
|
884 |
+
affine_activation="softplus",
|
885 |
+
n_channels=1024,
|
886 |
+
use_partial_padding=False,
|
887 |
+
):
|
888 |
+
super(AffineTransformationLayer, self).__init__()
|
889 |
+
if affine_model not in ("wavenet", "simple_conv"):
|
890 |
+
raise Exception("{} affine model not supported".format(affine_model))
|
891 |
+
if isinstance(scaling_fn, list):
|
892 |
+
if not all(
|
893 |
+
[x in ("translate", "exp", "tanh", "sigmoid") for x in scaling_fn]
|
894 |
+
):
|
895 |
+
raise Exception("{} scaling fn not supported".format(scaling_fn))
|
896 |
+
else:
|
897 |
+
if scaling_fn not in ("translate", "exp", "tanh", "sigmoid"):
|
898 |
+
raise Exception("{} scaling fn not supported".format(scaling_fn))
|
899 |
+
|
900 |
+
self.affine_model = affine_model
|
901 |
+
self.scaling_fn = scaling_fn
|
902 |
+
if affine_model == "wavenet":
|
903 |
+
self.affine_param_predictor = WN(
|
904 |
+
int(n_mel_channels / 2),
|
905 |
+
n_context_dim,
|
906 |
+
n_layers=n_layers,
|
907 |
+
n_channels=n_channels,
|
908 |
+
affine_activation=affine_activation,
|
909 |
+
use_partial_padding=use_partial_padding,
|
910 |
+
)
|
911 |
+
elif affine_model == "simple_conv":
|
912 |
+
self.affine_param_predictor = SimpleConvNet(
|
913 |
+
int(n_mel_channels / 2),
|
914 |
+
n_context_dim,
|
915 |
+
n_mel_channels,
|
916 |
+
n_layers,
|
917 |
+
with_dilation=with_dilation,
|
918 |
+
kernel_size=kernel_size,
|
919 |
+
use_partial_padding=use_partial_padding,
|
920 |
+
)
|
921 |
+
self.n_mel_channels = n_mel_channels
|
922 |
+
|
923 |
+
def get_scaling_and_logs(self, scale_unconstrained):
|
924 |
+
if self.scaling_fn == "translate":
|
925 |
+
s = torch.exp(scale_unconstrained * 0)
|
926 |
+
log_s = scale_unconstrained * 0
|
927 |
+
elif self.scaling_fn == "exp":
|
928 |
+
s = torch.exp(scale_unconstrained)
|
929 |
+
log_s = scale_unconstrained # log(exp
|
930 |
+
elif self.scaling_fn == "tanh":
|
931 |
+
s = torch.tanh(scale_unconstrained) + 1 + 1e-6
|
932 |
+
log_s = torch.log(s)
|
933 |
+
elif self.scaling_fn == "sigmoid":
|
934 |
+
s = torch.sigmoid(scale_unconstrained + 10) + 1e-6
|
935 |
+
log_s = torch.log(s)
|
936 |
+
elif isinstance(self.scaling_fn, list):
|
937 |
+
s_list, log_s_list = [], []
|
938 |
+
for i in range(scale_unconstrained.shape[1]):
|
939 |
+
scaling_i = self.scaling_fn[i]
|
940 |
+
if scaling_i == "translate":
|
941 |
+
s_i = torch.exp(scale_unconstrained[:i] * 0)
|
942 |
+
log_s_i = scale_unconstrained[:, i] * 0
|
943 |
+
elif scaling_i == "exp":
|
944 |
+
s_i = torch.exp(scale_unconstrained[:, i])
|
945 |
+
log_s_i = scale_unconstrained[:, i]
|
946 |
+
elif scaling_i == "tanh":
|
947 |
+
s_i = torch.tanh(scale_unconstrained[:, i]) + 1 + 1e-6
|
948 |
+
log_s_i = torch.log(s_i)
|
949 |
+
elif scaling_i == "sigmoid":
|
950 |
+
s_i = torch.sigmoid(scale_unconstrained[:, i]) + 1e-6
|
951 |
+
log_s_i = torch.log(s_i)
|
952 |
+
s_list.append(s_i[:, None])
|
953 |
+
log_s_list.append(log_s_i[:, None])
|
954 |
+
s = torch.cat(s_list, dim=1)
|
955 |
+
log_s = torch.cat(log_s_list, dim=1)
|
956 |
+
return s, log_s
|
957 |
+
|
958 |
+
def forward(self, z, context, inverse=False, seq_lens=None):
|
959 |
+
n_half = int(self.n_mel_channels / 2)
|
960 |
+
z_0, z_1 = z[:, :n_half], z[:, n_half:]
|
961 |
+
if self.affine_model == "wavenet":
|
962 |
+
affine_params = self.affine_param_predictor(
|
963 |
+
(z_0, context), seq_lens=seq_lens
|
964 |
+
)
|
965 |
+
elif self.affine_model == "simple_conv":
|
966 |
+
z_w_context = torch.cat((z_0, context), 1)
|
967 |
+
affine_params = self.affine_param_predictor(z_w_context, seq_lens=seq_lens)
|
968 |
+
|
969 |
+
scale_unconstrained = affine_params[:, :n_half, :]
|
970 |
+
b = affine_params[:, n_half:, :]
|
971 |
+
s, log_s = self.get_scaling_and_logs(scale_unconstrained)
|
972 |
+
|
973 |
+
if inverse:
|
974 |
+
z_1 = (z_1 - b) / s
|
975 |
+
z = torch.cat((z_0, z_1), dim=1)
|
976 |
+
return z
|
977 |
+
else:
|
978 |
+
z_1 = s * z_1 + b
|
979 |
+
z = torch.cat((z_0, z_1), dim=1)
|
980 |
+
return z, log_s
|
981 |
+
|
982 |
+
|
983 |
+
class ConvAttention(torch.nn.Module):
|
984 |
+
def __init__(
|
985 |
+
self, n_mel_channels=80, n_text_channels=512, n_att_channels=80, temperature=1.0
|
986 |
+
):
|
987 |
+
super(ConvAttention, self).__init__()
|
988 |
+
self.temperature = temperature
|
989 |
+
self.softmax = torch.nn.Softmax(dim=3)
|
990 |
+
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
991 |
+
|
992 |
+
self.key_proj = nn.Sequential(
|
993 |
+
ConvNorm(
|
994 |
+
n_text_channels,
|
995 |
+
n_text_channels * 2,
|
996 |
+
kernel_size=3,
|
997 |
+
bias=True,
|
998 |
+
w_init_gain="relu",
|
999 |
+
),
|
1000 |
+
torch.nn.ReLU(),
|
1001 |
+
ConvNorm(n_text_channels * 2, n_att_channels, kernel_size=1, bias=True),
|
1002 |
+
)
|
1003 |
+
|
1004 |
+
self.query_proj = nn.Sequential(
|
1005 |
+
ConvNorm(
|
1006 |
+
n_mel_channels,
|
1007 |
+
n_mel_channels * 2,
|
1008 |
+
kernel_size=3,
|
1009 |
+
bias=True,
|
1010 |
+
w_init_gain="relu",
|
1011 |
+
),
|
1012 |
+
torch.nn.ReLU(),
|
1013 |
+
ConvNorm(n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True),
|
1014 |
+
torch.nn.ReLU(),
|
1015 |
+
ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True),
|
1016 |
+
)
|
1017 |
+
|
1018 |
+
def run_padded_sequence(
|
1019 |
+
self, sorted_idx, unsort_idx, lens, padded_data, recurrent_model
|
1020 |
+
):
|
1021 |
+
"""Sorts input data by previded ordering (and un-ordering) and runs the
|
1022 |
+
packed data through the recurrent model
|
1023 |
+
|
1024 |
+
Args:
|
1025 |
+
sorted_idx (torch.tensor): 1D sorting index
|
1026 |
+
unsort_idx (torch.tensor): 1D unsorting index (inverse of sorted_idx)
|
1027 |
+
lens: lengths of input data (sorted in descending order)
|
1028 |
+
padded_data (torch.tensor): input sequences (padded)
|
1029 |
+
recurrent_model (nn.Module): recurrent model to run data through
|
1030 |
+
Returns:
|
1031 |
+
hidden_vectors (torch.tensor): outputs of the RNN, in the original,
|
1032 |
+
unsorted, ordering
|
1033 |
+
"""
|
1034 |
+
|
1035 |
+
# sort the data by decreasing length using provided index
|
1036 |
+
# we assume batch index is in dim=1
|
1037 |
+
padded_data = padded_data[:, sorted_idx]
|
1038 |
+
padded_data = nn.utils.rnn.pack_padded_sequence(padded_data, lens)
|
1039 |
+
hidden_vectors = recurrent_model(padded_data)[0]
|
1040 |
+
hidden_vectors, _ = nn.utils.rnn.pad_packed_sequence(hidden_vectors)
|
1041 |
+
# unsort the results at dim=1 and return
|
1042 |
+
hidden_vectors = hidden_vectors[:, unsort_idx]
|
1043 |
+
return hidden_vectors
|
1044 |
+
|
1045 |
+
def forward(
|
1046 |
+
self, queries, keys, query_lens, mask=None, key_lens=None, attn_prior=None
|
1047 |
+
):
|
1048 |
+
"""Attention mechanism for radtts. Unlike in Flowtron, we have no
|
1049 |
+
restrictions such as causality etc, since we only need this during
|
1050 |
+
training.
|
1051 |
+
|
1052 |
+
Args:
|
1053 |
+
queries (torch.tensor): B x C x T1 tensor (likely mel data)
|
1054 |
+
keys (torch.tensor): B x C2 x T2 tensor (text data)
|
1055 |
+
query_lens: lengths for sorting the queries in descending order
|
1056 |
+
mask (torch.tensor): uint8 binary mask for variable length entries
|
1057 |
+
(should be in the T2 domain)
|
1058 |
+
Output:
|
1059 |
+
attn (torch.tensor): B x 1 x T1 x T2 attention mask.
|
1060 |
+
Final dim T2 should sum to 1
|
1061 |
+
"""
|
1062 |
+
temp = 0.0005
|
1063 |
+
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
|
1064 |
+
# Beware can only do this since query_dim = attn_dim = n_mel_channels
|
1065 |
+
queries_enc = self.query_proj(queries)
|
1066 |
+
|
1067 |
+
# Gaussian Isotopic Attention
|
1068 |
+
# B x n_attn_dims x T1 x T2
|
1069 |
+
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2
|
1070 |
+
|
1071 |
+
# compute log-likelihood from gaussian
|
1072 |
+
eps = 1e-8
|
1073 |
+
attn = -temp * attn.sum(1, keepdim=True)
|
1074 |
+
if attn_prior is not None:
|
1075 |
+
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + eps)
|
1076 |
+
|
1077 |
+
attn_logprob = attn.clone()
|
1078 |
+
|
1079 |
+
if mask is not None:
|
1080 |
+
attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), -float("inf"))
|
1081 |
+
|
1082 |
+
attn = self.softmax(attn) # softmax along T2
|
1083 |
+
return attn, attn_logprob
|
configs/radtts-pp-dap-model.json
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train_config": {
|
3 |
+
"output_directory": "outdir_pp_model",
|
4 |
+
"epochs": 10000000,
|
5 |
+
"optim_algo": "RAdam",
|
6 |
+
"learning_rate": 0.001,
|
7 |
+
"weight_decay": 1e-06,
|
8 |
+
"sigma": 1.0,
|
9 |
+
"iters_per_checkpoint": 1000,
|
10 |
+
"batch_size": 16,
|
11 |
+
"seed": null,
|
12 |
+
"checkpoint_path": "",
|
13 |
+
"ignore_layers": [],
|
14 |
+
"ignore_layers_warmstart": [],
|
15 |
+
"finetune_layers": [],
|
16 |
+
"include_layers": [],
|
17 |
+
"vocoder_config_path": "models/hifigan_22khz_config.json",
|
18 |
+
"vocoder_checkpoint_path": "models/hifigan_ljs_generator_v1.pt",
|
19 |
+
"log_attribute_samples": true,
|
20 |
+
"log_decoder_samples": true,
|
21 |
+
"warmstart_checkpoint_path": "outdir_pp/model_100000",
|
22 |
+
"use_amp": true,
|
23 |
+
"grad_clip_val": 1.0,
|
24 |
+
"loss_weights": {
|
25 |
+
"blank_logprob": -1,
|
26 |
+
"ctc_loss_weight": 0.1,
|
27 |
+
"binarization_loss_weight": 1.0,
|
28 |
+
"dur_loss_weight": 1.0,
|
29 |
+
"f0_loss_weight": 1.0,
|
30 |
+
"energy_loss_weight": 1.0,
|
31 |
+
"vpred_loss_weight": 1.0
|
32 |
+
},
|
33 |
+
"binarization_start_iter": 0,
|
34 |
+
"kl_loss_start_iter": 0,
|
35 |
+
"unfreeze_modules": "all"
|
36 |
+
},
|
37 |
+
"data_config": {
|
38 |
+
"training_files": {
|
39 |
+
"LJS": {
|
40 |
+
"basedir": "filelists/",
|
41 |
+
"audiodir": "wavs",
|
42 |
+
"filelist": "3speakers_ukrainian_train_filelist_dc.txt",
|
43 |
+
"lmdbpath": ""
|
44 |
+
}
|
45 |
+
},
|
46 |
+
"validation_files": {
|
47 |
+
"LJS": {
|
48 |
+
"basedir": "filelists/",
|
49 |
+
"audiodir": "wavs",
|
50 |
+
"filelist": "3speakers_ukrainian_val_filelist_dc.txt",
|
51 |
+
"lmdbpath": ""
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"dur_min": 0.1,
|
55 |
+
"dur_max": 10.2,
|
56 |
+
"sampling_rate": 22050,
|
57 |
+
"filter_length": 1024,
|
58 |
+
"hop_length": 256,
|
59 |
+
"win_length": 1024,
|
60 |
+
"n_mel_channels": 80,
|
61 |
+
"mel_fmin": 0.0,
|
62 |
+
"mel_fmax": 8000.0,
|
63 |
+
"f0_min": 80.0,
|
64 |
+
"f0_max": 640.0,
|
65 |
+
"max_wav_value": 32768.0,
|
66 |
+
"use_f0": true,
|
67 |
+
"use_log_f0": 0,
|
68 |
+
"use_energy_avg": true,
|
69 |
+
"use_scaled_energy": true,
|
70 |
+
"symbol_set": "ukrainian",
|
71 |
+
"cleaner_names": [
|
72 |
+
"ukrainian_cleaners"
|
73 |
+
],
|
74 |
+
"heteronyms_path": "tts_text_processing/heteronyms",
|
75 |
+
"phoneme_dict_path": "tts_text_processing/cmudict-0.7b",
|
76 |
+
"p_phoneme": 0.0,
|
77 |
+
"handle_phoneme": "word",
|
78 |
+
"handle_phoneme_ambiguous": "ignore",
|
79 |
+
"include_speakers": null,
|
80 |
+
"n_frames": -1,
|
81 |
+
"betabinom_cache_path": "/home/dmytro_chaplinsky/RAD-TTS/radtts-code/cache",
|
82 |
+
"lmdb_cache_path": "",
|
83 |
+
"use_attn_prior_masking": true,
|
84 |
+
"prepend_space_to_text": true,
|
85 |
+
"append_space_to_text": true,
|
86 |
+
"add_bos_eos_to_text": false,
|
87 |
+
"betabinom_scaling_factor": 1.0,
|
88 |
+
"distance_tx_unvoiced": false,
|
89 |
+
"mel_noise_scale": 0.0
|
90 |
+
},
|
91 |
+
"dist_config": {
|
92 |
+
"dist_backend": "nccl",
|
93 |
+
"dist_url": "tcp://localhost:54321"
|
94 |
+
},
|
95 |
+
"model_config": {
|
96 |
+
"n_speakers": 3,
|
97 |
+
"n_speaker_dim": 16,
|
98 |
+
"n_text": 185,
|
99 |
+
"n_text_dim": 512,
|
100 |
+
"n_flows": 8,
|
101 |
+
"n_conv_layers_per_step": 4,
|
102 |
+
"n_mel_channels": 80,
|
103 |
+
"n_hidden": 1024,
|
104 |
+
"mel_encoder_n_hidden": 512,
|
105 |
+
"dummy_speaker_embedding": false,
|
106 |
+
"n_early_size": 2,
|
107 |
+
"n_early_every": 2,
|
108 |
+
"n_group_size": 2,
|
109 |
+
"affine_model": "wavenet",
|
110 |
+
"include_modules": "decatndpmvpredapm",
|
111 |
+
"scaling_fn": "tanh",
|
112 |
+
"matrix_decomposition": "LUS",
|
113 |
+
"learn_alignments": true,
|
114 |
+
"use_speaker_emb_for_alignment": false,
|
115 |
+
"attn_straight_through_estimator": true,
|
116 |
+
"use_context_lstm": true,
|
117 |
+
"context_lstm_norm": "spectral",
|
118 |
+
"context_lstm_w_f0_and_energy": true,
|
119 |
+
"text_encoder_lstm_norm": "spectral",
|
120 |
+
"n_f0_dims": 1,
|
121 |
+
"n_energy_avg_dims": 1,
|
122 |
+
"use_first_order_features": false,
|
123 |
+
"unvoiced_bias_activation": "relu",
|
124 |
+
"decoder_use_partial_padding": true,
|
125 |
+
"decoder_use_unvoiced_bias": true,
|
126 |
+
"ap_pred_log_f0": true,
|
127 |
+
"ap_use_unvoiced_bias": false,
|
128 |
+
"ap_use_voiced_embeddings": true,
|
129 |
+
"dur_model_config": {
|
130 |
+
"name": "dap",
|
131 |
+
"hparams": {
|
132 |
+
"n_speaker_dim": 16,
|
133 |
+
"bottleneck_hparams": {
|
134 |
+
"in_dim": 512,
|
135 |
+
"reduction_factor": 16,
|
136 |
+
"norm": "weightnorm",
|
137 |
+
"non_linearity": "relu"
|
138 |
+
},
|
139 |
+
"take_log_of_input": true,
|
140 |
+
"arch_hparams": {
|
141 |
+
"out_dim": 1,
|
142 |
+
"n_layers": 2,
|
143 |
+
"n_channels": 256,
|
144 |
+
"kernel_size": 3,
|
145 |
+
"p_dropout": 0.25,
|
146 |
+
"in_dim": 48
|
147 |
+
}
|
148 |
+
}
|
149 |
+
},
|
150 |
+
"f0_model_config": {
|
151 |
+
"name": "dap",
|
152 |
+
"hparams": {
|
153 |
+
"n_speaker_dim": 16,
|
154 |
+
"bottleneck_hparams": {
|
155 |
+
"in_dim": 512,
|
156 |
+
"reduction_factor": 16,
|
157 |
+
"norm": "weightnorm",
|
158 |
+
"non_linearity": "relu"
|
159 |
+
},
|
160 |
+
"take_log_of_input": false,
|
161 |
+
"use_transformer": false,
|
162 |
+
"arch_hparams": {
|
163 |
+
"out_dim": 1,
|
164 |
+
"n_layers": 2,
|
165 |
+
"n_channels": 256,
|
166 |
+
"kernel_size": 11,
|
167 |
+
"p_dropout": 0.5,
|
168 |
+
"in_dim": 48
|
169 |
+
}
|
170 |
+
}
|
171 |
+
},
|
172 |
+
"energy_model_config": {
|
173 |
+
"name": "dap",
|
174 |
+
"hparams": {
|
175 |
+
"n_speaker_dim": 16,
|
176 |
+
"bottleneck_hparams": {
|
177 |
+
"in_dim": 512,
|
178 |
+
"reduction_factor": 16,
|
179 |
+
"norm": "weightnorm",
|
180 |
+
"non_linearity": "relu"
|
181 |
+
},
|
182 |
+
"take_log_of_input": false,
|
183 |
+
"use_transformer": false,
|
184 |
+
"arch_hparams": {
|
185 |
+
"out_dim": 1,
|
186 |
+
"n_layers": 2,
|
187 |
+
"n_channels": 256,
|
188 |
+
"kernel_size": 3,
|
189 |
+
"p_dropout": 0.25,
|
190 |
+
"in_dim": 48
|
191 |
+
}
|
192 |
+
}
|
193 |
+
},
|
194 |
+
"v_model_config": {
|
195 |
+
"name": "dap",
|
196 |
+
"hparams": {
|
197 |
+
"n_speaker_dim": 16,
|
198 |
+
"take_log_of_input": false,
|
199 |
+
"bottleneck_hparams": {
|
200 |
+
"in_dim": 512,
|
201 |
+
"reduction_factor": 16,
|
202 |
+
"norm": "weightnorm",
|
203 |
+
"non_linearity": "relu"
|
204 |
+
},
|
205 |
+
"arch_hparams": {
|
206 |
+
"out_dim": 1,
|
207 |
+
"n_layers": 2,
|
208 |
+
"n_channels": 256,
|
209 |
+
"kernel_size": 3,
|
210 |
+
"p_dropout": 0.5,
|
211 |
+
"lstm_type": "",
|
212 |
+
"use_linear": 1,
|
213 |
+
"in_dim": 48
|
214 |
+
}
|
215 |
+
}
|
216 |
+
}
|
217 |
+
}
|
218 |
+
}
|
data.py
ADDED
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
|
22 |
+
# Based on https://github.com/NVIDIA/flowtron/blob/master/data.py
|
23 |
+
# Original license text:
|
24 |
+
###############################################################################
|
25 |
+
#
|
26 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
27 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
28 |
+
# you may not use this file except in compliance with the License.
|
29 |
+
# You may obtain a copy of the License at
|
30 |
+
#
|
31 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
32 |
+
#
|
33 |
+
# Unless required by applicable law or agreed to in writing, software
|
34 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
35 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
36 |
+
# See the License for the specific language governing permissions and
|
37 |
+
# limitations under the License.
|
38 |
+
#
|
39 |
+
###############################################################################
|
40 |
+
|
41 |
+
import os
|
42 |
+
import argparse
|
43 |
+
import json
|
44 |
+
import numpy as np
|
45 |
+
import lmdb
|
46 |
+
import pickle as pkl
|
47 |
+
import torch
|
48 |
+
import torch.utils.data
|
49 |
+
from scipy.io.wavfile import read
|
50 |
+
from audio_processing import TacotronSTFT
|
51 |
+
from tts_text_processing.text_processing import TextProcessing
|
52 |
+
from scipy.stats import betabinom
|
53 |
+
from librosa import pyin
|
54 |
+
from common import update_params
|
55 |
+
from scipy.ndimage import distance_transform_edt as distance_transform
|
56 |
+
|
57 |
+
|
58 |
+
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=0.05):
|
59 |
+
P = phoneme_count
|
60 |
+
M = mel_count
|
61 |
+
x = np.arange(0, P)
|
62 |
+
mel_text_probs = []
|
63 |
+
for i in range(1, M + 1):
|
64 |
+
a, b = scaling_factor * i, scaling_factor * (M + 1 - i)
|
65 |
+
rv = betabinom(P - 1, a, b)
|
66 |
+
mel_i_prob = rv.pmf(x)
|
67 |
+
mel_text_probs.append(mel_i_prob)
|
68 |
+
return torch.tensor(np.array(mel_text_probs))
|
69 |
+
|
70 |
+
|
71 |
+
def load_wav_to_torch(full_path):
|
72 |
+
"""Loads wavdata into torch array"""
|
73 |
+
sampling_rate, data = read(full_path)
|
74 |
+
return torch.from_numpy(np.array(data)).float(), sampling_rate
|
75 |
+
|
76 |
+
|
77 |
+
class Data(torch.utils.data.Dataset):
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
datasets,
|
81 |
+
filter_length,
|
82 |
+
hop_length,
|
83 |
+
win_length,
|
84 |
+
sampling_rate,
|
85 |
+
n_mel_channels,
|
86 |
+
mel_fmin,
|
87 |
+
mel_fmax,
|
88 |
+
f0_min,
|
89 |
+
f0_max,
|
90 |
+
max_wav_value,
|
91 |
+
use_f0,
|
92 |
+
use_energy_avg,
|
93 |
+
use_log_f0,
|
94 |
+
use_scaled_energy,
|
95 |
+
symbol_set,
|
96 |
+
cleaner_names,
|
97 |
+
heteronyms_path,
|
98 |
+
phoneme_dict_path,
|
99 |
+
p_phoneme,
|
100 |
+
handle_phoneme="word",
|
101 |
+
handle_phoneme_ambiguous="ignore",
|
102 |
+
speaker_ids=None,
|
103 |
+
include_speakers=None,
|
104 |
+
n_frames=-1,
|
105 |
+
use_attn_prior_masking=True,
|
106 |
+
prepend_space_to_text=True,
|
107 |
+
append_space_to_text=True,
|
108 |
+
add_bos_eos_to_text=False,
|
109 |
+
betabinom_cache_path="",
|
110 |
+
betabinom_scaling_factor=0.05,
|
111 |
+
lmdb_cache_path="",
|
112 |
+
dur_min=None,
|
113 |
+
dur_max=None,
|
114 |
+
combine_speaker_and_emotion=False,
|
115 |
+
**kwargs,
|
116 |
+
):
|
117 |
+
self.combine_speaker_and_emotion = combine_speaker_and_emotion
|
118 |
+
self.max_wav_value = max_wav_value
|
119 |
+
self.audio_lmdb_dict = {} # dictionary of lmdbs for audio data
|
120 |
+
self.data = self.load_data(datasets)
|
121 |
+
self.distance_tx_unvoiced = False
|
122 |
+
if "distance_tx_unvoiced" in kwargs.keys():
|
123 |
+
self.distance_tx_unvoiced = kwargs["distance_tx_unvoiced"]
|
124 |
+
self.stft = TacotronSTFT(
|
125 |
+
filter_length=filter_length,
|
126 |
+
hop_length=hop_length,
|
127 |
+
win_length=win_length,
|
128 |
+
sampling_rate=sampling_rate,
|
129 |
+
n_mel_channels=n_mel_channels,
|
130 |
+
mel_fmin=mel_fmin,
|
131 |
+
mel_fmax=mel_fmax,
|
132 |
+
)
|
133 |
+
|
134 |
+
self.do_mel_scaling = kwargs.get("do_mel_scaling", True)
|
135 |
+
self.mel_noise_scale = kwargs.get("mel_noise_scale", 0.0)
|
136 |
+
self.filter_length = filter_length
|
137 |
+
self.hop_length = hop_length
|
138 |
+
self.win_length = win_length
|
139 |
+
self.mel_fmin = mel_fmin
|
140 |
+
self.mel_fmax = mel_fmax
|
141 |
+
self.f0_min = f0_min
|
142 |
+
self.f0_max = f0_max
|
143 |
+
self.use_f0 = use_f0
|
144 |
+
self.use_log_f0 = use_log_f0
|
145 |
+
self.use_energy_avg = use_energy_avg
|
146 |
+
self.use_scaled_energy = use_scaled_energy
|
147 |
+
self.sampling_rate = sampling_rate
|
148 |
+
self.tp = TextProcessing(
|
149 |
+
symbol_set,
|
150 |
+
cleaner_names,
|
151 |
+
heteronyms_path,
|
152 |
+
phoneme_dict_path,
|
153 |
+
p_phoneme=p_phoneme,
|
154 |
+
handle_phoneme=handle_phoneme,
|
155 |
+
handle_phoneme_ambiguous=handle_phoneme_ambiguous,
|
156 |
+
prepend_space_to_text=prepend_space_to_text,
|
157 |
+
append_space_to_text=append_space_to_text,
|
158 |
+
add_bos_eos_to_text=add_bos_eos_to_text,
|
159 |
+
)
|
160 |
+
|
161 |
+
self.dur_min = dur_min
|
162 |
+
self.dur_max = dur_max
|
163 |
+
if speaker_ids is None or speaker_ids == "":
|
164 |
+
self.speaker_ids = self.create_speaker_lookup_table(self.data)
|
165 |
+
else:
|
166 |
+
self.speaker_ids = speaker_ids
|
167 |
+
|
168 |
+
print("Number of files", len(self.data))
|
169 |
+
if include_speakers is not None:
|
170 |
+
for speaker_set, include in include_speakers:
|
171 |
+
self.filter_by_speakers_(speaker_set, include)
|
172 |
+
print("Number of files after speaker filtering", len(self.data))
|
173 |
+
|
174 |
+
if dur_min is not None and dur_max is not None:
|
175 |
+
self.filter_by_duration_(dur_min, dur_max)
|
176 |
+
print("Number of files after duration filtering", len(self.data))
|
177 |
+
|
178 |
+
self.use_attn_prior_masking = bool(use_attn_prior_masking)
|
179 |
+
self.prepend_space_to_text = bool(prepend_space_to_text)
|
180 |
+
self.append_space_to_text = bool(append_space_to_text)
|
181 |
+
self.betabinom_cache_path = betabinom_cache_path
|
182 |
+
self.betabinom_scaling_factor = betabinom_scaling_factor
|
183 |
+
self.lmdb_cache_path = lmdb_cache_path
|
184 |
+
if self.lmdb_cache_path != "":
|
185 |
+
self.cache_data_lmdb = lmdb.open(
|
186 |
+
self.lmdb_cache_path, readonly=True, max_readers=1024, lock=False
|
187 |
+
).begin()
|
188 |
+
|
189 |
+
# # make sure caching path exists
|
190 |
+
# if not os.path.exists(self.betabinom_cache_path):
|
191 |
+
# os.makedirs(self.betabinom_cache_path)
|
192 |
+
|
193 |
+
print("Dataloader initialized with no augmentations")
|
194 |
+
self.speaker_map = None
|
195 |
+
if "speaker_map" in kwargs:
|
196 |
+
self.speaker_map = kwargs["speaker_map"]
|
197 |
+
|
198 |
+
def load_data(self, datasets, split="|"):
|
199 |
+
dataset = []
|
200 |
+
for dset_name, dset_dict in datasets.items():
|
201 |
+
folder_path = dset_dict["basedir"]
|
202 |
+
audiodir = dset_dict["audiodir"]
|
203 |
+
filename = dset_dict["filelist"]
|
204 |
+
audio_lmdb_key = None
|
205 |
+
if "lmdbpath" in dset_dict.keys() and len(dset_dict["lmdbpath"]) > 0:
|
206 |
+
self.audio_lmdb_dict[dset_name] = lmdb.open(
|
207 |
+
dset_dict["lmdbpath"], readonly=True, max_readers=256, lock=False
|
208 |
+
).begin()
|
209 |
+
audio_lmdb_key = dset_name
|
210 |
+
|
211 |
+
wav_folder_prefix = os.path.join(folder_path, audiodir)
|
212 |
+
filelist_path = os.path.join(folder_path, filename)
|
213 |
+
with open(filelist_path, encoding="utf-8") as f:
|
214 |
+
data = [line.strip().split(split) for line in f]
|
215 |
+
|
216 |
+
for d in data:
|
217 |
+
emotion = "other" if len(d) == 3 else d[3]
|
218 |
+
duration = -1 if len(d) == 3 else d[4]
|
219 |
+
dataset.append(
|
220 |
+
{
|
221 |
+
"audiopath": os.path.join(wav_folder_prefix, d[0]),
|
222 |
+
"text": d[1],
|
223 |
+
"speaker": d[2] + "-" + emotion
|
224 |
+
if self.combine_speaker_and_emotion
|
225 |
+
else d[2],
|
226 |
+
"emotion": emotion,
|
227 |
+
"duration": float(duration),
|
228 |
+
"lmdb_key": audio_lmdb_key,
|
229 |
+
}
|
230 |
+
)
|
231 |
+
return dataset
|
232 |
+
|
233 |
+
def filter_by_speakers_(self, speakers, include=True):
|
234 |
+
print("Include spaker {}: {}".format(speakers, include))
|
235 |
+
if include:
|
236 |
+
self.data = [x for x in self.data if x["speaker"] in speakers]
|
237 |
+
else:
|
238 |
+
self.data = [x for x in self.data if x["speaker"] not in speakers]
|
239 |
+
|
240 |
+
def filter_by_duration_(self, dur_min, dur_max):
|
241 |
+
self.data = [
|
242 |
+
x
|
243 |
+
for x in self.data
|
244 |
+
if x["duration"] == -1
|
245 |
+
or (x["duration"] >= dur_min and x["duration"] <= dur_max)
|
246 |
+
]
|
247 |
+
|
248 |
+
def create_speaker_lookup_table(self, data):
|
249 |
+
speaker_ids = np.sort(np.unique([x["speaker"] for x in data]))
|
250 |
+
d = {speaker_ids[i]: i for i in range(len(speaker_ids))}
|
251 |
+
print("Number of speakers:", len(d))
|
252 |
+
print("Speaker IDS", d)
|
253 |
+
return d
|
254 |
+
|
255 |
+
def f0_normalize(self, x):
|
256 |
+
if self.use_log_f0:
|
257 |
+
mask = x >= self.f0_min
|
258 |
+
x[mask] = torch.log(x[mask])
|
259 |
+
x[~mask] = 0.0
|
260 |
+
|
261 |
+
return x
|
262 |
+
|
263 |
+
def f0_denormalize(self, x):
|
264 |
+
if self.use_log_f0:
|
265 |
+
log_f0_min = np.log(self.f0_min)
|
266 |
+
mask = x >= log_f0_min
|
267 |
+
x[mask] = torch.exp(x[mask])
|
268 |
+
x[~mask] = 0.0
|
269 |
+
x[x <= 0.0] = 0.0
|
270 |
+
|
271 |
+
return x
|
272 |
+
|
273 |
+
def energy_avg_normalize(self, x):
|
274 |
+
if self.use_scaled_energy:
|
275 |
+
x = (x + 20.0) / 20.0
|
276 |
+
return x
|
277 |
+
|
278 |
+
def energy_avg_denormalize(self, x):
|
279 |
+
if self.use_scaled_energy:
|
280 |
+
x = x * 20.0 - 20.0
|
281 |
+
return x
|
282 |
+
|
283 |
+
def get_f0_pvoiced(
|
284 |
+
self,
|
285 |
+
audio,
|
286 |
+
sampling_rate=22050,
|
287 |
+
frame_length=1024,
|
288 |
+
hop_length=256,
|
289 |
+
f0_min=100,
|
290 |
+
f0_max=300,
|
291 |
+
):
|
292 |
+
audio_norm = audio / self.max_wav_value
|
293 |
+
f0, voiced_mask, p_voiced = pyin(
|
294 |
+
audio_norm,
|
295 |
+
f0_min,
|
296 |
+
f0_max,
|
297 |
+
sampling_rate,
|
298 |
+
frame_length=frame_length,
|
299 |
+
win_length=frame_length // 2,
|
300 |
+
hop_length=hop_length,
|
301 |
+
)
|
302 |
+
f0[~voiced_mask] = 0.0
|
303 |
+
f0 = torch.FloatTensor(f0)
|
304 |
+
p_voiced = torch.FloatTensor(p_voiced)
|
305 |
+
voiced_mask = torch.FloatTensor(voiced_mask)
|
306 |
+
return f0, voiced_mask, p_voiced
|
307 |
+
|
308 |
+
def get_energy_average(self, mel):
|
309 |
+
energy_avg = mel.mean(0)
|
310 |
+
energy_avg = self.energy_avg_normalize(energy_avg)
|
311 |
+
return energy_avg
|
312 |
+
|
313 |
+
def get_mel(self, audio):
|
314 |
+
audio_norm = audio / self.max_wav_value
|
315 |
+
audio_norm = audio_norm.unsqueeze(0)
|
316 |
+
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
|
317 |
+
melspec = self.stft.mel_spectrogram(audio_norm)
|
318 |
+
melspec = torch.squeeze(melspec, 0)
|
319 |
+
if self.do_mel_scaling:
|
320 |
+
melspec = (melspec + 5.5) / 2
|
321 |
+
if self.mel_noise_scale > 0:
|
322 |
+
melspec += torch.randn_like(melspec) * self.mel_noise_scale
|
323 |
+
return melspec
|
324 |
+
|
325 |
+
def get_speaker_id(self, speaker):
|
326 |
+
if self.speaker_map is not None and speaker in self.speaker_map:
|
327 |
+
speaker = self.speaker_map[speaker]
|
328 |
+
|
329 |
+
return torch.LongTensor([self.speaker_ids[speaker]])
|
330 |
+
|
331 |
+
def get_text(self, text):
|
332 |
+
text = self.tp.encode_text(text)
|
333 |
+
text = torch.LongTensor(text)
|
334 |
+
return text
|
335 |
+
|
336 |
+
def get_attention_prior(self, n_tokens, n_frames):
|
337 |
+
# cache the entire attn_prior by filename
|
338 |
+
if self.use_attn_prior_masking:
|
339 |
+
filename = "{}_{}".format(n_tokens, n_frames)
|
340 |
+
prior_path = os.path.join(self.betabinom_cache_path, filename)
|
341 |
+
prior_path += "_prior.pth"
|
342 |
+
if self.lmdb_cache_path != "":
|
343 |
+
attn_prior = pkl.loads(
|
344 |
+
self.cache_data_lmdb.get(prior_path.encode("ascii"))
|
345 |
+
)
|
346 |
+
elif os.path.exists(prior_path):
|
347 |
+
attn_prior = torch.load(prior_path)
|
348 |
+
else:
|
349 |
+
attn_prior = beta_binomial_prior_distribution(
|
350 |
+
n_tokens, n_frames, self.betabinom_scaling_factor
|
351 |
+
)
|
352 |
+
torch.save(attn_prior, prior_path)
|
353 |
+
else:
|
354 |
+
attn_prior = torch.ones(n_frames, n_tokens) # all ones baseline
|
355 |
+
|
356 |
+
return attn_prior
|
357 |
+
|
358 |
+
def __getitem__(self, index):
|
359 |
+
data = self.data[index]
|
360 |
+
audiopath, text = data["audiopath"], data["text"]
|
361 |
+
speaker_id = data["speaker"]
|
362 |
+
|
363 |
+
if data["lmdb_key"] is not None:
|
364 |
+
data_dict = pkl.loads(
|
365 |
+
self.audio_lmdb_dict[data["lmdb_key"]].get(audiopath.encode("ascii"))
|
366 |
+
)
|
367 |
+
audio = data_dict["audio"]
|
368 |
+
sampling_rate = data_dict["sampling_rate"]
|
369 |
+
else:
|
370 |
+
audio, sampling_rate = load_wav_to_torch(audiopath)
|
371 |
+
|
372 |
+
if sampling_rate != self.sampling_rate:
|
373 |
+
raise ValueError(
|
374 |
+
"{} SR doesn't match target {} SR".format(
|
375 |
+
sampling_rate, self.sampling_rate
|
376 |
+
)
|
377 |
+
)
|
378 |
+
|
379 |
+
mel = self.get_mel(audio)
|
380 |
+
f0 = None
|
381 |
+
p_voiced = None
|
382 |
+
voiced_mask = None
|
383 |
+
if self.use_f0:
|
384 |
+
filename = "_".join(audiopath.split("/")[-3:])
|
385 |
+
f0_path = os.path.join(self.betabinom_cache_path, filename)
|
386 |
+
f0_path += "_f0_sr{}_fl{}_hl{}_f0min{}_f0max{}_log{}.pt".format(
|
387 |
+
self.sampling_rate,
|
388 |
+
self.filter_length,
|
389 |
+
self.hop_length,
|
390 |
+
self.f0_min,
|
391 |
+
self.f0_max,
|
392 |
+
self.use_log_f0,
|
393 |
+
)
|
394 |
+
|
395 |
+
dikt = None
|
396 |
+
if len(self.lmdb_cache_path) > 0:
|
397 |
+
dikt = pkl.loads(self.cache_data_lmdb.get(f0_path.encode("ascii")))
|
398 |
+
f0 = dikt["f0"]
|
399 |
+
p_voiced = dikt["p_voiced"]
|
400 |
+
voiced_mask = dikt["voiced_mask"]
|
401 |
+
elif os.path.exists(f0_path):
|
402 |
+
try:
|
403 |
+
dikt = torch.load(f0_path)
|
404 |
+
except:
|
405 |
+
print(f"f0 loading from {f0_path} is broken, recomputing.")
|
406 |
+
|
407 |
+
if dikt is not None:
|
408 |
+
f0 = dikt["f0"]
|
409 |
+
p_voiced = dikt["p_voiced"]
|
410 |
+
voiced_mask = dikt["voiced_mask"]
|
411 |
+
else:
|
412 |
+
f0, voiced_mask, p_voiced = self.get_f0_pvoiced(
|
413 |
+
audio.cpu().numpy(),
|
414 |
+
self.sampling_rate,
|
415 |
+
self.filter_length,
|
416 |
+
self.hop_length,
|
417 |
+
self.f0_min,
|
418 |
+
self.f0_max,
|
419 |
+
)
|
420 |
+
print("saving f0 to {}".format(f0_path))
|
421 |
+
torch.save(
|
422 |
+
{"f0": f0, "voiced_mask": voiced_mask, "p_voiced": p_voiced},
|
423 |
+
f0_path,
|
424 |
+
)
|
425 |
+
if f0 is None:
|
426 |
+
raise Exception("STOP, BROKEN F0 {}".format(audiopath))
|
427 |
+
|
428 |
+
f0 = self.f0_normalize(f0)
|
429 |
+
if self.distance_tx_unvoiced:
|
430 |
+
mask = f0 <= 0.0
|
431 |
+
distance_map = np.log(distance_transform(mask))
|
432 |
+
distance_map[distance_map <= 0] = 0.0
|
433 |
+
f0 = f0 - distance_map
|
434 |
+
|
435 |
+
energy_avg = None
|
436 |
+
if self.use_energy_avg:
|
437 |
+
energy_avg = self.get_energy_average(mel)
|
438 |
+
if self.use_scaled_energy and energy_avg.min() < 0.0:
|
439 |
+
print(audiopath, "has scaled energy avg smaller than 0")
|
440 |
+
|
441 |
+
speaker_id = self.get_speaker_id(speaker_id)
|
442 |
+
text_encoded = self.get_text(text)
|
443 |
+
|
444 |
+
attn_prior = self.get_attention_prior(text_encoded.shape[0], mel.shape[1])
|
445 |
+
|
446 |
+
if not self.use_attn_prior_masking:
|
447 |
+
attn_prior = None
|
448 |
+
|
449 |
+
return {
|
450 |
+
"mel": mel,
|
451 |
+
"speaker_id": speaker_id,
|
452 |
+
"text_encoded": text_encoded,
|
453 |
+
"audiopath": audiopath,
|
454 |
+
"attn_prior": attn_prior,
|
455 |
+
"f0": f0,
|
456 |
+
"p_voiced": p_voiced,
|
457 |
+
"voiced_mask": voiced_mask,
|
458 |
+
"energy_avg": energy_avg,
|
459 |
+
}
|
460 |
+
|
461 |
+
def __len__(self):
|
462 |
+
return len(self.data)
|
463 |
+
|
464 |
+
|
465 |
+
class DataCollate:
|
466 |
+
"""Zero-pads model inputs and targets given number of steps"""
|
467 |
+
|
468 |
+
def __init__(self, n_frames_per_step=1):
|
469 |
+
self.n_frames_per_step = n_frames_per_step
|
470 |
+
|
471 |
+
def __call__(self, batch):
|
472 |
+
"""Collate from normalized data"""
|
473 |
+
# Right zero-pad all one-hot text sequences to max input length
|
474 |
+
input_lengths, ids_sorted_decreasing = torch.sort(
|
475 |
+
torch.LongTensor([len(x["text_encoded"]) for x in batch]),
|
476 |
+
dim=0,
|
477 |
+
descending=True,
|
478 |
+
)
|
479 |
+
|
480 |
+
max_input_len = input_lengths[0]
|
481 |
+
text_padded = torch.LongTensor(len(batch), max_input_len)
|
482 |
+
text_padded.zero_()
|
483 |
+
|
484 |
+
for i in range(len(ids_sorted_decreasing)):
|
485 |
+
text = batch[ids_sorted_decreasing[i]]["text_encoded"]
|
486 |
+
text_padded[i, : text.size(0)] = text
|
487 |
+
|
488 |
+
# Right zero-pad mel-spec
|
489 |
+
num_mel_channels = batch[0]["mel"].size(0)
|
490 |
+
max_target_len = max([x["mel"].size(1) for x in batch])
|
491 |
+
|
492 |
+
# include mel padded, gate padded and speaker ids
|
493 |
+
mel_padded = torch.FloatTensor(len(batch), num_mel_channels, max_target_len)
|
494 |
+
mel_padded.zero_()
|
495 |
+
f0_padded = None
|
496 |
+
p_voiced_padded = None
|
497 |
+
voiced_mask_padded = None
|
498 |
+
energy_avg_padded = None
|
499 |
+
if batch[0]["f0"] is not None:
|
500 |
+
f0_padded = torch.FloatTensor(len(batch), max_target_len)
|
501 |
+
f0_padded.zero_()
|
502 |
+
|
503 |
+
if batch[0]["p_voiced"] is not None:
|
504 |
+
p_voiced_padded = torch.FloatTensor(len(batch), max_target_len)
|
505 |
+
p_voiced_padded.zero_()
|
506 |
+
|
507 |
+
if batch[0]["voiced_mask"] is not None:
|
508 |
+
voiced_mask_padded = torch.FloatTensor(len(batch), max_target_len)
|
509 |
+
voiced_mask_padded.zero_()
|
510 |
+
|
511 |
+
if batch[0]["energy_avg"] is not None:
|
512 |
+
energy_avg_padded = torch.FloatTensor(len(batch), max_target_len)
|
513 |
+
energy_avg_padded.zero_()
|
514 |
+
|
515 |
+
attn_prior_padded = torch.FloatTensor(len(batch), max_target_len, max_input_len)
|
516 |
+
attn_prior_padded.zero_()
|
517 |
+
|
518 |
+
output_lengths = torch.LongTensor(len(batch))
|
519 |
+
speaker_ids = torch.LongTensor(len(batch))
|
520 |
+
audiopaths = []
|
521 |
+
for i in range(len(ids_sorted_decreasing)):
|
522 |
+
mel = batch[ids_sorted_decreasing[i]]["mel"]
|
523 |
+
mel_padded[i, :, : mel.size(1)] = mel
|
524 |
+
if batch[ids_sorted_decreasing[i]]["f0"] is not None:
|
525 |
+
f0 = batch[ids_sorted_decreasing[i]]["f0"]
|
526 |
+
f0_padded[i, : len(f0)] = f0
|
527 |
+
|
528 |
+
if batch[ids_sorted_decreasing[i]]["voiced_mask"] is not None:
|
529 |
+
voiced_mask = batch[ids_sorted_decreasing[i]]["voiced_mask"]
|
530 |
+
voiced_mask_padded[i, : len(f0)] = voiced_mask
|
531 |
+
|
532 |
+
if batch[ids_sorted_decreasing[i]]["p_voiced"] is not None:
|
533 |
+
p_voiced = batch[ids_sorted_decreasing[i]]["p_voiced"]
|
534 |
+
p_voiced_padded[i, : len(f0)] = p_voiced
|
535 |
+
|
536 |
+
if batch[ids_sorted_decreasing[i]]["energy_avg"] is not None:
|
537 |
+
energy_avg = batch[ids_sorted_decreasing[i]]["energy_avg"]
|
538 |
+
energy_avg_padded[i, : len(energy_avg)] = energy_avg
|
539 |
+
|
540 |
+
output_lengths[i] = mel.size(1)
|
541 |
+
speaker_ids[i] = batch[ids_sorted_decreasing[i]]["speaker_id"]
|
542 |
+
audiopath = batch[ids_sorted_decreasing[i]]["audiopath"]
|
543 |
+
audiopaths.append(audiopath)
|
544 |
+
cur_attn_prior = batch[ids_sorted_decreasing[i]]["attn_prior"]
|
545 |
+
if cur_attn_prior is None:
|
546 |
+
attn_prior_padded = None
|
547 |
+
else:
|
548 |
+
attn_prior_padded[
|
549 |
+
i, : cur_attn_prior.size(0), : cur_attn_prior.size(1)
|
550 |
+
] = cur_attn_prior
|
551 |
+
|
552 |
+
return {
|
553 |
+
"mel": mel_padded,
|
554 |
+
"speaker_ids": speaker_ids,
|
555 |
+
"text": text_padded,
|
556 |
+
"input_lengths": input_lengths,
|
557 |
+
"output_lengths": output_lengths,
|
558 |
+
"audiopaths": audiopaths,
|
559 |
+
"attn_prior": attn_prior_padded,
|
560 |
+
"f0": f0_padded,
|
561 |
+
"p_voiced": p_voiced_padded,
|
562 |
+
"voiced_mask": voiced_mask_padded,
|
563 |
+
"energy_avg": energy_avg_padded,
|
564 |
+
}
|
565 |
+
|
566 |
+
|
567 |
+
# ===================================================================
|
568 |
+
# Takes directory of clean audio and makes directory of spectrograms
|
569 |
+
# Useful for making test sets
|
570 |
+
# ===================================================================
|
571 |
+
if __name__ == "__main__":
|
572 |
+
# Get defaults so it can work with no Sacred
|
573 |
+
parser = argparse.ArgumentParser()
|
574 |
+
parser.add_argument("-c", "--config", type=str, help="JSON file for configuration")
|
575 |
+
parser.add_argument("-p", "--params", nargs="+", default=[])
|
576 |
+
args = parser.parse_args()
|
577 |
+
args.rank = 0
|
578 |
+
|
579 |
+
# Parse configs. Globals nicer in this case
|
580 |
+
with open(args.config) as f:
|
581 |
+
data = f.read()
|
582 |
+
|
583 |
+
config = json.loads(data)
|
584 |
+
update_params(config, args.params)
|
585 |
+
print(config)
|
586 |
+
|
587 |
+
data_config = config["data_config"]
|
588 |
+
|
589 |
+
ignore_keys = ["training_files", "validation_files"]
|
590 |
+
trainset = Data(
|
591 |
+
data_config["training_files"],
|
592 |
+
**dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
|
593 |
+
)
|
594 |
+
|
595 |
+
valset = Data(
|
596 |
+
data_config["validation_files"],
|
597 |
+
**dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
|
598 |
+
speaker_ids=trainset.speaker_ids,
|
599 |
+
)
|
600 |
+
|
601 |
+
collate_fn = DataCollate()
|
602 |
+
|
603 |
+
for dataset in (trainset, valset):
|
604 |
+
for i, batch in enumerate(dataset):
|
605 |
+
out = batch
|
606 |
+
print("{}/{}".format(i, len(dataset)))
|
distributed.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original source: https://github.com/NVIDIA/waveglow/blob/master/distributed.py
|
2 |
+
#
|
3 |
+
# Original license text:
|
4 |
+
# *****************************************************************************
|
5 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
6 |
+
#
|
7 |
+
# Redistribution and use in source and binary forms, with or without
|
8 |
+
# modification, are permitted provided that the following conditions are met:
|
9 |
+
# * Redistributions of source code must retain the above copyright
|
10 |
+
# notice, this list of conditions and the following disclaimer.
|
11 |
+
# * Redistributions in binary form must reproduce the above copyright
|
12 |
+
# notice, this list of conditions and the following disclaimer in the
|
13 |
+
# documentation and/or other materials provided with the distribution.
|
14 |
+
# * Neither the name of the NVIDIA CORPORATION nor the
|
15 |
+
# names of its contributors may be used to endorse or promote products
|
16 |
+
# derived from this software without specific prior written permission.
|
17 |
+
#
|
18 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
19 |
+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
20 |
+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
21 |
+
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
22 |
+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
23 |
+
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
24 |
+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
25 |
+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
26 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
27 |
+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
28 |
+
#
|
29 |
+
# *****************************************************************************
|
30 |
+
|
31 |
+
import os
|
32 |
+
import torch
|
33 |
+
import torch.distributed as dist
|
34 |
+
from torch.autograd import Variable
|
35 |
+
|
36 |
+
|
37 |
+
def reduce_tensor(tensor, num_gpus, reduce_dst=None):
|
38 |
+
if num_gpus <= 1: # pass-thru
|
39 |
+
return tensor
|
40 |
+
rt = tensor.clone()
|
41 |
+
if reduce_dst is not None:
|
42 |
+
dist.reduce(rt, reduce_dst, op=dist.ReduceOp.SUM)
|
43 |
+
else:
|
44 |
+
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
45 |
+
rt /= num_gpus
|
46 |
+
return rt
|
47 |
+
|
48 |
+
|
49 |
+
def init_distributed(rank, num_gpus, dist_backend, dist_url):
|
50 |
+
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
|
51 |
+
|
52 |
+
print("> initializing distributed for rank {} out of {}".format(rank, num_gpus))
|
53 |
+
|
54 |
+
# Set cuda device so everything is done on the right GPU.
|
55 |
+
torch.cuda.set_device(rank % torch.cuda.device_count())
|
56 |
+
|
57 |
+
init_method = "tcp://"
|
58 |
+
master_ip = os.getenv("MASTER_ADDR", "localhost")
|
59 |
+
master_port = os.getenv("MASTER_PORT", "6000")
|
60 |
+
init_method += master_ip + ":" + master_port
|
61 |
+
torch.distributed.init_process_group(
|
62 |
+
backend="nccl", world_size=num_gpus, rank=rank, init_method=init_method
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
def _flatten_dense_tensors(tensors):
|
67 |
+
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
|
68 |
+
same dense type.
|
69 |
+
Since inputs are dense, the resulting tensor will be a concatenated 1D
|
70 |
+
buffer. Element-wise operation on this buffer will be equivalent to
|
71 |
+
operating individually.
|
72 |
+
Arguments:
|
73 |
+
tensors (Iterable[Tensor]): dense tensors to flatten.
|
74 |
+
Returns:
|
75 |
+
A contiguous 1D buffer containing input tensors.
|
76 |
+
"""
|
77 |
+
if len(tensors) == 1:
|
78 |
+
return tensors[0].contiguous().view(-1)
|
79 |
+
flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
|
80 |
+
return flat
|
81 |
+
|
82 |
+
|
83 |
+
def _unflatten_dense_tensors(flat, tensors):
|
84 |
+
"""View a flat buffer using the sizes of tensors. Assume that tensors are of
|
85 |
+
same dense type, and that flat is given by _flatten_dense_tensors.
|
86 |
+
Arguments:
|
87 |
+
flat (Tensor): flattened dense tensors to unflatten.
|
88 |
+
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
|
89 |
+
unflatten flat.
|
90 |
+
Returns:
|
91 |
+
Unflattened dense tensors with sizes same as tensors and values from
|
92 |
+
flat.
|
93 |
+
"""
|
94 |
+
outputs = []
|
95 |
+
offset = 0
|
96 |
+
for tensor in tensors:
|
97 |
+
numel = tensor.numel()
|
98 |
+
outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
|
99 |
+
offset += numel
|
100 |
+
return tuple(outputs)
|
101 |
+
|
102 |
+
|
103 |
+
def apply_gradient_allreduce(module):
|
104 |
+
"""
|
105 |
+
Modifies existing model to do gradient allreduce, but doesn't change class
|
106 |
+
so you don't need "module"
|
107 |
+
"""
|
108 |
+
if not hasattr(dist, "_backend"):
|
109 |
+
module.warn_on_half = True
|
110 |
+
else:
|
111 |
+
module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
|
112 |
+
|
113 |
+
for p in module.state_dict().values():
|
114 |
+
if not torch.is_tensor(p):
|
115 |
+
continue
|
116 |
+
dist.broadcast(p, 0)
|
117 |
+
|
118 |
+
def allreduce_params():
|
119 |
+
if module.needs_reduction:
|
120 |
+
module.needs_reduction = False
|
121 |
+
buckets = {}
|
122 |
+
for param in module.parameters():
|
123 |
+
if param.requires_grad and param.grad is not None:
|
124 |
+
tp = type(param.data)
|
125 |
+
if tp not in buckets:
|
126 |
+
buckets[tp] = []
|
127 |
+
buckets[tp].append(param)
|
128 |
+
if module.warn_on_half:
|
129 |
+
if torch.cuda.HalfTensor in buckets:
|
130 |
+
print(
|
131 |
+
"WARNING: gloo dist backend for half parameters may be extremely slow."
|
132 |
+
+ " It is recommended to use the NCCL backend in this case. This currently requires"
|
133 |
+
+ "PyTorch built from top of tree master."
|
134 |
+
)
|
135 |
+
module.warn_on_half = False
|
136 |
+
|
137 |
+
for tp in buckets:
|
138 |
+
bucket = buckets[tp]
|
139 |
+
grads = [param.grad.data for param in bucket]
|
140 |
+
coalesced = _flatten_dense_tensors(grads)
|
141 |
+
dist.all_reduce(coalesced)
|
142 |
+
coalesced /= dist.get_world_size()
|
143 |
+
for buf, synced in zip(
|
144 |
+
grads, _unflatten_dense_tensors(coalesced, grads)
|
145 |
+
):
|
146 |
+
buf.copy_(synced)
|
147 |
+
|
148 |
+
for param in list(module.parameters()):
|
149 |
+
|
150 |
+
def allreduce_hook(*unused):
|
151 |
+
Variable._execution_engine.queue_callback(allreduce_params)
|
152 |
+
|
153 |
+
if param.requires_grad:
|
154 |
+
param.register_hook(allreduce_hook)
|
155 |
+
dir(param)
|
156 |
+
|
157 |
+
def set_needs_reduction(self, input, output):
|
158 |
+
self.needs_reduction = True
|
159 |
+
|
160 |
+
module.register_forward_hook(set_needs_reduction)
|
161 |
+
return module
|
filelists/3speakers_ukrainian_train_filelist.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
filelists/3speakers_ukrainian_train_filelist_dc.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
filelists/3speakers_ukrainian_val_filelist.txt
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48849.wav|мандрівник+и вп+ерто відмовл+ялися.|lada
|
2 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48850.wav|він уз+яв сок+иру й г+острим кінц+ем поч+ав розв+ажувати з+уби.|lada
|
3 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48851.wav|розгр+ібши сніг, тр+охи прос+унув г+олову й пл+ечі під шатр+о.|lada
|
4 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48853.wav|ал+е раз зас+идівся до п+ізнього в+ечора.|lada
|
5 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48854.wav|то ж не дим їй +очі роз'їд+ав, бо др+ова бул+и сух+і.|lada
|
6 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48855.wav|вон+а не м+ала теп+ер с+умніву, що в портоса з д+амою бул+а інтр+ига.|lada
|
7 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48857.wav|х+очуть укра+їну з під л+яхів визвол+яти.|lada
|
8 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48858.wav|там жінк+ам не д+уже догодж+ають.|lada
|
9 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48859.wav|і б+удьте спок+ійні! якщ+о вин+о нам не спод+обається, ми пошлем+о по +інше.|lada
|
10 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48830.wav|мій д+івер і я м+арно чек+али на вас вч+ора й позавч+ора.|lada
|
11 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48831.wav|п+ане д'артаньяне, ви п+ерший.|lada
|
12 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48832.wav|ось мо+я в+ідповідь.|lada
|
13 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48833.wav|хоч той так+и й д+ійсно д+урень.|lada
|
14 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48834.wav|ви давн+о не гр+али?|lada
|
15 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48835.wav|теп+ер їм довел+ось зазн+ати д+оброї бід+и в цій кра+їні.|lada
|
16 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48836.wav|позавч+ора був пісн+ий день, а там подав+али лиш+е скор+омне.|lada
|
17 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48837.wav|і не потреб+уєте всі роб+ити.|lada
|
18 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48838.wav|у рук+ах у н+еї бул+а нов+а зап+иска міл+еді.|lada
|
19 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48839.wav|і ч+етверо др+узів одн+им г+олосом повтор+или прис+ягу, запропон+овану від д'артаньяна.|lada
|
20 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48841.wav|іг+уменя ст+ала сл+ухати ув+ажніш, тр+охи пожвав+іла й всміхн+улася.|lada
|
21 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48842.wav|так ти цьог+о не роб+и й не втрач+айся, бо одн+аково не пом+оже.|lada
|
22 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48843.wav|туд+и і рв+еться н+аша душ+а, кол+и х+очеш зн+ати.|lada
|
23 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48844.wav|б+олісно всміх+ався і трясс+я, як у проп+асниці.|lada
|
24 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48845.wav|я прив+ів тоб+і др+угого, сказ+ав д'артаньян.|lada
|
25 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48846.wav|я поб+ачу корол+я сьог+одні увечорі, ал+е вас не р+аджу наверт+атись йому на в+ічі.|lada
|
26 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48847.wav|ще весел+іш почал+и тод+і гомон+іти.|lada
|
27 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/lada/accept/48848.wav|споч+атку вон+а нарахув+ала двох, п+отім п'ять, нар+ешті в+ісім.|lada
|
28 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68540.wav|кр+аще вже пуст+ити соб+і к+улю в л+оба і відр+азу покл+асти всь+ому край.|mykyta
|
29 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68541.wav|ал+е сидяч+и за стол+ом, при п+иві, знов поч+ув як+есь невдов+олення.|mykyta
|
30 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68543.wav|на шабл+ях!|mykyta
|
31 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68544.wav|вон+а пров+адила з незнай+омим д+уже жв+аву розм+ову.|mykyta
|
32 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68545.wav|офіц+ер взяв зі ст+олу вк+азані пап+ери, под+ав їх і, н+изько вклонившися, в+ийшов.|mykyta
|
33 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68546.wav|аж с+умно йому ст+ало.|mykyta
|
34 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68547.wav|житт+я не ласк+аве з багать+ох прич+ин.|mykyta
|
35 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68548.wav|так, звич+айно тр+еба, ств+ердила корол+ева.|mykyta
|
36 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68549.wav|вон+а, не зверн+увши ув+аги на цей д+ок+ір, промовл+яла д+алі.|mykyta
|
37 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68550.wav|зда+ється, не дочув+аю.|mykyta
|
38 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68551.wav|відв+ажний і завз+ятий, він не вп+ерше в+ажив сво+ї+++м житт+ям у так+их приг+одах.|mykyta
|
39 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68552.wav|як ч+асом, г+аво.|mykyta
|
40 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68553.wav|мій друг араміс, що оц+е сто+їть п+еред вами, здоб+ув легк+ого вд+ара шпад+ою в р+уку.|mykyta
|
41 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68554.wav|я знав+ець свог+о д+іла.|mykyta
|
42 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68556.wav|пог+онич леж+ав на с+анк+ах, а соб+аки шв+идко б+ігли пр+ямо до хат+ини.|mykyta
|
43 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68557.wav|міл+еді к+инулась до нього.|mykyta
|
44 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68558.wav|хто тоб+і сказ+ав?|mykyta
|
45 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68559.wav|то й не поваж+ай, не зляк+аєш.|mykyta
|
46 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68560.wav|поясн+іть, бо я не розум+ію, що ви х+очете сказ+ати.|mykyta
|
47 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68561.wav|шрам наздогн+ав свій п+оїзд к+оло вис+оких вор+іт п+ана гвинтовки.|mykyta
|
48 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68562.wav|що ж він так+е?|mykyta
|
49 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68563.wav|що це так+е? спит+ав портос.|mykyta
|
50 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/mykyta/accept/68565.wav|див+іться, тут зн+ову втруч+алася ц+ерква, з+авжд+и та ц+ерква.|mykyta
|
51 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67117.wav|а чолов+ік цьог+о жахл+ивого створ+іння ще жив+ий? зацік+авився араміс.|tetiana
|
52 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67118.wav|ви, дик, не ч+ули ці+єї т+иші.|tetiana
|
53 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67119.wav|він баг+атий на р+ок+и, шан+обу й сл+аву вел+ику.|tetiana
|
54 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67120.wav|в +осени зар+ані, ск+оро п+ісля сп+аса под+ався макс+им до київа.|tetiana
|
55 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67121.wav|а до н+еї п+ишеш?|tetiana
|
56 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67122.wav|я, б+ачилось, н+авіть не люб+ив її так, як л+юблять зак+охані.|tetiana
|
57 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67123.wav|юрб+а провал+ила тим ч+асом м+имо петр+а.|tetiana
|
58 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67124.wav|хай так! приєдн+ався швайц+арець.|tetiana
|
59 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67125.wav|к+онюх підтв+ердив кардин+алові слов+а мушкет+ерів про атоса.|tetiana
|
60 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67126.wav|що завин+ив, те б+уду терп+іти.|tetiana
|
61 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67127.wav|чи є у вас тр+охи піск+у? ск+ільки? він показ+ав їй свій міш+ок.|tetiana
|
62 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67128.wav|я скаж+у це т+ільки том+у, хто прозирн+е в мо+ю д+ушу.|tetiana
|
63 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67129.wav|і в оц+ій хв+илі вон+а не міркув+ала тог+о.|tetiana
|
64 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67130.wav|ти б+ачив сво+ю ж?|tetiana
|
65 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67132.wav|прот+е, тр+еба скл+асти як+ийсь плян б+ою, пром+овив араміс.|tetiana
|
66 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67133.wav|огого! д+уже швидк+а! так я теб+е й пуст+ив до богун+а!|tetiana
|
67 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67134.wav|бог з тоб+ою, добр+одію!|tetiana
|
68 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67135.wav|киценька! ти т+ямиш її?|tetiana
|
69 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67136.wav|розм+ова поверн+ула на вес+еле.|tetiana
|
70 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67137.wav|розум+іється, сказ+ала вон+а к+оротко.|tetiana
|
71 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67138.wav|їй с+оромно ст+ало, що на оч+ах у всіх її так знев+ажено, і вон+а знен+авиділа фреду.|tetiana
|
72 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67139.wav|це бул+о м+ужнє обл+иччя.|tetiana
|
73 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67140.wav|св+екра зн+ала м+ало, не ч+асто й б+ачилася з ним, на рік раз+ів зо три.|tetiana
|
74 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67141.wav|спр+ава ця єсть особл+ивої делікатности.|tetiana
|
75 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67143.wav|я так отощ+ав, не +ївши зр+анку, що й р+адуватись незд+ужаю.|tetiana
|
76 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67145.wav|т+ільки в+ірна будь мен+і.|tetiana
|
77 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67146.wav|п'єр піш+ов за н+ею і відч+алив.|tetiana
|
78 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67147.wav|і по цих слов+ах к+инув торб+инку із з+олотом в р+ічку.|tetiana
|
79 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67148.wav|а, він в пор+ядку, сказ+ав нач+альник, та з чуд+овою рекоменд+ацією.|tetiana
|
80 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67149.wav|тод+і підожд+іть тр+ошки, зачек+айте.|tetiana
|
81 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67150.wav|із як+ими вістьми? пит+ає г+етьман.|tetiana
|
82 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67151.wav|стар+ий сарабр+ин міг л+егко пот+ішитися.|tetiana
|
83 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67152.wav|о, я, нещ+асний!|tetiana
|
84 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67153.wav|кр+оки в сальоні.|tetiana
|
85 |
+
/home/yehor/RADTTS-Multiple-Voices/datasets/tetiana/accept/67154.wav|щоб н+ашим ворог+ам бул+о т+яжко!|tetiana
|
filelists/3speakers_ukrainian_val_filelist_dc.txt
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48849.wav|мандрівник+и вп+ерто відмовл+ялися.|lada
|
2 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48850.wav|він уз+яв сок+иру й г+острим кінц+ем поч+ав розв+ажувати з+уби.|lada
|
3 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48851.wav|розгр+ібши сніг, тр+охи прос+унув г+олову й пл+ечі під шатр+о.|lada
|
4 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48853.wav|ал+е раз зас+идівся до п+ізнього в+ечора.|lada
|
5 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48854.wav|то ж не дим їй +очі роз'їд+ав, бо др+ова бул+и сух+і.|lada
|
6 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48855.wav|вон+а не м+ала теп+ер с+умніву, що в портоса з д+амою бул+а інтр+ига.|lada
|
7 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48857.wav|х+очуть укра+їну з під л+яхів визвол+яти.|lada
|
8 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48858.wav|там жінк+ам не д+уже догодж+ають.|lada
|
9 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48859.wav|і б+удьте спок+ійні! якщ+о вин+о нам не спод+обається, ми пошлем+о по +інше.|lada
|
10 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48830.wav|мій д+івер і я м+арно чек+али на вас вч+ора й позавч+ора.|lada
|
11 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48831.wav|п+ане д'артаньяне, ви п+ерший.|lada
|
12 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48832.wav|ось мо+я в+ідповідь.|lada
|
13 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48833.wav|хоч той так+и й д+ійсно д+урень.|lada
|
14 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48834.wav|ви давн+о не гр+али?|lada
|
15 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48835.wav|теп+ер їм довел+ось зазн+ати д+оброї бід+и в цій кра+їні.|lada
|
16 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48836.wav|позавч+ора був пісн+ий день, а там подав+али лиш+е скор+омне.|lada
|
17 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48837.wav|і не потреб+уєте всі роб+ити.|lada
|
18 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48838.wav|у рук+ах у н+еї бул+а нов+а зап+иска міл+еді.|lada
|
19 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48839.wav|і ч+етверо др+узів одн+им г+олосом повтор+или прис+ягу, запропон+овану від д'артаньяна.|lada
|
20 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48841.wav|іг+уменя ст+ала сл+ухати ув+ажніш, тр+охи пожвав+іла й всміхн+улася.|lada
|
21 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48842.wav|так ти цьог+о не роб+и й не втрач+айся, бо одн+аково не пом+оже.|lada
|
22 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48843.wav|туд+и і рв+еться н+аша душ+а, кол+и х+очеш зн+ати.|lada
|
23 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48844.wav|б+олісно всміх+ався і трясс+я, як у проп+асниці.|lada
|
24 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48845.wav|я прив+ів тоб+і др+угого, сказ+ав д'артаньян.|lada
|
25 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48846.wav|я поб+ачу корол+я сьог+одні увечорі, ал+е вас не р+аджу наверт+атись йому на в+ічі.|lada
|
26 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48847.wav|ще весел+іш почал+и тод+і гомон+іти.|lada
|
27 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/lada/accept/48848.wav|споч+атку вон+а нарахув+ала двох, п+отім п'ять, нар+ешті в+ісім.|lada
|
28 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68540.wav|кр+аще вже пуст+ити соб+і к+улю в л+оба і відр+азу покл+асти всь+ому край.|mykyta
|
29 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68541.wav|ал+е сидяч+и за стол+ом, при п+иві, знов поч+ув як+есь невдов+олення.|mykyta
|
30 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68543.wav|на шабл+ях!|mykyta
|
31 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68544.wav|вон+а пров+адила з незнай+омим д+уже жв+аву розм+ову.|mykyta
|
32 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68545.wav|офіц+ер взяв зі ст+олу вк+азані пап+ери, под+ав їх і, н+изько вклонившися, в+ийшов.|mykyta
|
33 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68546.wav|аж с+умно йому ст+ало.|mykyta
|
34 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68547.wav|житт+я не ласк+аве з багать+ох прич+ин.|mykyta
|
35 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68548.wav|так, звич+айно тр+еба, ств+ердила корол+ева.|mykyta
|
36 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68549.wav|вон+а, не зверн+увши ув+аги на цей д+ок+ір, промовл+яла д+алі.|mykyta
|
37 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68550.wav|зда+ється, не дочув+аю.|mykyta
|
38 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68551.wav|відв+ажний і завз+ятий, він не вп+ерше в+ажив сво+ї+++м житт+ям у так+их приг+одах.|mykyta
|
39 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68552.wav|як ч+асом, г+аво.|mykyta
|
40 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68553.wav|мій друг араміс, що оц+е сто+їть п+еред вами, здоб+ув легк+ого вд+ара шпад+ою в р+уку.|mykyta
|
41 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68554.wav|я знав+ець свог+о д+іла.|mykyta
|
42 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68556.wav|пог+онич леж+ав на с+анк+ах, а соб+аки шв+идко б+ігли пр+ямо до хат+ини.|mykyta
|
43 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68557.wav|міл+еді к+инулась до нього.|mykyta
|
44 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68558.wav|хто тоб+і сказ+ав?|mykyta
|
45 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68559.wav|то й не поваж+ай, не зляк+аєш.|mykyta
|
46 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68560.wav|поясн+іть, бо я не розум+ію, що ви х+очете сказ+ати.|mykyta
|
47 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68561.wav|шрам наздогн+ав свій п+оїзд к+оло вис+оких вор+іт п+ана гвинтовки.|mykyta
|
48 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68562.wav|що ж він так+е?|mykyta
|
49 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68563.wav|що це так+е? спит+ав портос.|mykyta
|
50 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/mykyta/accept/68565.wav|див+іться, тут зн+ову втруч+алася ц+ерква, з+авжд+и та ц+ерква.|mykyta
|
51 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67117.wav|а чолов+ік цьог+о жахл+ивого створ+іння ще жив+ий? зацік+авився араміс.|tetiana
|
52 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67118.wav|ви, дик, не ч+ули ці+єї т+иші.|tetiana
|
53 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67119.wav|він баг+атий на р+ок+и, шан+обу й сл+аву вел+ику.|tetiana
|
54 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67120.wav|в +осени зар+ані, ск+оро п+ісля сп+аса под+ався макс+им до київа.|tetiana
|
55 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67121.wav|а до н+еї п+ишеш?|tetiana
|
56 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67122.wav|я, б+ачилось, н+авіть не люб+ив її так, як л+юблять зак+охані.|tetiana
|
57 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67123.wav|юрб+а провал+ила тим ч+асом м+имо петр+а.|tetiana
|
58 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67124.wav|хай так! приєдн+ався швайц+арець.|tetiana
|
59 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67125.wav|к+онюх підтв+ердив кардин+алові слов+а мушкет+ерів про атоса.|tetiana
|
60 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67126.wav|що завин+ив, те б+уду терп+іти.|tetiana
|
61 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67127.wav|чи є у вас тр+охи піск+у? ск+ільки? він показ+ав їй свій міш+ок.|tetiana
|
62 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67128.wav|я скаж+у це т+ільки том+у, хто прозирн+е в мо+ю д+ушу.|tetiana
|
63 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67129.wav|і в оц+ій хв+илі вон+а не міркув+ала тог+о.|tetiana
|
64 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67130.wav|ти б+ачив сво+ю ж?|tetiana
|
65 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67132.wav|прот+е, тр+еба скл+асти як+ийсь плян б+ою, пром+овив араміс.|tetiana
|
66 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67133.wav|огого! д+уже швидк+а! так я теб+е й пуст+ив до богун+а!|tetiana
|
67 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67134.wav|бог з тоб+ою, добр+одію!|tetiana
|
68 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67135.wav|киценька! ти т+ямиш її?|tetiana
|
69 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67136.wav|розм+ова поверн+ула на вес+еле.|tetiana
|
70 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67137.wav|розум+іється, сказ+ала вон+а к+оротко.|tetiana
|
71 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67138.wav|їй с+оромно ст+ало, що на оч+ах у всіх її так знев+ажено, і вон+а знен+авиділа фреду.|tetiana
|
72 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67139.wav|це бул+о м+ужнє обл+иччя.|tetiana
|
73 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67140.wav|св+екра зн+ала м+ало, не ч+асто й б+ачилася з ним, на рік раз+ів зо три.|tetiana
|
74 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67141.wav|спр+ава ця єсть особл+ивої делікатности.|tetiana
|
75 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67143.wav|я так отощ+ав, не +ївши зр+анку, що й р+адуватись незд+ужаю.|tetiana
|
76 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67145.wav|т+ільки в+ірна будь мен+і.|tetiana
|
77 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67146.wav|п'єр піш+ов за н+ею і відч+алив.|tetiana
|
78 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67147.wav|і по цих слов+ах к+инув торб+инку із з+олотом в р+ічку.|tetiana
|
79 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67148.wav|а, він в пор+ядку, сказ+ав нач+альник, та з чуд+овою рекоменд+ацією.|tetiana
|
80 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67149.wav|тод+і підожд+іть тр+ошки, зачек+айте.|tetiana
|
81 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67150.wav|із як+ими вістьми? пит+ає г+етьман.|tetiana
|
82 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67151.wav|стар+ий сарабр+ин міг л+егко пот+ішитися.|tetiana
|
83 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67152.wav|о, я, нещ+асний!|tetiana
|
84 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67153.wav|кр+оки в сальоні.|tetiana
|
85 |
+
/home/dmytro_chaplinsky/RAD-TTS/datasets/tetiana/accept/67154.wav|щоб н+ашим ворог+ам бул+о т+яжко!|tetiana
|
loss.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
from torch.nn import functional as F
|
24 |
+
from common import get_mask_from_lengths
|
25 |
+
|
26 |
+
|
27 |
+
def compute_flow_loss(
|
28 |
+
z, log_det_W_list, log_s_list, n_elements, n_dims, mask, sigma=1.0
|
29 |
+
):
|
30 |
+
log_det_W_total = 0.0
|
31 |
+
for i, log_s in enumerate(log_s_list):
|
32 |
+
if i == 0:
|
33 |
+
log_s_total = torch.sum(log_s * mask)
|
34 |
+
if len(log_det_W_list):
|
35 |
+
log_det_W_total = log_det_W_list[i]
|
36 |
+
else:
|
37 |
+
log_s_total = log_s_total + torch.sum(log_s * mask)
|
38 |
+
if len(log_det_W_list):
|
39 |
+
log_det_W_total += log_det_W_list[i]
|
40 |
+
|
41 |
+
if len(log_det_W_list):
|
42 |
+
log_det_W_total *= n_elements
|
43 |
+
|
44 |
+
z = z * mask
|
45 |
+
prior_NLL = torch.sum(z * z) / (2 * sigma * sigma)
|
46 |
+
|
47 |
+
loss = prior_NLL - log_s_total - log_det_W_total
|
48 |
+
|
49 |
+
denom = n_elements * n_dims
|
50 |
+
loss = loss / denom
|
51 |
+
loss_prior = prior_NLL / denom
|
52 |
+
return loss, loss_prior
|
53 |
+
|
54 |
+
|
55 |
+
def compute_regression_loss(x_hat, x, mask, name=False):
|
56 |
+
x = x[:, None] if len(x.shape) == 2 else x # add channel dim
|
57 |
+
mask = mask[:, None] if len(mask.shape) == 2 else mask # add channel dim
|
58 |
+
assert len(x.shape) == len(mask.shape)
|
59 |
+
|
60 |
+
x = x * mask
|
61 |
+
x_hat = x_hat * mask
|
62 |
+
|
63 |
+
if name == "vpred":
|
64 |
+
loss = F.binary_cross_entropy_with_logits(x_hat, x, reduction="sum")
|
65 |
+
else:
|
66 |
+
loss = F.mse_loss(x_hat, x, reduction="sum")
|
67 |
+
loss = loss / mask.sum()
|
68 |
+
|
69 |
+
loss_dict = {"loss_{}".format(name): loss}
|
70 |
+
|
71 |
+
return loss_dict
|
72 |
+
|
73 |
+
|
74 |
+
class AttributePredictionLoss(torch.nn.Module):
|
75 |
+
def __init__(self, name, model_config, loss_weight, sigma=1.0):
|
76 |
+
super(AttributePredictionLoss, self).__init__()
|
77 |
+
self.name = name
|
78 |
+
self.sigma = sigma
|
79 |
+
self.model_name = model_config["name"]
|
80 |
+
self.loss_weight = loss_weight
|
81 |
+
self.n_group_size = 1
|
82 |
+
if "n_group_size" in model_config["hparams"]:
|
83 |
+
self.n_group_size = model_config["hparams"]["n_group_size"]
|
84 |
+
|
85 |
+
def forward(self, model_output, lens):
|
86 |
+
mask = get_mask_from_lengths(lens // self.n_group_size)
|
87 |
+
mask = mask[:, None].float()
|
88 |
+
loss_dict = {}
|
89 |
+
if "z" in model_output:
|
90 |
+
n_elements = lens.sum() // self.n_group_size
|
91 |
+
n_dims = model_output["z"].size(1)
|
92 |
+
|
93 |
+
loss, loss_prior = compute_flow_loss(
|
94 |
+
model_output["z"],
|
95 |
+
model_output["log_det_W_list"],
|
96 |
+
model_output["log_s_list"],
|
97 |
+
n_elements,
|
98 |
+
n_dims,
|
99 |
+
mask,
|
100 |
+
self.sigma,
|
101 |
+
)
|
102 |
+
loss_dict = {
|
103 |
+
"loss_{}".format(self.name): (loss, self.loss_weight),
|
104 |
+
"loss_prior_{}".format(self.name): (loss_prior, 0.0),
|
105 |
+
}
|
106 |
+
elif "x_hat" in model_output:
|
107 |
+
loss_dict = compute_regression_loss(
|
108 |
+
model_output["x_hat"], model_output["x"], mask, self.name
|
109 |
+
)
|
110 |
+
for k, v in loss_dict.items():
|
111 |
+
loss_dict[k] = (v, self.loss_weight)
|
112 |
+
|
113 |
+
if len(loss_dict) == 0:
|
114 |
+
raise Exception("loss not supported")
|
115 |
+
|
116 |
+
return loss_dict
|
117 |
+
|
118 |
+
|
119 |
+
class AttentionCTCLoss(torch.nn.Module):
|
120 |
+
def __init__(self, blank_logprob=-1):
|
121 |
+
super(AttentionCTCLoss, self).__init__()
|
122 |
+
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
123 |
+
self.blank_logprob = blank_logprob
|
124 |
+
self.CTCLoss = nn.CTCLoss(zero_infinity=True)
|
125 |
+
|
126 |
+
def forward(self, attn_logprob, in_lens, out_lens):
|
127 |
+
key_lens = in_lens
|
128 |
+
query_lens = out_lens
|
129 |
+
attn_logprob_padded = F.pad(
|
130 |
+
input=attn_logprob, pad=(1, 0, 0, 0, 0, 0, 0, 0), value=self.blank_logprob
|
131 |
+
)
|
132 |
+
cost_total = 0.0
|
133 |
+
for bid in range(attn_logprob.shape[0]):
|
134 |
+
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
|
135 |
+
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[
|
136 |
+
: query_lens[bid], :, : key_lens[bid] + 1
|
137 |
+
]
|
138 |
+
curr_logprob = self.log_softmax(curr_logprob[None])[0]
|
139 |
+
ctc_cost = self.CTCLoss(
|
140 |
+
curr_logprob,
|
141 |
+
target_seq,
|
142 |
+
input_lengths=query_lens[bid : bid + 1],
|
143 |
+
target_lengths=key_lens[bid : bid + 1],
|
144 |
+
)
|
145 |
+
cost_total += ctc_cost
|
146 |
+
cost = cost_total / attn_logprob.shape[0]
|
147 |
+
return cost
|
148 |
+
|
149 |
+
|
150 |
+
class AttentionBinarizationLoss(torch.nn.Module):
|
151 |
+
def __init__(self):
|
152 |
+
super(AttentionBinarizationLoss, self).__init__()
|
153 |
+
|
154 |
+
def forward(self, hard_attention, soft_attention):
|
155 |
+
log_sum = torch.log(soft_attention[hard_attention == 1]).sum()
|
156 |
+
return -log_sum / hard_attention.sum()
|
157 |
+
|
158 |
+
|
159 |
+
class RADTTSLoss(torch.nn.Module):
|
160 |
+
def __init__(
|
161 |
+
self,
|
162 |
+
sigma=1.0,
|
163 |
+
n_group_size=1,
|
164 |
+
dur_model_config=None,
|
165 |
+
f0_model_config=None,
|
166 |
+
energy_model_config=None,
|
167 |
+
vpred_model_config=None,
|
168 |
+
loss_weights=None,
|
169 |
+
):
|
170 |
+
super(RADTTSLoss, self).__init__()
|
171 |
+
self.sigma = sigma
|
172 |
+
self.n_group_size = n_group_size
|
173 |
+
self.loss_weights = loss_weights
|
174 |
+
self.attn_ctc_loss = AttentionCTCLoss(
|
175 |
+
blank_logprob=loss_weights.get("blank_logprob", -1)
|
176 |
+
)
|
177 |
+
self.loss_fns = {}
|
178 |
+
if dur_model_config is not None:
|
179 |
+
self.loss_fns["duration_model_outputs"] = AttributePredictionLoss(
|
180 |
+
"duration", dur_model_config, loss_weights["dur_loss_weight"]
|
181 |
+
)
|
182 |
+
|
183 |
+
if f0_model_config is not None:
|
184 |
+
self.loss_fns["f0_model_outputs"] = AttributePredictionLoss(
|
185 |
+
"f0", f0_model_config, loss_weights["f0_loss_weight"], sigma=1.0
|
186 |
+
)
|
187 |
+
|
188 |
+
if energy_model_config is not None:
|
189 |
+
self.loss_fns["energy_model_outputs"] = AttributePredictionLoss(
|
190 |
+
"energy", energy_model_config, loss_weights["energy_loss_weight"]
|
191 |
+
)
|
192 |
+
|
193 |
+
if vpred_model_config is not None:
|
194 |
+
self.loss_fns["vpred_model_outputs"] = AttributePredictionLoss(
|
195 |
+
"vpred", vpred_model_config, loss_weights["vpred_loss_weight"]
|
196 |
+
)
|
197 |
+
|
198 |
+
def forward(self, model_output, in_lens, out_lens):
|
199 |
+
loss_dict = {}
|
200 |
+
if len(model_output["z_mel"]):
|
201 |
+
n_elements = out_lens.sum() // self.n_group_size
|
202 |
+
mask = get_mask_from_lengths(out_lens // self.n_group_size)
|
203 |
+
mask = mask[:, None].float()
|
204 |
+
n_dims = model_output["z_mel"].size(1)
|
205 |
+
loss_mel, loss_prior_mel = compute_flow_loss(
|
206 |
+
model_output["z_mel"],
|
207 |
+
model_output["log_det_W_list"],
|
208 |
+
model_output["log_s_list"],
|
209 |
+
n_elements,
|
210 |
+
n_dims,
|
211 |
+
mask,
|
212 |
+
self.sigma,
|
213 |
+
)
|
214 |
+
loss_dict["loss_mel"] = (loss_mel, 1.0) # loss, weight
|
215 |
+
loss_dict["loss_prior_mel"] = (loss_prior_mel, 0.0)
|
216 |
+
|
217 |
+
ctc_cost = self.attn_ctc_loss(model_output["attn_logprob"], in_lens, out_lens)
|
218 |
+
loss_dict["loss_ctc"] = (ctc_cost, self.loss_weights["ctc_loss_weight"])
|
219 |
+
|
220 |
+
for k in model_output:
|
221 |
+
if k in self.loss_fns:
|
222 |
+
if model_output[k] is not None and len(model_output[k]) > 0:
|
223 |
+
t_lens = in_lens if "dur" in k else out_lens
|
224 |
+
mout = model_output[k]
|
225 |
+
for loss_name, v in self.loss_fns[k](mout, t_lens).items():
|
226 |
+
loss_dict[loss_name] = v
|
227 |
+
|
228 |
+
return loss_dict
|
partialconv1d.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified partialconv source code based on implementation from
|
2 |
+
# https://github.com/NVIDIA/partialconv/blob/master/models/partialconv2d.py
|
3 |
+
###############################################################################
|
4 |
+
# BSD 3-Clause License
|
5 |
+
#
|
6 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
7 |
+
#
|
8 |
+
# Author & Contact: Guilin Liu ([email protected])
|
9 |
+
###############################################################################
|
10 |
+
|
11 |
+
# Original Author & Contact: Guilin Liu ([email protected])
|
12 |
+
# Modified by Kevin Shih ([email protected])
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
|
19 |
+
class PartialConv1d(nn.Conv1d):
|
20 |
+
def __init__(self, *args, **kwargs):
|
21 |
+
self.multi_channel = False
|
22 |
+
self.return_mask = False
|
23 |
+
super(PartialConv1d, self).__init__(*args, **kwargs)
|
24 |
+
|
25 |
+
self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0])
|
26 |
+
self.slide_winsize = (
|
27 |
+
self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2]
|
28 |
+
)
|
29 |
+
|
30 |
+
self.last_size = (None, None, None)
|
31 |
+
self.update_mask = None
|
32 |
+
self.mask_ratio = None
|
33 |
+
|
34 |
+
@torch.jit.ignore
|
35 |
+
def forward(self, input: torch.Tensor, mask_in: torch.Tensor = None):
|
36 |
+
"""
|
37 |
+
input: standard input to a 1D conv
|
38 |
+
mask_in: binary mask for valid values, same shape as input
|
39 |
+
"""
|
40 |
+
assert len(input.shape) == 3
|
41 |
+
# if a mask is input, or tensor shape changed, update mask ratio
|
42 |
+
if mask_in is not None or self.last_size != tuple(input.shape):
|
43 |
+
self.last_size = tuple(input.shape)
|
44 |
+
with torch.no_grad():
|
45 |
+
if self.weight_maskUpdater.type() != input.type():
|
46 |
+
self.weight_maskUpdater = self.weight_maskUpdater.to(input)
|
47 |
+
if mask_in is None:
|
48 |
+
mask = torch.ones(1, 1, input.data.shape[2]).to(input)
|
49 |
+
else:
|
50 |
+
mask = mask_in
|
51 |
+
self.update_mask = F.conv1d(
|
52 |
+
mask,
|
53 |
+
self.weight_maskUpdater,
|
54 |
+
bias=None,
|
55 |
+
stride=self.stride,
|
56 |
+
padding=self.padding,
|
57 |
+
dilation=self.dilation,
|
58 |
+
groups=1,
|
59 |
+
)
|
60 |
+
# for mixed precision training, change 1e-8 to 1e-6
|
61 |
+
self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-6)
|
62 |
+
self.update_mask = torch.clamp(self.update_mask, 0, 1)
|
63 |
+
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
|
64 |
+
raw_out = super(PartialConv1d, self).forward(
|
65 |
+
torch.mul(input, mask) if mask_in is not None else input
|
66 |
+
)
|
67 |
+
if self.bias is not None:
|
68 |
+
bias_view = self.bias.view(1, self.out_channels, 1)
|
69 |
+
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
|
70 |
+
output = torch.mul(output, self.update_mask)
|
71 |
+
else:
|
72 |
+
output = torch.mul(raw_out, self.mask_ratio)
|
73 |
+
|
74 |
+
if self.return_mask:
|
75 |
+
return output, self.update_mask
|
76 |
+
else:
|
77 |
+
return output
|
radam.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original source taken from https://github.com/LiyuanLucasLiu/RAdam
|
2 |
+
#
|
3 |
+
# Copyright 2019 Liyuan Liu
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import math
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
# pylint: disable=no-name-in-module
|
21 |
+
from torch.optim.optimizer import Optimizer
|
22 |
+
|
23 |
+
|
24 |
+
class RAdam(Optimizer):
|
25 |
+
"""RAdam optimizer"""
|
26 |
+
|
27 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
28 |
+
"""
|
29 |
+
Init
|
30 |
+
|
31 |
+
:param params: parameters to optimize
|
32 |
+
:param lr: learning rate
|
33 |
+
:param betas: beta
|
34 |
+
:param eps: numerical precision
|
35 |
+
:param weight_decay: weight decay weight
|
36 |
+
"""
|
37 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
38 |
+
self.buffer = [[None, None, None] for _ in range(10)]
|
39 |
+
super().__init__(params, defaults)
|
40 |
+
|
41 |
+
def step(self, closure=None):
|
42 |
+
loss = None
|
43 |
+
if closure is not None:
|
44 |
+
loss = closure()
|
45 |
+
|
46 |
+
for group in self.param_groups:
|
47 |
+
for p in group["params"]:
|
48 |
+
if p.grad is None:
|
49 |
+
continue
|
50 |
+
grad = p.grad.data.float()
|
51 |
+
if grad.is_sparse:
|
52 |
+
raise RuntimeError("RAdam does not support sparse gradients")
|
53 |
+
|
54 |
+
p_data_fp32 = p.data.float()
|
55 |
+
|
56 |
+
state = self.state[p]
|
57 |
+
|
58 |
+
if len(state) == 0:
|
59 |
+
state["step"] = 0
|
60 |
+
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
61 |
+
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
62 |
+
else:
|
63 |
+
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
|
64 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
|
65 |
+
|
66 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
67 |
+
beta1, beta2 = group["betas"]
|
68 |
+
|
69 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
70 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
71 |
+
|
72 |
+
state["step"] += 1
|
73 |
+
buffered = self.buffer[int(state["step"] % 10)]
|
74 |
+
if state["step"] == buffered[0]:
|
75 |
+
N_sma, step_size = buffered[1], buffered[2]
|
76 |
+
else:
|
77 |
+
buffered[0] = state["step"]
|
78 |
+
beta2_t = beta2 ** state["step"]
|
79 |
+
N_sma_max = 2 / (1 - beta2) - 1
|
80 |
+
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
|
81 |
+
buffered[1] = N_sma
|
82 |
+
|
83 |
+
# more conservative since it's an approximated value
|
84 |
+
if N_sma >= 5:
|
85 |
+
step_size = (
|
86 |
+
group["lr"]
|
87 |
+
* math.sqrt(
|
88 |
+
(1 - beta2_t)
|
89 |
+
* (N_sma - 4)
|
90 |
+
/ (N_sma_max - 4)
|
91 |
+
* (N_sma - 2)
|
92 |
+
/ N_sma
|
93 |
+
* N_sma_max
|
94 |
+
/ (N_sma_max - 2)
|
95 |
+
)
|
96 |
+
/ (1 - beta1 ** state["step"])
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
step_size = group["lr"] / (1 - beta1 ** state["step"])
|
100 |
+
buffered[2] = step_size
|
101 |
+
|
102 |
+
if group["weight_decay"] != 0:
|
103 |
+
p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
|
104 |
+
|
105 |
+
# more conservative since it's an approximated value
|
106 |
+
if N_sma >= 5:
|
107 |
+
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
108 |
+
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
109 |
+
else:
|
110 |
+
p_data_fp32.add_(-step_size, exp_avg)
|
111 |
+
|
112 |
+
p.data.copy_(p_data_fp32)
|
113 |
+
|
114 |
+
return loss
|
radtts.py
ADDED
@@ -0,0 +1,936 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: MIT
|
3 |
+
#
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
+
# copy of this software and associated documentation files (the "Software"),
|
6 |
+
# to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
+
# Software is furnished to do so, subject to the following conditions:
|
10 |
+
#
|
11 |
+
# The above copyright notice and this permission notice shall be included in
|
12 |
+
# all copies or substantial portions of the Software.
|
13 |
+
#
|
14 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
+
# DEALINGS IN THE SOFTWARE.
|
21 |
+
import torch
|
22 |
+
from torch import nn
|
23 |
+
from common import Encoder, LengthRegulator, ConvAttention
|
24 |
+
from common import Invertible1x1ConvLUS, Invertible1x1Conv
|
25 |
+
from common import AffineTransformationLayer, LinearNorm, ExponentialClass
|
26 |
+
from common import get_mask_from_lengths
|
27 |
+
from attribute_prediction_model import get_attribute_prediction_model
|
28 |
+
from alignment import mas_width1 as mas
|
29 |
+
|
30 |
+
|
31 |
+
class FlowStep(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
n_mel_channels,
|
35 |
+
n_context_dim,
|
36 |
+
n_layers,
|
37 |
+
affine_model="simple_conv",
|
38 |
+
scaling_fn="exp",
|
39 |
+
matrix_decomposition="",
|
40 |
+
affine_activation="softplus",
|
41 |
+
use_partial_padding=False,
|
42 |
+
cache_inverse=False,
|
43 |
+
):
|
44 |
+
super(FlowStep, self).__init__()
|
45 |
+
if matrix_decomposition == "LUS":
|
46 |
+
self.invtbl_conv = Invertible1x1ConvLUS(
|
47 |
+
n_mel_channels, cache_inverse=cache_inverse
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
self.invtbl_conv = Invertible1x1Conv(
|
51 |
+
n_mel_channels, cache_inverse=cache_inverse
|
52 |
+
)
|
53 |
+
|
54 |
+
self.affine_tfn = AffineTransformationLayer(
|
55 |
+
n_mel_channels,
|
56 |
+
n_context_dim,
|
57 |
+
n_layers,
|
58 |
+
affine_model=affine_model,
|
59 |
+
scaling_fn=scaling_fn,
|
60 |
+
affine_activation=affine_activation,
|
61 |
+
use_partial_padding=use_partial_padding,
|
62 |
+
)
|
63 |
+
|
64 |
+
def enable_inverse_cache(self):
|
65 |
+
self.invtbl_conv.cache_inverse = True
|
66 |
+
|
67 |
+
def forward(self, z, context, inverse=False, seq_lens=None):
|
68 |
+
if inverse: # for inference z-> mel
|
69 |
+
z = self.affine_tfn(z, context, inverse, seq_lens=seq_lens)
|
70 |
+
z = self.invtbl_conv(z, inverse)
|
71 |
+
return z
|
72 |
+
else: # training mel->z
|
73 |
+
z, log_det_W = self.invtbl_conv(z)
|
74 |
+
z, log_s = self.affine_tfn(z, context, seq_lens=seq_lens)
|
75 |
+
return z, log_det_W, log_s
|
76 |
+
|
77 |
+
|
78 |
+
class RADTTS(torch.nn.Module):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
n_speakers,
|
82 |
+
n_speaker_dim,
|
83 |
+
n_text,
|
84 |
+
n_text_dim,
|
85 |
+
n_flows,
|
86 |
+
n_conv_layers_per_step,
|
87 |
+
n_mel_channels,
|
88 |
+
n_hidden,
|
89 |
+
mel_encoder_n_hidden,
|
90 |
+
dummy_speaker_embedding,
|
91 |
+
n_early_size,
|
92 |
+
n_early_every,
|
93 |
+
n_group_size,
|
94 |
+
affine_model,
|
95 |
+
dur_model_config,
|
96 |
+
f0_model_config,
|
97 |
+
energy_model_config,
|
98 |
+
v_model_config=None,
|
99 |
+
include_modules="dec",
|
100 |
+
scaling_fn="exp",
|
101 |
+
matrix_decomposition="",
|
102 |
+
learn_alignments=False,
|
103 |
+
affine_activation="softplus",
|
104 |
+
attn_use_CTC=True,
|
105 |
+
use_speaker_emb_for_alignment=False,
|
106 |
+
use_context_lstm=False,
|
107 |
+
context_lstm_norm=None,
|
108 |
+
text_encoder_lstm_norm=None,
|
109 |
+
n_f0_dims=0,
|
110 |
+
n_energy_avg_dims=0,
|
111 |
+
context_lstm_w_f0_and_energy=True,
|
112 |
+
use_first_order_features=False,
|
113 |
+
unvoiced_bias_activation="",
|
114 |
+
ap_pred_log_f0=False,
|
115 |
+
**kwargs,
|
116 |
+
):
|
117 |
+
super(RADTTS, self).__init__()
|
118 |
+
assert n_early_size % 2 == 0
|
119 |
+
self.do_mel_descaling = kwargs.get("do_mel_descaling", True)
|
120 |
+
self.n_mel_channels = n_mel_channels
|
121 |
+
self.n_f0_dims = n_f0_dims # >= 1 to trains with f0
|
122 |
+
self.n_energy_avg_dims = n_energy_avg_dims # >= 1 trains with energy
|
123 |
+
self.decoder_use_partial_padding = kwargs.get(
|
124 |
+
"decoder_use_partial_padding", True
|
125 |
+
)
|
126 |
+
self.n_speaker_dim = n_speaker_dim
|
127 |
+
assert self.n_speaker_dim % 2 == 0
|
128 |
+
self.speaker_embedding = torch.nn.Embedding(n_speakers, self.n_speaker_dim)
|
129 |
+
self.embedding = torch.nn.Embedding(n_text, n_text_dim)
|
130 |
+
self.flows = torch.nn.ModuleList()
|
131 |
+
self.encoder = Encoder(
|
132 |
+
encoder_embedding_dim=n_text_dim,
|
133 |
+
norm_fn=nn.InstanceNorm1d,
|
134 |
+
lstm_norm_fn=text_encoder_lstm_norm,
|
135 |
+
)
|
136 |
+
self.dummy_speaker_embedding = dummy_speaker_embedding
|
137 |
+
self.learn_alignments = learn_alignments
|
138 |
+
self.affine_activation = affine_activation
|
139 |
+
self.include_modules = include_modules
|
140 |
+
self.attn_use_CTC = bool(attn_use_CTC)
|
141 |
+
self.use_speaker_emb_for_alignment = use_speaker_emb_for_alignment
|
142 |
+
self.use_context_lstm = bool(use_context_lstm)
|
143 |
+
self.context_lstm_norm = context_lstm_norm
|
144 |
+
self.context_lstm_w_f0_and_energy = context_lstm_w_f0_and_energy
|
145 |
+
self.length_regulator = LengthRegulator()
|
146 |
+
self.use_first_order_features = bool(use_first_order_features)
|
147 |
+
self.decoder_use_unvoiced_bias = kwargs.get("decoder_use_unvoiced_bias", True)
|
148 |
+
self.ap_pred_log_f0 = ap_pred_log_f0
|
149 |
+
self.ap_use_unvoiced_bias = kwargs.get("ap_use_unvoiced_bias", True)
|
150 |
+
self.attn_straight_through_estimator = kwargs.get(
|
151 |
+
"attn_straight_through_estimator", False
|
152 |
+
)
|
153 |
+
if "atn" in include_modules or "dec" in include_modules:
|
154 |
+
if self.learn_alignments:
|
155 |
+
if self.use_speaker_emb_for_alignment:
|
156 |
+
self.attention = ConvAttention(
|
157 |
+
n_mel_channels, n_text_dim + self.n_speaker_dim
|
158 |
+
)
|
159 |
+
else:
|
160 |
+
self.attention = ConvAttention(n_mel_channels, n_text_dim)
|
161 |
+
|
162 |
+
self.n_flows = n_flows
|
163 |
+
self.n_group_size = n_group_size
|
164 |
+
|
165 |
+
n_flowstep_cond_dims = (
|
166 |
+
self.n_speaker_dim
|
167 |
+
+ (n_text_dim + n_f0_dims + n_energy_avg_dims) * n_group_size
|
168 |
+
)
|
169 |
+
|
170 |
+
if self.use_context_lstm:
|
171 |
+
n_in_context_lstm = self.n_speaker_dim + n_text_dim * n_group_size
|
172 |
+
n_context_lstm_hidden = int(
|
173 |
+
(self.n_speaker_dim + n_text_dim * n_group_size) / 2
|
174 |
+
)
|
175 |
+
|
176 |
+
if self.context_lstm_w_f0_and_energy:
|
177 |
+
n_in_context_lstm = n_f0_dims + n_energy_avg_dims + n_text_dim
|
178 |
+
n_in_context_lstm *= n_group_size
|
179 |
+
n_in_context_lstm += self.n_speaker_dim
|
180 |
+
|
181 |
+
n_context_hidden = n_f0_dims + n_energy_avg_dims + n_text_dim
|
182 |
+
n_context_hidden = n_context_hidden * n_group_size / 2
|
183 |
+
n_context_hidden = self.n_speaker_dim + n_context_hidden
|
184 |
+
n_context_hidden = int(n_context_hidden)
|
185 |
+
|
186 |
+
n_flowstep_cond_dims = (
|
187 |
+
self.n_speaker_dim + n_text_dim * n_group_size
|
188 |
+
)
|
189 |
+
|
190 |
+
self.context_lstm = torch.nn.LSTM(
|
191 |
+
input_size=n_in_context_lstm,
|
192 |
+
hidden_size=n_context_lstm_hidden,
|
193 |
+
num_layers=1,
|
194 |
+
batch_first=True,
|
195 |
+
bidirectional=True,
|
196 |
+
)
|
197 |
+
|
198 |
+
if context_lstm_norm is not None:
|
199 |
+
if "spectral" in context_lstm_norm:
|
200 |
+
print("Applying spectral norm to context encoder LSTM")
|
201 |
+
lstm_norm_fn_pntr = torch.nn.utils.spectral_norm
|
202 |
+
elif "weight" in context_lstm_norm:
|
203 |
+
print("Applying weight norm to context encoder LSTM")
|
204 |
+
lstm_norm_fn_pntr = torch.nn.utils.weight_norm
|
205 |
+
|
206 |
+
self.context_lstm = lstm_norm_fn_pntr(
|
207 |
+
self.context_lstm, "weight_hh_l0"
|
208 |
+
)
|
209 |
+
self.context_lstm = lstm_norm_fn_pntr(
|
210 |
+
self.context_lstm, "weight_hh_l0_reverse"
|
211 |
+
)
|
212 |
+
|
213 |
+
if self.n_group_size > 1:
|
214 |
+
self.unfold_params = {
|
215 |
+
"kernel_size": (n_group_size, 1),
|
216 |
+
"stride": n_group_size,
|
217 |
+
"padding": 0,
|
218 |
+
"dilation": 1,
|
219 |
+
}
|
220 |
+
self.unfold = nn.Unfold(**self.unfold_params)
|
221 |
+
|
222 |
+
self.exit_steps = []
|
223 |
+
self.n_early_size = n_early_size
|
224 |
+
n_mel_channels = n_mel_channels * n_group_size
|
225 |
+
|
226 |
+
for i in range(self.n_flows):
|
227 |
+
if i > 0 and i % n_early_every == 0: # early exitting
|
228 |
+
n_mel_channels -= self.n_early_size
|
229 |
+
self.exit_steps.append(i)
|
230 |
+
|
231 |
+
self.flows.append(
|
232 |
+
FlowStep(
|
233 |
+
n_mel_channels,
|
234 |
+
n_flowstep_cond_dims,
|
235 |
+
n_conv_layers_per_step,
|
236 |
+
affine_model,
|
237 |
+
scaling_fn,
|
238 |
+
matrix_decomposition,
|
239 |
+
affine_activation=affine_activation,
|
240 |
+
use_partial_padding=self.decoder_use_partial_padding,
|
241 |
+
)
|
242 |
+
)
|
243 |
+
|
244 |
+
if "dpm" in include_modules:
|
245 |
+
dur_model_config["hparams"]["n_speaker_dim"] = n_speaker_dim
|
246 |
+
self.dur_pred_layer = get_attribute_prediction_model(dur_model_config)
|
247 |
+
|
248 |
+
self.use_unvoiced_bias = False
|
249 |
+
self.use_vpred_module = False
|
250 |
+
self.ap_use_voiced_embeddings = kwargs.get("ap_use_voiced_embeddings", True)
|
251 |
+
|
252 |
+
if self.decoder_use_unvoiced_bias or self.ap_use_unvoiced_bias:
|
253 |
+
assert unvoiced_bias_activation in {"relu", "exp"}
|
254 |
+
self.use_unvoiced_bias = True
|
255 |
+
if unvoiced_bias_activation == "relu":
|
256 |
+
unvbias_nonlin = nn.ReLU()
|
257 |
+
elif unvoiced_bias_activation == "exp":
|
258 |
+
unvbias_nonlin = ExponentialClass()
|
259 |
+
else:
|
260 |
+
exit(1) # we won't reach here anyway due to the assertion
|
261 |
+
self.unvoiced_bias_module = nn.Sequential(
|
262 |
+
LinearNorm(n_text_dim, 1), unvbias_nonlin
|
263 |
+
)
|
264 |
+
|
265 |
+
# all situations in which the vpred module is necessary
|
266 |
+
if (
|
267 |
+
self.ap_use_voiced_embeddings
|
268 |
+
or self.use_unvoiced_bias
|
269 |
+
or "vpred" in include_modules
|
270 |
+
):
|
271 |
+
self.use_vpred_module = True
|
272 |
+
|
273 |
+
if self.use_vpred_module:
|
274 |
+
v_model_config["hparams"]["n_speaker_dim"] = n_speaker_dim
|
275 |
+
self.v_pred_module = get_attribute_prediction_model(v_model_config)
|
276 |
+
# 4 embeddings, first two are scales, second two are biases
|
277 |
+
if self.ap_use_voiced_embeddings:
|
278 |
+
self.v_embeddings = torch.nn.Embedding(4, n_text_dim)
|
279 |
+
|
280 |
+
if "apm" in include_modules:
|
281 |
+
f0_model_config["hparams"]["n_speaker_dim"] = n_speaker_dim
|
282 |
+
energy_model_config["hparams"]["n_speaker_dim"] = n_speaker_dim
|
283 |
+
if self.use_first_order_features:
|
284 |
+
f0_model_config["hparams"]["n_in_dim"] = 2
|
285 |
+
energy_model_config["hparams"]["n_in_dim"] = 2
|
286 |
+
if (
|
287 |
+
"spline_flow_params" in f0_model_config["hparams"]
|
288 |
+
and f0_model_config["hparams"]["spline_flow_params"] is not None
|
289 |
+
):
|
290 |
+
f0_model_config["hparams"]["spline_flow_params"][
|
291 |
+
"n_in_channels"
|
292 |
+
] = 2
|
293 |
+
if (
|
294 |
+
"spline_flow_params" in energy_model_config["hparams"]
|
295 |
+
and energy_model_config["hparams"]["spline_flow_params"] is not None
|
296 |
+
):
|
297 |
+
energy_model_config["hparams"]["spline_flow_params"][
|
298 |
+
"n_in_channels"
|
299 |
+
] = 2
|
300 |
+
else:
|
301 |
+
if (
|
302 |
+
"spline_flow_params" in f0_model_config["hparams"]
|
303 |
+
and f0_model_config["hparams"]["spline_flow_params"] is not None
|
304 |
+
):
|
305 |
+
f0_model_config["hparams"]["spline_flow_params"][
|
306 |
+
"n_in_channels"
|
307 |
+
] = f0_model_config["hparams"]["n_in_dim"]
|
308 |
+
if (
|
309 |
+
"spline_flow_params" in energy_model_config["hparams"]
|
310 |
+
and energy_model_config["hparams"]["spline_flow_params"] is not None
|
311 |
+
):
|
312 |
+
energy_model_config["hparams"]["spline_flow_params"][
|
313 |
+
"n_in_channels"
|
314 |
+
] = energy_model_config["hparams"]["n_in_dim"]
|
315 |
+
|
316 |
+
self.f0_pred_module = get_attribute_prediction_model(f0_model_config)
|
317 |
+
self.energy_pred_module = get_attribute_prediction_model(
|
318 |
+
energy_model_config
|
319 |
+
)
|
320 |
+
|
321 |
+
def is_attribute_unconditional(self):
|
322 |
+
"""
|
323 |
+
returns true if the decoder is conditioned on neither energy nor F0
|
324 |
+
"""
|
325 |
+
return self.n_f0_dims == 0 and self.n_energy_avg_dims == 0
|
326 |
+
|
327 |
+
def encode_speaker(self, spk_ids):
|
328 |
+
spk_ids = spk_ids * 0 if self.dummy_speaker_embedding else spk_ids
|
329 |
+
spk_vecs = self.speaker_embedding(spk_ids)
|
330 |
+
return spk_vecs
|
331 |
+
|
332 |
+
def encode_text(self, text, in_lens):
|
333 |
+
# text_embeddings: b x len_text x n_text_dim
|
334 |
+
text_embeddings = self.embedding(text).transpose(1, 2)
|
335 |
+
# text_enc: b x n_text_dim x encoder_dim (512)
|
336 |
+
if in_lens is None:
|
337 |
+
text_enc = self.encoder.infer(text_embeddings).transpose(1, 2)
|
338 |
+
else:
|
339 |
+
text_enc = self.encoder(text_embeddings, in_lens).transpose(1, 2)
|
340 |
+
|
341 |
+
return text_enc, text_embeddings
|
342 |
+
|
343 |
+
def preprocess_context(
|
344 |
+
self, context, speaker_vecs, out_lens=None, f0=None, energy_avg=None
|
345 |
+
):
|
346 |
+
if self.n_group_size > 1:
|
347 |
+
# unfolding zero-padded values
|
348 |
+
context = self.unfold(context.unsqueeze(-1))
|
349 |
+
if f0 is not None:
|
350 |
+
f0 = self.unfold(f0[:, None, :, None])
|
351 |
+
if energy_avg is not None:
|
352 |
+
energy_avg = self.unfold(energy_avg[:, None, :, None])
|
353 |
+
speaker_vecs = speaker_vecs[..., None].expand(-1, -1, context.shape[2])
|
354 |
+
context_w_spkvec = torch.cat((context, speaker_vecs), 1)
|
355 |
+
|
356 |
+
if self.use_context_lstm:
|
357 |
+
if self.context_lstm_w_f0_and_energy:
|
358 |
+
if f0 is not None:
|
359 |
+
context_w_spkvec = torch.cat((context_w_spkvec, f0), 1)
|
360 |
+
|
361 |
+
if energy_avg is not None:
|
362 |
+
context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1)
|
363 |
+
|
364 |
+
unfolded_out_lens = (out_lens // self.n_group_size).long().cpu()
|
365 |
+
unfolded_out_lens_packed = nn.utils.rnn.pack_padded_sequence(
|
366 |
+
context_w_spkvec.transpose(1, 2),
|
367 |
+
unfolded_out_lens,
|
368 |
+
batch_first=True,
|
369 |
+
enforce_sorted=False,
|
370 |
+
)
|
371 |
+
self.context_lstm.flatten_parameters()
|
372 |
+
context_lstm_packed_output, _ = self.context_lstm(unfolded_out_lens_packed)
|
373 |
+
context_lstm_padded_output, _ = nn.utils.rnn.pad_packed_sequence(
|
374 |
+
context_lstm_packed_output, batch_first=True
|
375 |
+
)
|
376 |
+
context_w_spkvec = context_lstm_padded_output.transpose(1, 2)
|
377 |
+
|
378 |
+
if not self.context_lstm_w_f0_and_energy:
|
379 |
+
if f0 is not None:
|
380 |
+
context_w_spkvec = torch.cat((context_w_spkvec, f0), 1)
|
381 |
+
|
382 |
+
if energy_avg is not None:
|
383 |
+
context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1)
|
384 |
+
|
385 |
+
return context_w_spkvec
|
386 |
+
|
387 |
+
def enable_inverse_cache(self):
|
388 |
+
for flow_step in self.flows:
|
389 |
+
flow_step.enable_inverse_cache()
|
390 |
+
|
391 |
+
def fold(self, mel):
|
392 |
+
"""Inverse of the self.unfold(mel.unsqueeze(-1)) operation used for the
|
393 |
+
grouping or "squeeze" operation on input
|
394 |
+
|
395 |
+
Args:
|
396 |
+
mel: B x C x T tensor of temporal data
|
397 |
+
"""
|
398 |
+
mel = nn.functional.fold(
|
399 |
+
mel, output_size=(mel.shape[2] * self.n_group_size, 1), **self.unfold_params
|
400 |
+
).squeeze(-1)
|
401 |
+
return mel
|
402 |
+
|
403 |
+
def binarize_attention(self, attn, in_lens, out_lens):
|
404 |
+
"""For training purposes only. Binarizes attention with MAS. These will
|
405 |
+
no longer recieve a gradient
|
406 |
+
Args:
|
407 |
+
attn: B x 1 x max_mel_len x max_text_len
|
408 |
+
"""
|
409 |
+
b_size = attn.shape[0]
|
410 |
+
with torch.no_grad():
|
411 |
+
attn_cpu = attn.data.cpu().numpy()
|
412 |
+
attn_out = torch.zeros_like(attn)
|
413 |
+
for ind in range(b_size):
|
414 |
+
hard_attn = mas(attn_cpu[ind, 0, : out_lens[ind], : in_lens[ind]])
|
415 |
+
attn_out[ind, 0, : out_lens[ind], : in_lens[ind]] = torch.tensor(
|
416 |
+
hard_attn, device=attn.get_device()
|
417 |
+
)
|
418 |
+
return attn_out
|
419 |
+
|
420 |
+
def get_first_order_features(self, feats, out_lens, dilation=1):
|
421 |
+
"""
|
422 |
+
feats: b x max_length
|
423 |
+
out_lens: b-dim
|
424 |
+
"""
|
425 |
+
# add an extra column
|
426 |
+
feats_extended_R = torch.cat(
|
427 |
+
(feats, torch.zeros_like(feats[:, 0:dilation])), dim=1
|
428 |
+
)
|
429 |
+
feats_extended_L = torch.cat(
|
430 |
+
(torch.zeros_like(feats[:, 0:dilation]), feats), dim=1
|
431 |
+
)
|
432 |
+
dfeats_R = feats_extended_R[:, dilation:] - feats
|
433 |
+
dfeats_L = feats - feats_extended_L[:, 0:-dilation]
|
434 |
+
|
435 |
+
return (dfeats_R + dfeats_L) * 0.5
|
436 |
+
|
437 |
+
def apply_voice_mask_to_text(self, text_enc, voiced_mask):
|
438 |
+
"""
|
439 |
+
text_enc: b x C x N
|
440 |
+
voiced_mask: b x N
|
441 |
+
"""
|
442 |
+
voiced_mask = voiced_mask.unsqueeze(1)
|
443 |
+
voiced_embedding_s = self.v_embeddings.weight[0:1, :, None]
|
444 |
+
unvoiced_embedding_s = self.v_embeddings.weight[1:2, :, None]
|
445 |
+
voiced_embedding_b = self.v_embeddings.weight[2:3, :, None]
|
446 |
+
unvoiced_embedding_b = self.v_embeddings.weight[3:4, :, None]
|
447 |
+
scale = torch.sigmoid(
|
448 |
+
voiced_embedding_s * voiced_mask + unvoiced_embedding_s * (1 - voiced_mask)
|
449 |
+
)
|
450 |
+
bias = 0.1 * torch.tanh(
|
451 |
+
voiced_embedding_b * voiced_mask + unvoiced_embedding_b * (1 - voiced_mask)
|
452 |
+
)
|
453 |
+
return text_enc * scale + bias
|
454 |
+
|
455 |
+
def forward(
|
456 |
+
self,
|
457 |
+
mel,
|
458 |
+
speaker_ids,
|
459 |
+
text,
|
460 |
+
in_lens,
|
461 |
+
out_lens,
|
462 |
+
binarize_attention=False,
|
463 |
+
attn_prior=None,
|
464 |
+
f0=None,
|
465 |
+
energy_avg=None,
|
466 |
+
voiced_mask=None,
|
467 |
+
p_voiced=None,
|
468 |
+
):
|
469 |
+
speaker_vecs = self.encode_speaker(speaker_ids)
|
470 |
+
text_enc, text_embeddings = self.encode_text(text, in_lens)
|
471 |
+
|
472 |
+
log_s_list, log_det_W_list, z_mel = [], [], []
|
473 |
+
attn = None
|
474 |
+
attn_soft = None
|
475 |
+
attn_hard = None
|
476 |
+
if "atn" in self.include_modules or "dec" in self.include_modules:
|
477 |
+
# make sure to do the alignments before folding
|
478 |
+
attn_mask = get_mask_from_lengths(in_lens)[..., None] == 0
|
479 |
+
|
480 |
+
text_embeddings_for_attn = text_embeddings
|
481 |
+
if self.use_speaker_emb_for_alignment:
|
482 |
+
speaker_vecs_expd = speaker_vecs[:, :, None].expand(
|
483 |
+
-1, -1, text_embeddings.shape[2]
|
484 |
+
)
|
485 |
+
text_embeddings_for_attn = torch.cat(
|
486 |
+
(text_embeddings_for_attn, speaker_vecs_expd.detach()), 1
|
487 |
+
)
|
488 |
+
|
489 |
+
# attn_mask shld be 1 for unsd t-steps in text_enc_w_spkvec tensor
|
490 |
+
attn_soft, attn_logprob = self.attention(
|
491 |
+
mel,
|
492 |
+
text_embeddings_for_attn,
|
493 |
+
out_lens,
|
494 |
+
attn_mask,
|
495 |
+
key_lens=in_lens,
|
496 |
+
attn_prior=attn_prior,
|
497 |
+
)
|
498 |
+
|
499 |
+
if binarize_attention:
|
500 |
+
attn = self.binarize_attention(attn_soft, in_lens, out_lens)
|
501 |
+
attn_hard = attn
|
502 |
+
if self.attn_straight_through_estimator:
|
503 |
+
attn_hard = attn_soft + (attn_hard - attn_soft).detach()
|
504 |
+
else:
|
505 |
+
attn = attn_soft
|
506 |
+
|
507 |
+
context = torch.bmm(text_enc, attn.squeeze(1).transpose(1, 2))
|
508 |
+
|
509 |
+
f0_bias = 0
|
510 |
+
# unvoiced bias forward pass
|
511 |
+
if self.use_unvoiced_bias:
|
512 |
+
f0_bias = self.unvoiced_bias_module(context.permute(0, 2, 1))
|
513 |
+
f0_bias = -f0_bias[..., 0]
|
514 |
+
f0_bias = f0_bias * (~voiced_mask.bool()).float()
|
515 |
+
|
516 |
+
# mel decoder forward pass
|
517 |
+
if "dec" in self.include_modules:
|
518 |
+
if self.n_group_size > 1:
|
519 |
+
# might truncate some frames at the end, but that's ok
|
520 |
+
# sometimes referred to as the "squeeeze" operation
|
521 |
+
# invert this by calling self.fold(mel_or_z)
|
522 |
+
mel = self.unfold(mel.unsqueeze(-1))
|
523 |
+
z_out = []
|
524 |
+
# where context is folded
|
525 |
+
# mask f0 in case values are interpolated
|
526 |
+
|
527 |
+
if f0 is None:
|
528 |
+
f0_aug = None
|
529 |
+
else:
|
530 |
+
if self.decoder_use_unvoiced_bias:
|
531 |
+
f0_aug = f0 * voiced_mask + f0_bias
|
532 |
+
else:
|
533 |
+
f0_aug = f0 * voiced_mask
|
534 |
+
|
535 |
+
context_w_spkvec = self.preprocess_context(
|
536 |
+
context, speaker_vecs, out_lens, f0_aug, energy_avg
|
537 |
+
)
|
538 |
+
|
539 |
+
log_s_list, log_det_W_list, z_out = [], [], []
|
540 |
+
unfolded_seq_lens = out_lens // self.n_group_size
|
541 |
+
for i, flow_step in enumerate(self.flows):
|
542 |
+
if i in self.exit_steps:
|
543 |
+
z = mel[:, : self.n_early_size]
|
544 |
+
z_out.append(z)
|
545 |
+
mel = mel[:, self.n_early_size :]
|
546 |
+
mel, log_det_W, log_s = flow_step(
|
547 |
+
mel, context_w_spkvec, seq_lens=unfolded_seq_lens
|
548 |
+
)
|
549 |
+
log_s_list.append(log_s)
|
550 |
+
log_det_W_list.append(log_det_W)
|
551 |
+
|
552 |
+
z_out.append(mel)
|
553 |
+
z_mel = torch.cat(z_out, 1)
|
554 |
+
|
555 |
+
# duration predictor forward pass
|
556 |
+
duration_model_outputs = None
|
557 |
+
if "dpm" in self.include_modules:
|
558 |
+
if attn_hard is None:
|
559 |
+
attn_hard = self.binarize_attention(attn_soft, in_lens, out_lens)
|
560 |
+
|
561 |
+
# convert hard attention to durations
|
562 |
+
attn_hard_reduced = attn_hard.sum(2)[:, 0, :]
|
563 |
+
duration_model_outputs = self.dur_pred_layer(
|
564 |
+
torch.detach(text_enc),
|
565 |
+
torch.detach(speaker_vecs),
|
566 |
+
torch.detach(attn_hard_reduced.float()),
|
567 |
+
in_lens,
|
568 |
+
)
|
569 |
+
|
570 |
+
# f0, energy, vpred predictors forward pass
|
571 |
+
f0_model_outputs = None
|
572 |
+
energy_model_outputs = None
|
573 |
+
vpred_model_outputs = None
|
574 |
+
if "apm" in self.include_modules:
|
575 |
+
if attn_hard is None:
|
576 |
+
attn_hard = self.binarize_attention(attn_soft, in_lens, out_lens)
|
577 |
+
|
578 |
+
# convert hard attention to durations
|
579 |
+
if binarize_attention:
|
580 |
+
text_enc_time_expanded = context.clone()
|
581 |
+
else:
|
582 |
+
text_enc_time_expanded = torch.bmm(
|
583 |
+
text_enc, attn_hard.squeeze(1).transpose(1, 2)
|
584 |
+
)
|
585 |
+
|
586 |
+
if self.use_vpred_module:
|
587 |
+
# unvoiced bias requires voiced mask prediction
|
588 |
+
vpred_model_outputs = self.v_pred_module(
|
589 |
+
torch.detach(text_enc_time_expanded),
|
590 |
+
torch.detach(speaker_vecs),
|
591 |
+
torch.detach(voiced_mask),
|
592 |
+
out_lens,
|
593 |
+
)
|
594 |
+
|
595 |
+
# affine transform context using voiced mask
|
596 |
+
if self.ap_use_voiced_embeddings:
|
597 |
+
text_enc_time_expanded = self.apply_voice_mask_to_text(
|
598 |
+
text_enc_time_expanded, voiced_mask
|
599 |
+
)
|
600 |
+
|
601 |
+
# whether to use the unvoiced bias in the attribute predictor
|
602 |
+
# circumvent in-place modification
|
603 |
+
f0_target = f0.clone()
|
604 |
+
if self.ap_use_unvoiced_bias:
|
605 |
+
f0_target = torch.detach(f0_target * voiced_mask + f0_bias)
|
606 |
+
else:
|
607 |
+
f0_target = torch.detach(f0_target)
|
608 |
+
|
609 |
+
# fit to log f0 in f0 predictor
|
610 |
+
f0_target[voiced_mask.bool()] = torch.log(f0_target[voiced_mask.bool()])
|
611 |
+
f0_target = f0_target / 6 # scale to ~ [0, 1] in log space
|
612 |
+
energy_avg = energy_avg * 2 - 1 # scale to ~ [-1, 1]
|
613 |
+
|
614 |
+
if self.use_first_order_features:
|
615 |
+
df0 = self.get_first_order_features(f0_target, out_lens)
|
616 |
+
denergy_avg = self.get_first_order_features(energy_avg, out_lens)
|
617 |
+
|
618 |
+
f0_voiced = torch.cat((f0_target[:, None], df0[:, None]), dim=1)
|
619 |
+
energy_avg = torch.cat(
|
620 |
+
(energy_avg[:, None], denergy_avg[:, None]), dim=1
|
621 |
+
)
|
622 |
+
|
623 |
+
f0_voiced = f0_voiced * 3 # scale to ~ 1 std
|
624 |
+
energy_avg = energy_avg * 3 # scale to ~ 1 std
|
625 |
+
else:
|
626 |
+
f0_voiced = f0_target * 2 # scale to ~ 1 std
|
627 |
+
energy_avg = energy_avg * 1.4 # scale to ~ 1 std
|
628 |
+
|
629 |
+
f0_model_outputs = self.f0_pred_module(
|
630 |
+
text_enc_time_expanded, torch.detach(speaker_vecs), f0_voiced, out_lens
|
631 |
+
)
|
632 |
+
|
633 |
+
energy_model_outputs = self.energy_pred_module(
|
634 |
+
text_enc_time_expanded, torch.detach(speaker_vecs), energy_avg, out_lens
|
635 |
+
)
|
636 |
+
|
637 |
+
outputs = {
|
638 |
+
"z_mel": z_mel,
|
639 |
+
"log_det_W_list": log_det_W_list,
|
640 |
+
"log_s_list": log_s_list,
|
641 |
+
"duration_model_outputs": duration_model_outputs,
|
642 |
+
"f0_model_outputs": f0_model_outputs,
|
643 |
+
"energy_model_outputs": energy_model_outputs,
|
644 |
+
"vpred_model_outputs": vpred_model_outputs,
|
645 |
+
"attn_soft": attn_soft,
|
646 |
+
"attn": attn,
|
647 |
+
"text_embeddings": text_embeddings,
|
648 |
+
"attn_logprob": attn_logprob,
|
649 |
+
}
|
650 |
+
|
651 |
+
return outputs
|
652 |
+
|
653 |
+
def infer(
|
654 |
+
self,
|
655 |
+
speaker_id,
|
656 |
+
text,
|
657 |
+
sigma,
|
658 |
+
sigma_dur=0.8,
|
659 |
+
sigma_f0=0.8,
|
660 |
+
sigma_energy=0.8,
|
661 |
+
token_dur_scaling=1.0,
|
662 |
+
token_duration_max=100,
|
663 |
+
speaker_id_text=None,
|
664 |
+
speaker_id_attributes=None,
|
665 |
+
dur=None,
|
666 |
+
f0=None,
|
667 |
+
energy_avg=None,
|
668 |
+
voiced_mask=None,
|
669 |
+
f0_mean=0.0,
|
670 |
+
f0_std=0.0,
|
671 |
+
energy_mean=0.0,
|
672 |
+
energy_std=0.0,
|
673 |
+
use_cuda=False,
|
674 |
+
):
|
675 |
+
batch_size = text.shape[0]
|
676 |
+
n_tokens = text.shape[1]
|
677 |
+
spk_vec = self.encode_speaker(speaker_id)
|
678 |
+
spk_vec_text, spk_vec_attributes = spk_vec, spk_vec
|
679 |
+
if speaker_id_text is not None:
|
680 |
+
spk_vec_text = self.encode_speaker(speaker_id_text)
|
681 |
+
if speaker_id_attributes is not None:
|
682 |
+
spk_vec_attributes = self.encode_speaker(speaker_id_attributes)
|
683 |
+
|
684 |
+
txt_enc, txt_emb = self.encode_text(text, None)
|
685 |
+
|
686 |
+
if dur is None:
|
687 |
+
# get token durations
|
688 |
+
if use_cuda:
|
689 |
+
z_dur = torch.cuda.FloatTensor(batch_size, 1, n_tokens)
|
690 |
+
else:
|
691 |
+
z_dur = torch.FloatTensor(batch_size, 1, n_tokens)
|
692 |
+
|
693 |
+
z_dur = z_dur.normal_() * sigma_dur
|
694 |
+
|
695 |
+
dur = self.dur_pred_layer.infer(z_dur, txt_enc, spk_vec_text)
|
696 |
+
if dur.shape[-1] < txt_enc.shape[-1]:
|
697 |
+
to_pad = txt_enc.shape[-1] - dur.shape[2]
|
698 |
+
pad_fn = nn.ReplicationPad1d((0, to_pad))
|
699 |
+
dur = pad_fn(dur)
|
700 |
+
dur = dur[:, 0]
|
701 |
+
dur = dur.clamp(0, token_duration_max)
|
702 |
+
dur = dur * token_dur_scaling if token_dur_scaling > 0 else dur
|
703 |
+
dur = (dur + 0.5).floor().int()
|
704 |
+
|
705 |
+
out_lens = dur.sum(1).long().cpu() if dur.shape[0] != 1 else [dur.sum(1)]
|
706 |
+
max_n_frames = max(out_lens)
|
707 |
+
|
708 |
+
out_lens = torch.LongTensor(out_lens).to(txt_enc.device)
|
709 |
+
|
710 |
+
# get attributes f0, energy, vpred, etc)
|
711 |
+
txt_enc_time_expanded = self.length_regulator(
|
712 |
+
txt_enc.transpose(1, 2), dur
|
713 |
+
).transpose(1, 2)
|
714 |
+
|
715 |
+
if not self.is_attribute_unconditional():
|
716 |
+
# if explicitly modeling attributes
|
717 |
+
if voiced_mask is None:
|
718 |
+
if self.use_vpred_module:
|
719 |
+
# get logits
|
720 |
+
voiced_mask = self.v_pred_module.infer(
|
721 |
+
None, txt_enc_time_expanded, spk_vec_attributes
|
722 |
+
)
|
723 |
+
voiced_mask = torch.sigmoid(voiced_mask[:, 0]) > 0.5
|
724 |
+
voiced_mask = voiced_mask.float()
|
725 |
+
|
726 |
+
ap_txt_enc_time_expanded = txt_enc_time_expanded
|
727 |
+
# voice mask augmentation only used for attribute prediction
|
728 |
+
if self.ap_use_voiced_embeddings:
|
729 |
+
ap_txt_enc_time_expanded = self.apply_voice_mask_to_text(
|
730 |
+
txt_enc_time_expanded, voiced_mask
|
731 |
+
)
|
732 |
+
|
733 |
+
f0_bias = 0
|
734 |
+
# unvoiced bias forward pass
|
735 |
+
if self.use_unvoiced_bias:
|
736 |
+
f0_bias = self.unvoiced_bias_module(
|
737 |
+
txt_enc_time_expanded.permute(0, 2, 1)
|
738 |
+
)
|
739 |
+
f0_bias = -f0_bias[..., 0]
|
740 |
+
f0_bias = f0_bias * (~voiced_mask.bool()).float()
|
741 |
+
|
742 |
+
if f0 is None:
|
743 |
+
n_f0_feature_channels = 2 if self.use_first_order_features else 1
|
744 |
+
|
745 |
+
if use_cuda:
|
746 |
+
z_f0 = (
|
747 |
+
torch.cuda.FloatTensor(
|
748 |
+
batch_size, n_f0_feature_channels, max_n_frames
|
749 |
+
).normal_()
|
750 |
+
* sigma_f0
|
751 |
+
)
|
752 |
+
else:
|
753 |
+
z_f0 = (
|
754 |
+
torch.FloatTensor(
|
755 |
+
batch_size, n_f0_feature_channels, max_n_frames
|
756 |
+
).normal_()
|
757 |
+
* sigma_f0
|
758 |
+
)
|
759 |
+
|
760 |
+
f0 = self.infer_f0(
|
761 |
+
z_f0,
|
762 |
+
ap_txt_enc_time_expanded,
|
763 |
+
spk_vec_attributes,
|
764 |
+
voiced_mask,
|
765 |
+
out_lens,
|
766 |
+
)[:, 0]
|
767 |
+
|
768 |
+
if f0_mean > 0.0:
|
769 |
+
vmask_bool = voiced_mask.bool()
|
770 |
+
f0_mu, f0_sigma = f0[vmask_bool].mean(), f0[vmask_bool].std()
|
771 |
+
f0[vmask_bool] = (f0[vmask_bool] - f0_mu) / f0_sigma
|
772 |
+
f0_std = f0_std if f0_std > 0 else f0_sigma
|
773 |
+
f0[vmask_bool] = f0[vmask_bool] * f0_std + f0_mean
|
774 |
+
|
775 |
+
if energy_avg is None:
|
776 |
+
n_energy_feature_channels = 2 if self.use_first_order_features else 1
|
777 |
+
if use_cuda:
|
778 |
+
z_energy_avg = (
|
779 |
+
torch.cuda.FloatTensor(
|
780 |
+
batch_size, n_energy_feature_channels, max_n_frames
|
781 |
+
).normal_()
|
782 |
+
* sigma_energy
|
783 |
+
)
|
784 |
+
else:
|
785 |
+
z_energy_avg = (
|
786 |
+
torch.FloatTensor(
|
787 |
+
batch_size, n_energy_feature_channels, max_n_frames
|
788 |
+
).normal_()
|
789 |
+
* sigma_energy
|
790 |
+
)
|
791 |
+
energy_avg = self.infer_energy(
|
792 |
+
z_energy_avg, ap_txt_enc_time_expanded, spk_vec, out_lens
|
793 |
+
)[:, 0]
|
794 |
+
|
795 |
+
# replication pad, because ungrouping with different group sizes
|
796 |
+
# may lead to mismatched lengths
|
797 |
+
if energy_avg.shape[1] < out_lens[0]:
|
798 |
+
to_pad = out_lens[0] - energy_avg.shape[1]
|
799 |
+
pad_fn = nn.ReplicationPad1d((0, to_pad))
|
800 |
+
f0 = pad_fn(f0[None])[0]
|
801 |
+
energy_avg = pad_fn(energy_avg[None])[0]
|
802 |
+
if f0.shape[1] < out_lens[0]:
|
803 |
+
to_pad = out_lens[0] - f0.shape[1]
|
804 |
+
pad_fn = nn.ReplicationPad1d((0, to_pad))
|
805 |
+
f0 = pad_fn(f0[None])[0]
|
806 |
+
|
807 |
+
if self.decoder_use_unvoiced_bias:
|
808 |
+
context_w_spkvec = self.preprocess_context(
|
809 |
+
txt_enc_time_expanded,
|
810 |
+
spk_vec,
|
811 |
+
out_lens,
|
812 |
+
f0 * voiced_mask + f0_bias,
|
813 |
+
energy_avg,
|
814 |
+
)
|
815 |
+
else:
|
816 |
+
context_w_spkvec = self.preprocess_context(
|
817 |
+
txt_enc_time_expanded,
|
818 |
+
spk_vec,
|
819 |
+
out_lens,
|
820 |
+
f0 * voiced_mask,
|
821 |
+
energy_avg,
|
822 |
+
)
|
823 |
+
else:
|
824 |
+
context_w_spkvec = self.preprocess_context(
|
825 |
+
txt_enc_time_expanded, spk_vec, out_lens, None, None
|
826 |
+
)
|
827 |
+
|
828 |
+
if use_cuda:
|
829 |
+
residual = torch.cuda.FloatTensor(
|
830 |
+
batch_size, 80 * self.n_group_size, max_n_frames // self.n_group_size
|
831 |
+
)
|
832 |
+
else:
|
833 |
+
residual = torch.FloatTensor(
|
834 |
+
batch_size, 80 * self.n_group_size, max_n_frames // self.n_group_size
|
835 |
+
)
|
836 |
+
|
837 |
+
residual = residual.normal_() * sigma
|
838 |
+
|
839 |
+
# map from z sample to data
|
840 |
+
exit_steps_stack = self.exit_steps.copy()
|
841 |
+
mel = residual[:, len(exit_steps_stack) * self.n_early_size :]
|
842 |
+
remaining_residual = residual[:, : len(exit_steps_stack) * self.n_early_size]
|
843 |
+
unfolded_seq_lens = out_lens // self.n_group_size
|
844 |
+
for i, flow_step in enumerate(reversed(self.flows)):
|
845 |
+
curr_step = len(self.flows) - i - 1
|
846 |
+
mel = flow_step(
|
847 |
+
mel, context_w_spkvec, inverse=True, seq_lens=unfolded_seq_lens
|
848 |
+
)
|
849 |
+
if len(exit_steps_stack) > 0 and curr_step == exit_steps_stack[-1]:
|
850 |
+
# concatenate the next chunk of z
|
851 |
+
exit_steps_stack.pop()
|
852 |
+
residual_to_add = remaining_residual[
|
853 |
+
:, len(exit_steps_stack) * self.n_early_size :
|
854 |
+
]
|
855 |
+
remaining_residual = remaining_residual[
|
856 |
+
:, : len(exit_steps_stack) * self.n_early_size
|
857 |
+
]
|
858 |
+
mel = torch.cat((residual_to_add, mel), 1)
|
859 |
+
|
860 |
+
if self.n_group_size > 1:
|
861 |
+
mel = self.fold(mel)
|
862 |
+
if self.do_mel_descaling:
|
863 |
+
mel = mel * 2 - 5.5
|
864 |
+
|
865 |
+
return {
|
866 |
+
"mel": mel,
|
867 |
+
"dur": dur,
|
868 |
+
"f0": f0,
|
869 |
+
"energy_avg": energy_avg,
|
870 |
+
"voiced_mask": voiced_mask,
|
871 |
+
}
|
872 |
+
|
873 |
+
def infer_f0(
|
874 |
+
self, residual, txt_enc_time_expanded, spk_vec, voiced_mask=None, lens=None
|
875 |
+
):
|
876 |
+
f0 = self.f0_pred_module.infer(residual, txt_enc_time_expanded, spk_vec, lens)
|
877 |
+
|
878 |
+
if voiced_mask is not None and len(voiced_mask.shape) == 2:
|
879 |
+
voiced_mask = voiced_mask[:, None]
|
880 |
+
|
881 |
+
# constants
|
882 |
+
if self.ap_pred_log_f0:
|
883 |
+
if self.use_first_order_features:
|
884 |
+
f0 = f0[:, 0:1, :] / 3
|
885 |
+
else:
|
886 |
+
f0 = f0 / 2
|
887 |
+
f0 = f0 * 6
|
888 |
+
else:
|
889 |
+
f0 = f0 / 6
|
890 |
+
f0 = f0 / 640
|
891 |
+
|
892 |
+
if voiced_mask is None:
|
893 |
+
voiced_mask = f0 > 0.0
|
894 |
+
else:
|
895 |
+
voiced_mask = voiced_mask.bool()
|
896 |
+
|
897 |
+
# due to grouping, f0 might be 1 frame short
|
898 |
+
voiced_mask = voiced_mask[:, :, : f0.shape[-1]]
|
899 |
+
if self.ap_pred_log_f0:
|
900 |
+
# if variable is set, decoder sees linear f0
|
901 |
+
# mask = f0 > 0.0 if voiced_mask is None else voiced_mask.bool()
|
902 |
+
f0[voiced_mask] = torch.exp(f0[voiced_mask])
|
903 |
+
f0[~voiced_mask] = 0.0
|
904 |
+
return f0
|
905 |
+
|
906 |
+
def infer_energy(self, residual, txt_enc_time_expanded, spk_vec, lens):
|
907 |
+
energy = self.energy_pred_module.infer(
|
908 |
+
residual, txt_enc_time_expanded, spk_vec, lens
|
909 |
+
)
|
910 |
+
|
911 |
+
# magic constants
|
912 |
+
if self.use_first_order_features:
|
913 |
+
energy = energy / 3
|
914 |
+
else:
|
915 |
+
energy = energy / 1.4
|
916 |
+
energy = (energy + 1) / 2
|
917 |
+
return energy
|
918 |
+
|
919 |
+
def remove_norms(self):
|
920 |
+
"""Removes spectral and weightnorms from model. Call before inference"""
|
921 |
+
for name, module in self.named_modules():
|
922 |
+
try:
|
923 |
+
nn.utils.remove_spectral_norm(module, name="weight_hh_l0")
|
924 |
+
print("Removed spectral norm from {}".format(name))
|
925 |
+
except:
|
926 |
+
pass
|
927 |
+
try:
|
928 |
+
nn.utils.remove_spectral_norm(module, name="weight_hh_l0_reverse")
|
929 |
+
print("Removed spectral norm from {}".format(name))
|
930 |
+
except:
|
931 |
+
pass
|
932 |
+
try:
|
933 |
+
nn.utils.remove_weight_norm(module)
|
934 |
+
print("Removed wnorm from {}".format(name))
|
935 |
+
except:
|
936 |
+
pass
|
requirements-dev.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ruff
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub
|
2 |
+
|
3 |
+
gradio==5.18.0
|
4 |
+
|
5 |
+
torch
|
6 |
+
torchaudio
|
7 |
+
scipy
|
8 |
+
numba
|
9 |
+
lmdb
|
10 |
+
librosa
|
11 |
+
|
12 |
+
unidecode
|
13 |
+
inflect
|
14 |
+
|
15 |
+
git+https://github.com/langtech-bsc/vocos.git@matcha
|
splines.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original Source:
|
2 |
+
# Original Source:
|
3 |
+
# https://github.com/ndeutschmann/zunis/blob/master/zunis_lib/zunis/models/flows/coupling_cells/piecewise_coupling/piecewise_linear.py
|
4 |
+
# https://github.com/ndeutschmann/zunis/blob/master/zunis_lib/zunis/models/flows/coupling_cells/piecewise_coupling/piecewise_quadratic.py
|
5 |
+
# Modifications made to jacobian computation by Yurong You and Kevin Shih
|
6 |
+
# Original License Text:
|
7 |
+
#########################################################################
|
8 |
+
|
9 |
+
# The MIT License (MIT)
|
10 |
+
# Copyright (c) 2020, nicolas deutschmann
|
11 |
+
|
12 |
+
# Permission is hereby granted, free of charge, to any person obtaining
|
13 |
+
# a copy of this software and associated documentation files (the
|
14 |
+
# "Software"), to deal in the Software without restriction, including
|
15 |
+
# without limitation the rights to use, copy, modify, merge, publish,
|
16 |
+
# distribute, sublicense, and/or sell copies of the Software, and to
|
17 |
+
# permit persons to whom the Software is furnished to do so, subject to
|
18 |
+
# the following conditions:
|
19 |
+
|
20 |
+
# The above copyright notice and this permission notice shall be
|
21 |
+
# included in all copies or substantial portions of the Software.
|
22 |
+
|
23 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
24 |
+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
25 |
+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
26 |
+
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
27 |
+
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
28 |
+
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
29 |
+
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
30 |
+
|
31 |
+
|
32 |
+
import torch
|
33 |
+
import torch.nn.functional as F
|
34 |
+
|
35 |
+
third_dimension_softmax = torch.nn.Softmax(dim=2)
|
36 |
+
|
37 |
+
|
38 |
+
def piecewise_linear_transform(
|
39 |
+
x, q_tilde, compute_jacobian=True, outlier_passthru=True
|
40 |
+
):
|
41 |
+
"""Apply an element-wise piecewise-linear transformation to some variables
|
42 |
+
|
43 |
+
Parameters
|
44 |
+
----------
|
45 |
+
x : torch.Tensor
|
46 |
+
a tensor with shape (N,k) where N is the batch dimension while k is the
|
47 |
+
dimension of the variable space. This variable span the k-dimensional unit
|
48 |
+
hypercube
|
49 |
+
|
50 |
+
q_tilde: torch.Tensor
|
51 |
+
is a tensor with shape (N,k,b) where b is the number of bins.
|
52 |
+
This contains the un-normalized heights of the bins of the piecewise-constant PDF for dimension k,
|
53 |
+
i.e. q_tilde lives in all of R and we don't impose a constraint on their sum yet.
|
54 |
+
Normalization is imposed in this function using softmax.
|
55 |
+
|
56 |
+
compute_jacobian : bool, optional
|
57 |
+
determines whether the jacobian should be compute or None is returned
|
58 |
+
|
59 |
+
Returns
|
60 |
+
-------
|
61 |
+
tuple of torch.Tensor
|
62 |
+
pair `(y,h)`.
|
63 |
+
- `y` is a tensor with shape (N,k) living in the k-dimensional unit hypercube
|
64 |
+
- `j` is the jacobian of the transformation with shape (N,) if compute_jacobian==True, else None.
|
65 |
+
"""
|
66 |
+
logj = None
|
67 |
+
|
68 |
+
# TODO bottom-up assesment of handling the differentiability of variables
|
69 |
+
# Compute the bin width w
|
70 |
+
N, k, b = q_tilde.shape
|
71 |
+
Nx, kx = x.shape
|
72 |
+
assert N == Nx and k == kx, "Shape mismatch"
|
73 |
+
|
74 |
+
w = 1.0 / b
|
75 |
+
|
76 |
+
# Compute normalized bin heights with softmax function on bin dimension
|
77 |
+
q = 1.0 / w * third_dimension_softmax(q_tilde)
|
78 |
+
# x is in the mx-th bin: x \in [0,1],
|
79 |
+
# mx \in [[0,b-1]], so we clamp away the case x == 1
|
80 |
+
mx = torch.clamp(torch.floor(b * x), 0, b - 1).to(torch.long)
|
81 |
+
# Need special error handling because trying to index with mx
|
82 |
+
# if it contains nans will lock the GPU. (device-side assert triggered)
|
83 |
+
if torch.any(torch.isnan(mx)).item() or torch.any(mx < 0) or torch.any(mx >= b):
|
84 |
+
raise Exception("NaN detected in PWLinear bin indexing")
|
85 |
+
|
86 |
+
# We compute the output variable in-place
|
87 |
+
out = x - mx * w # alpha (element of [0.,w], the position of x in its bin
|
88 |
+
|
89 |
+
# Multiply by the slope
|
90 |
+
# q has shape (N,k,b), mxu = mx.unsqueeze(-1) has shape (N,k) with entries that are a b-index
|
91 |
+
# gather defines slope[i, j, k] = q[i, j, mxu[i, j, k]] with k taking only 0 as a value
|
92 |
+
# i.e. we say slope[i, j] = q[i, j, mx [i, j]]
|
93 |
+
slopes = torch.gather(q, 2, mx.unsqueeze(-1)).squeeze(-1)
|
94 |
+
out = out * slopes
|
95 |
+
# The jacobian is the product of the slopes in all dimensions
|
96 |
+
|
97 |
+
# Compute the integral over the left-bins.
|
98 |
+
# 1. Compute all integrals: cumulative sum of bin height * bin weight.
|
99 |
+
# We want that index i contains the cumsum *strictly to the left* so we shift by 1
|
100 |
+
# leaving the first entry null, which is achieved with a roll and assignment
|
101 |
+
q_left_integrals = torch.roll(torch.cumsum(q, 2) * w, 1, 2)
|
102 |
+
q_left_integrals[:, :, 0] = 0
|
103 |
+
|
104 |
+
# 2. Access the correct index to get the left integral of each point and add it to our transformation
|
105 |
+
out = out + torch.gather(q_left_integrals, 2, mx.unsqueeze(-1)).squeeze(-1)
|
106 |
+
|
107 |
+
# Regularization: points must be strictly within the unit hypercube
|
108 |
+
# Use the dtype information from pytorch
|
109 |
+
eps = torch.finfo(out.dtype).eps
|
110 |
+
out = out.clamp(min=eps, max=1.0 - eps)
|
111 |
+
oob_mask = torch.logical_or(x < 0.0, x > 1.0).detach().float()
|
112 |
+
if outlier_passthru:
|
113 |
+
out = out * (1 - oob_mask) + x * oob_mask
|
114 |
+
slopes = slopes * (1 - oob_mask) + oob_mask
|
115 |
+
|
116 |
+
if compute_jacobian:
|
117 |
+
# logj = torch.log(torch.prod(slopes.float(), 1))
|
118 |
+
logj = torch.sum(torch.log(slopes), 1)
|
119 |
+
del slopes
|
120 |
+
|
121 |
+
return out, logj
|
122 |
+
|
123 |
+
|
124 |
+
def piecewise_linear_inverse_transform(
|
125 |
+
y, q_tilde, compute_jacobian=True, outlier_passthru=True
|
126 |
+
):
|
127 |
+
"""
|
128 |
+
Apply inverse of an element-wise piecewise-linear transformation to some
|
129 |
+
variables
|
130 |
+
|
131 |
+
Parameters
|
132 |
+
----------
|
133 |
+
y : torch.Tensor
|
134 |
+
a tensor with shape (N,k) where N is the batch dimension while k is the
|
135 |
+
dimension of the variable space. This variable span the k-dimensional unit
|
136 |
+
hypercube
|
137 |
+
|
138 |
+
q_tilde: torch.Tensor
|
139 |
+
is a tensor with shape (N,k,b) where b is the number of bins.
|
140 |
+
This contains the un-normalized heights of the bins of the piecewise-constant PDF for dimension k,
|
141 |
+
i.e. q_tilde lives in all of R and we don't impose a constraint on their sum yet.
|
142 |
+
Normalization is imposed in this function using softmax.
|
143 |
+
|
144 |
+
compute_jacobian : bool, optional
|
145 |
+
determines whether the jacobian should be compute or None is returned
|
146 |
+
|
147 |
+
Returns
|
148 |
+
-------
|
149 |
+
tuple of torch.Tensor
|
150 |
+
pair `(x,h)`.
|
151 |
+
- `x` is a tensor with shape (N,k) living in the k-dimensional unit hypercube
|
152 |
+
- `j` is the jacobian of the transformation with shape (N,) if compute_jacobian==True, else None.
|
153 |
+
"""
|
154 |
+
|
155 |
+
# TODO bottom-up assesment of handling the differentiability of variables
|
156 |
+
|
157 |
+
# Compute the bin width w
|
158 |
+
N, k, b = q_tilde.shape
|
159 |
+
Ny, ky = y.shape
|
160 |
+
assert N == Ny and k == ky, "Shape mismatch"
|
161 |
+
|
162 |
+
w = 1.0 / b
|
163 |
+
|
164 |
+
# Compute normalized bin heights with softmax function on the bin dimension
|
165 |
+
q = 1.0 / w * third_dimension_softmax(q_tilde)
|
166 |
+
|
167 |
+
# Compute the integral over the left-bins in the forward transform.
|
168 |
+
# 1. Compute all integrals: cumulative sum of bin height * bin weight.
|
169 |
+
# We want that index i contains the cumsum *strictly to the left*,
|
170 |
+
# so we shift by 1 leaving the first entry null,
|
171 |
+
# which is achieved with a roll and assignment
|
172 |
+
q_left_integrals = torch.roll(torch.cumsum(q.float(), 2) * w, 1, 2)
|
173 |
+
q_left_integrals[:, :, 0] = 0
|
174 |
+
|
175 |
+
# Find which bin each y belongs to by finding the smallest bin such that
|
176 |
+
# y - q_left_integral is positive
|
177 |
+
|
178 |
+
edges = (y.unsqueeze(-1) - q_left_integrals).detach()
|
179 |
+
# y and q_left_integrals are between 0 and 1,
|
180 |
+
# so that their difference is at most 1.
|
181 |
+
# By setting the negative values to 2., we know that the
|
182 |
+
# smallest value left is the smallest positive
|
183 |
+
edges[edges < 0] = 2.0
|
184 |
+
edges = torch.clamp(torch.argmin(edges, dim=2), 0, b - 1).to(torch.long)
|
185 |
+
|
186 |
+
# Need special error handling because trying to index with mx
|
187 |
+
# if it contains nans will lock the GPU. (device-side assert triggered)
|
188 |
+
if (
|
189 |
+
torch.any(torch.isnan(edges)).item()
|
190 |
+
or torch.any(edges < 0)
|
191 |
+
or torch.any(edges >= b)
|
192 |
+
):
|
193 |
+
raise Exception("NaN detected in PWLinear bin indexing")
|
194 |
+
|
195 |
+
# Gather the left integrals at each edge. See comment about gathering in q_left_integrals
|
196 |
+
# for the unsqueeze
|
197 |
+
q_left_integrals = q_left_integrals.gather(2, edges.unsqueeze(-1)).squeeze(-1)
|
198 |
+
|
199 |
+
# Gather the slope at each edge.
|
200 |
+
q = q.gather(2, edges.unsqueeze(-1)).squeeze(-1)
|
201 |
+
|
202 |
+
# Build the output
|
203 |
+
x = (y - q_left_integrals) / q + edges * w
|
204 |
+
|
205 |
+
# Regularization: points must be strictly within the unit hypercube
|
206 |
+
# Use the dtype information from pytorch
|
207 |
+
eps = torch.finfo(x.dtype).eps
|
208 |
+
x = x.clamp(min=eps, max=1.0 - eps)
|
209 |
+
oob_mask = torch.logical_or(y < 0.0, y > 1.0).detach().float()
|
210 |
+
if outlier_passthru:
|
211 |
+
x = x * (1 - oob_mask) + y * oob_mask
|
212 |
+
q = q * (1 - oob_mask) + oob_mask
|
213 |
+
|
214 |
+
# Prepare the jacobian
|
215 |
+
logj = None
|
216 |
+
if compute_jacobian:
|
217 |
+
# logj = - torch.log(torch.prod(q, 1))
|
218 |
+
logj = -torch.sum(torch.log(q.float()), 1)
|
219 |
+
return x.detach(), logj
|
220 |
+
|
221 |
+
|
222 |
+
def unbounded_piecewise_quadratic_transform(
|
223 |
+
x, w_tilde, v_tilde, upper=1, lower=0, inverse=False
|
224 |
+
):
|
225 |
+
assert upper > lower
|
226 |
+
_range = upper - lower
|
227 |
+
inside_interval_mask = (x >= lower) & (x < upper)
|
228 |
+
outside_interval_mask = ~inside_interval_mask
|
229 |
+
|
230 |
+
outputs = torch.zeros_like(x)
|
231 |
+
log_j = torch.zeros_like(x)
|
232 |
+
|
233 |
+
outputs[outside_interval_mask] = x[outside_interval_mask]
|
234 |
+
log_j[outside_interval_mask] = 0
|
235 |
+
|
236 |
+
output, _log_j = piecewise_quadratic_transform(
|
237 |
+
(x[inside_interval_mask] - lower) / _range,
|
238 |
+
w_tilde[inside_interval_mask, :],
|
239 |
+
v_tilde[inside_interval_mask, :],
|
240 |
+
inverse=inverse,
|
241 |
+
)
|
242 |
+
outputs[inside_interval_mask] = output * _range + lower
|
243 |
+
if not inverse:
|
244 |
+
# the before and after transformation cancel out, so the log_j would be just as it is.
|
245 |
+
log_j[inside_interval_mask] = _log_j
|
246 |
+
else:
|
247 |
+
log_j = None
|
248 |
+
return outputs, log_j
|
249 |
+
|
250 |
+
|
251 |
+
def weighted_softmax(v, w):
|
252 |
+
# to avoid NaN...
|
253 |
+
v = v - torch.max(v, dim=-1, keepdim=True)[0]
|
254 |
+
v = torch.exp(v) + 1e-8 # to avoid NaN...
|
255 |
+
v_sum = torch.sum((v[..., :-1] + v[..., 1:]) / 2 * w, dim=-1, keepdim=True)
|
256 |
+
return v / v_sum
|
257 |
+
|
258 |
+
|
259 |
+
def piecewise_quadratic_transform(x, w_tilde, v_tilde, inverse=False):
|
260 |
+
"""Element-wise piecewise-quadratic transformation
|
261 |
+
Parameters
|
262 |
+
----------
|
263 |
+
x : torch.Tensor
|
264 |
+
*, The variable spans the D-dim unit hypercube ([0,1))
|
265 |
+
w_tilde : torch.Tensor
|
266 |
+
* x K defined in the paper
|
267 |
+
v_tilde : torch.Tensor
|
268 |
+
* x (K+1) defined in the paper
|
269 |
+
inverse : bool
|
270 |
+
forward or inverse
|
271 |
+
Returns
|
272 |
+
-------
|
273 |
+
c : torch.Tensor
|
274 |
+
*, transformed value
|
275 |
+
log_j : torch.Tensor
|
276 |
+
*, log determinant of the Jacobian matrix
|
277 |
+
"""
|
278 |
+
w = torch.softmax(w_tilde, dim=-1)
|
279 |
+
v = weighted_softmax(v_tilde, w)
|
280 |
+
w_cumsum = torch.cumsum(w, dim=-1)
|
281 |
+
# force sum = 1
|
282 |
+
w_cumsum[..., -1] = 1.0
|
283 |
+
w_cumsum_shift = F.pad(w_cumsum, (1, 0), "constant", 0)
|
284 |
+
cdf = torch.cumsum((v[..., 1:] + v[..., :-1]) / 2 * w, dim=-1)
|
285 |
+
# force sum = 1
|
286 |
+
cdf[..., -1] = 1.0
|
287 |
+
cdf_shift = F.pad(cdf, (1, 0), "constant", 0)
|
288 |
+
|
289 |
+
if not inverse:
|
290 |
+
# * x D x 1, (w_cumsum[idx-1] < x <= w_cumsum[idx])
|
291 |
+
bin_index = torch.searchsorted(w_cumsum, x.unsqueeze(-1))
|
292 |
+
else:
|
293 |
+
# * x D x 1, (cdf[idx-1] < x <= cdf[idx])
|
294 |
+
bin_index = torch.searchsorted(cdf, x.unsqueeze(-1))
|
295 |
+
|
296 |
+
w_b = torch.gather(w, -1, bin_index).squeeze(-1)
|
297 |
+
w_bn1 = torch.gather(w_cumsum_shift, -1, bin_index).squeeze(-1)
|
298 |
+
v_b = torch.gather(v, -1, bin_index).squeeze(-1)
|
299 |
+
v_bp1 = torch.gather(v, -1, bin_index + 1).squeeze(-1)
|
300 |
+
cdf_bn1 = torch.gather(cdf_shift, -1, bin_index).squeeze(-1)
|
301 |
+
|
302 |
+
if not inverse:
|
303 |
+
alpha = (x - w_bn1) / w_b.clamp(min=torch.finfo(w_b.dtype).eps)
|
304 |
+
c = (alpha**2) / 2 * (v_bp1 - v_b) * w_b + alpha * v_b * w_b + cdf_bn1
|
305 |
+
|
306 |
+
# just sum of log pdfs
|
307 |
+
log_j = torch.lerp(v_b, v_bp1, alpha).clamp(min=torch.finfo(c.dtype).eps).log()
|
308 |
+
|
309 |
+
# make sure it falls into [0,1)
|
310 |
+
c = c.clamp(min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(c.dtype).eps)
|
311 |
+
return c, log_j
|
312 |
+
else:
|
313 |
+
# quadratic equation for alpha
|
314 |
+
# alpha should fall into (0, 1]. Since a, b > 0, the symmetry axis -b/2a < 0 and we should pick the larger root
|
315 |
+
# skip calculating the log_j in inverse since we don't need it
|
316 |
+
a = (v_bp1 - v_b) * w_b / 2
|
317 |
+
b = v_b * w_b
|
318 |
+
c = cdf_bn1 - x
|
319 |
+
alpha = (-b + torch.sqrt((b**2) - 4 * a * c)) / (2 * a)
|
320 |
+
inv = alpha * w_b + w_bn1
|
321 |
+
|
322 |
+
# make sure it falls into [0,1)
|
323 |
+
inv = inv.clamp(
|
324 |
+
min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(inv.dtype).eps
|
325 |
+
)
|
326 |
+
return inv, None
|
transformer.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adapted from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/transformer.py
|
2 |
+
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7 |
+
#
|
8 |
+
# Unless required by applicable law or agreed to in writing, software
|
9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11 |
+
# See the License for the specific language governing permissions and
|
12 |
+
# limitations under the License.
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
from common import get_mask_from_lengths, LinearNorm
|
19 |
+
|
20 |
+
|
21 |
+
class PositionalEmbedding(nn.Module):
|
22 |
+
def __init__(self, demb):
|
23 |
+
super(PositionalEmbedding, self).__init__()
|
24 |
+
self.demb = demb
|
25 |
+
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
|
26 |
+
self.register_buffer("inv_freq", inv_freq)
|
27 |
+
|
28 |
+
def forward(self, pos_seq, bsz=None):
|
29 |
+
sinusoid_inp = torch.matmul(
|
30 |
+
torch.unsqueeze(pos_seq, -1), torch.unsqueeze(self.inv_freq, 0)
|
31 |
+
)
|
32 |
+
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
|
33 |
+
if bsz is not None:
|
34 |
+
return pos_emb[None, :, :].expand(bsz, -1, -1)
|
35 |
+
else:
|
36 |
+
return pos_emb[None, :, :]
|
37 |
+
|
38 |
+
|
39 |
+
class PositionwiseConvFF(nn.Module):
|
40 |
+
def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
|
41 |
+
super(PositionwiseConvFF, self).__init__()
|
42 |
+
|
43 |
+
self.d_model = d_model
|
44 |
+
self.d_inner = d_inner
|
45 |
+
self.dropout = dropout
|
46 |
+
|
47 |
+
self.CoreNet = nn.Sequential(
|
48 |
+
nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
|
49 |
+
nn.ReLU(),
|
50 |
+
# nn.Dropout(dropout), # worse convergence
|
51 |
+
nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
|
52 |
+
nn.Dropout(dropout),
|
53 |
+
)
|
54 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
55 |
+
self.pre_lnorm = pre_lnorm
|
56 |
+
|
57 |
+
def forward(self, inp):
|
58 |
+
return self._forward(inp)
|
59 |
+
|
60 |
+
def _forward(self, inp):
|
61 |
+
if self.pre_lnorm:
|
62 |
+
# layer normalization + positionwise feed-forward
|
63 |
+
core_out = inp.transpose(1, 2)
|
64 |
+
core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype))
|
65 |
+
core_out = core_out.transpose(1, 2)
|
66 |
+
|
67 |
+
# residual connection
|
68 |
+
output = core_out + inp
|
69 |
+
else:
|
70 |
+
# positionwise feed-forward
|
71 |
+
core_out = inp.transpose(1, 2)
|
72 |
+
core_out = self.CoreNet(core_out)
|
73 |
+
core_out = core_out.transpose(1, 2)
|
74 |
+
|
75 |
+
# residual connection + layer normalization
|
76 |
+
output = self.layer_norm(inp + core_out).to(inp.dtype)
|
77 |
+
|
78 |
+
return output
|
79 |
+
|
80 |
+
|
81 |
+
class MultiHeadAttn(nn.Module):
|
82 |
+
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1, pre_lnorm=False):
|
83 |
+
super(MultiHeadAttn, self).__init__()
|
84 |
+
|
85 |
+
self.n_head = n_head
|
86 |
+
self.d_model = d_model
|
87 |
+
self.d_head = d_head
|
88 |
+
self.scale = 1 / (d_head**0.5)
|
89 |
+
self.pre_lnorm = pre_lnorm
|
90 |
+
|
91 |
+
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
|
92 |
+
self.drop = nn.Dropout(dropout)
|
93 |
+
self.dropatt = nn.Dropout(dropatt)
|
94 |
+
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
95 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
96 |
+
|
97 |
+
def forward(self, inp, attn_mask=None):
|
98 |
+
return self._forward(inp, attn_mask)
|
99 |
+
|
100 |
+
def _forward(self, inp, attn_mask=None):
|
101 |
+
residual = inp
|
102 |
+
|
103 |
+
if self.pre_lnorm:
|
104 |
+
# layer normalization
|
105 |
+
inp = self.layer_norm(inp)
|
106 |
+
|
107 |
+
n_head, d_head = self.n_head, self.d_head
|
108 |
+
|
109 |
+
head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)
|
110 |
+
head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
|
111 |
+
head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
|
112 |
+
head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
|
113 |
+
|
114 |
+
q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
115 |
+
k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
116 |
+
v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
|
117 |
+
|
118 |
+
attn_score = torch.bmm(q, k.transpose(1, 2))
|
119 |
+
attn_score.mul_(self.scale)
|
120 |
+
|
121 |
+
if attn_mask is not None:
|
122 |
+
attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype)
|
123 |
+
attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
|
124 |
+
attn_score.masked_fill_(attn_mask.to(torch.bool), -float("inf"))
|
125 |
+
|
126 |
+
attn_prob = F.softmax(attn_score, dim=2)
|
127 |
+
attn_prob = self.dropatt(attn_prob)
|
128 |
+
attn_vec = torch.bmm(attn_prob, v)
|
129 |
+
|
130 |
+
attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
|
131 |
+
attn_vec = (
|
132 |
+
attn_vec.permute(1, 2, 0, 3)
|
133 |
+
.contiguous()
|
134 |
+
.view(inp.size(0), inp.size(1), n_head * d_head)
|
135 |
+
)
|
136 |
+
|
137 |
+
# linear projection
|
138 |
+
attn_out = self.o_net(attn_vec)
|
139 |
+
attn_out = self.drop(attn_out)
|
140 |
+
|
141 |
+
# residual connection + layer normalization
|
142 |
+
output = self.layer_norm(residual + attn_out)
|
143 |
+
|
144 |
+
output = output.to(attn_out.dtype)
|
145 |
+
|
146 |
+
return output
|
147 |
+
|
148 |
+
|
149 |
+
class TransformerLayer(nn.Module):
|
150 |
+
def __init__(
|
151 |
+
self, n_head, d_model, d_head, d_inner, kernel_size, dropout, **kwargs
|
152 |
+
):
|
153 |
+
super(TransformerLayer, self).__init__()
|
154 |
+
|
155 |
+
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
|
156 |
+
self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout)
|
157 |
+
|
158 |
+
def forward(self, dec_inp, mask=None):
|
159 |
+
output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
|
160 |
+
output *= mask
|
161 |
+
output = self.pos_ff(output)
|
162 |
+
output *= mask
|
163 |
+
return output
|
164 |
+
|
165 |
+
|
166 |
+
class FFTransformer(nn.Module):
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
in_dim,
|
170 |
+
out_dim=1,
|
171 |
+
n_layers=6,
|
172 |
+
n_head=1,
|
173 |
+
d_head=64,
|
174 |
+
d_inner=1024,
|
175 |
+
kernel_size=3,
|
176 |
+
dropout=0.1,
|
177 |
+
dropatt=0.1,
|
178 |
+
dropemb=0.0,
|
179 |
+
):
|
180 |
+
super(FFTransformer, self).__init__()
|
181 |
+
self.in_dim = in_dim
|
182 |
+
self.out_dim = out_dim
|
183 |
+
self.n_head = n_head
|
184 |
+
self.d_head = d_head
|
185 |
+
|
186 |
+
self.pos_emb = PositionalEmbedding(self.in_dim)
|
187 |
+
self.drop = nn.Dropout(dropemb)
|
188 |
+
self.layers = nn.ModuleList()
|
189 |
+
|
190 |
+
for _ in range(n_layers):
|
191 |
+
self.layers.append(
|
192 |
+
TransformerLayer(
|
193 |
+
n_head,
|
194 |
+
in_dim,
|
195 |
+
d_head,
|
196 |
+
d_inner,
|
197 |
+
kernel_size,
|
198 |
+
dropout,
|
199 |
+
dropatt=dropatt,
|
200 |
+
)
|
201 |
+
)
|
202 |
+
|
203 |
+
self.dense = LinearNorm(in_dim, out_dim)
|
204 |
+
|
205 |
+
def forward(self, dec_inp, in_lens):
|
206 |
+
# B, C, T --> B, T, C
|
207 |
+
inp = dec_inp.transpose(1, 2)
|
208 |
+
mask = get_mask_from_lengths(in_lens)[..., None]
|
209 |
+
|
210 |
+
pos_seq = torch.arange(inp.size(1), device=inp.device).to(inp.dtype)
|
211 |
+
pos_emb = self.pos_emb(pos_seq) * mask
|
212 |
+
|
213 |
+
out = self.drop(inp + pos_emb)
|
214 |
+
|
215 |
+
for layer in self.layers:
|
216 |
+
out = layer(out, mask=mask)
|
217 |
+
|
218 |
+
out = self.dense(out).transpose(1, 2)
|
219 |
+
return out
|
tts_text_processing/LICENSE
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2017 Keith Ito
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
of this software and associated documentation files (the "Software"), to deal
|
5 |
+
in the Software without restriction, including without limitation the rights
|
6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
copies of the Software, and to permit persons to whom the Software is
|
8 |
+
furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in
|
11 |
+
all copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
19 |
+
THE SOFTWARE.
|
tts_text_processing/abbreviations.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
_no_period_re = re.compile(r"(No[.])(?=[ ]?[0-9])")
|
4 |
+
_percent_re = re.compile(r"([ ]?[%])")
|
5 |
+
_half_re = re.compile("([0-9]½)|(½)")
|
6 |
+
|
7 |
+
|
8 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
9 |
+
_abbreviations = [
|
10 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
11 |
+
for x in [
|
12 |
+
("mrs", "misess"),
|
13 |
+
("ms", "miss"),
|
14 |
+
("mr", "mister"),
|
15 |
+
("dr", "doctor"),
|
16 |
+
("st", "saint"),
|
17 |
+
("co", "company"),
|
18 |
+
("jr", "junior"),
|
19 |
+
("maj", "major"),
|
20 |
+
("gen", "general"),
|
21 |
+
("drs", "doctors"),
|
22 |
+
("rev", "reverend"),
|
23 |
+
("lt", "lieutenant"),
|
24 |
+
("hon", "honorable"),
|
25 |
+
("sgt", "sergeant"),
|
26 |
+
("capt", "captain"),
|
27 |
+
("esq", "esquire"),
|
28 |
+
("ltd", "limited"),
|
29 |
+
("col", "colonel"),
|
30 |
+
("ft", "fort"),
|
31 |
+
]
|
32 |
+
]
|
33 |
+
|
34 |
+
|
35 |
+
def _expand_no_period(m):
|
36 |
+
word = m.group(0)
|
37 |
+
if word[0] == "N":
|
38 |
+
return "Number"
|
39 |
+
return "number"
|
40 |
+
|
41 |
+
|
42 |
+
def _expand_percent(m):
|
43 |
+
return " percent"
|
44 |
+
|
45 |
+
|
46 |
+
def _expand_half(m):
|
47 |
+
word = m.group(1)
|
48 |
+
if word is None:
|
49 |
+
return "half"
|
50 |
+
return word[0] + " and a half"
|
51 |
+
|
52 |
+
|
53 |
+
def normalize_abbreviations(text):
|
54 |
+
text = re.sub(_no_period_re, _expand_no_period, text)
|
55 |
+
text = re.sub(_percent_re, _expand_percent, text)
|
56 |
+
text = re.sub(_half_re, _expand_half, text)
|
57 |
+
return text
|
tts_text_processing/acronyms.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
_letter_to_arpabet = {
|
4 |
+
"A": "EY1",
|
5 |
+
"B": "B IY1",
|
6 |
+
"C": "S IY1",
|
7 |
+
"D": "D IY1",
|
8 |
+
"E": "IY1",
|
9 |
+
"F": "EH1 F",
|
10 |
+
"G": "JH IY1",
|
11 |
+
"H": "EY1 CH",
|
12 |
+
"I": "AY1",
|
13 |
+
"J": "JH EY1",
|
14 |
+
"K": "K EY1",
|
15 |
+
"L": "EH1 L",
|
16 |
+
"M": "EH1 M",
|
17 |
+
"N": "EH1 N",
|
18 |
+
"O": "OW1",
|
19 |
+
"P": "P IY1",
|
20 |
+
"Q": "K Y UW1",
|
21 |
+
"R": "AA1 R",
|
22 |
+
"S": "EH1 S",
|
23 |
+
"T": "T IY1",
|
24 |
+
"U": "Y UW1",
|
25 |
+
"V": "V IY1",
|
26 |
+
"X": "EH1 K S",
|
27 |
+
"Y": "W AY1",
|
28 |
+
"W": "D AH1 B AH0 L Y UW0",
|
29 |
+
"Z": "Z IY1",
|
30 |
+
"s": "Z",
|
31 |
+
}
|
32 |
+
|
33 |
+
# must ignore roman numerals
|
34 |
+
# _acronym_re = re.compile(r'([A-Z][A-Z]+)s?|([A-Z]\.([A-Z]\.)+s?)')
|
35 |
+
_acronym_re = re.compile(r"([A-Z][A-Z]+)s?")
|
36 |
+
|
37 |
+
|
38 |
+
class AcronymNormalizer(object):
|
39 |
+
def __init__(self, phoneme_dict):
|
40 |
+
self.phoneme_dict = phoneme_dict
|
41 |
+
|
42 |
+
def normalize_acronyms(self, text):
|
43 |
+
def _expand_acronyms(m, add_spaces=True):
|
44 |
+
acronym = m.group(0)
|
45 |
+
# remove dots if they exist
|
46 |
+
acronym = re.sub("\.", "", acronym)
|
47 |
+
|
48 |
+
acronym = "".join(acronym.split())
|
49 |
+
arpabet = self.phoneme_dict.lookup(acronym)
|
50 |
+
|
51 |
+
if arpabet is None:
|
52 |
+
acronym = list(acronym)
|
53 |
+
arpabet = ["{" + _letter_to_arpabet[letter] + "}" for letter in acronym]
|
54 |
+
# temporary fix
|
55 |
+
if arpabet[-1] == "{Z}" and len(arpabet) > 1:
|
56 |
+
arpabet[-2] = arpabet[-2][:-1] + " " + arpabet[-1][1:]
|
57 |
+
del arpabet[-1]
|
58 |
+
arpabet = " ".join(arpabet)
|
59 |
+
elif len(arpabet) == 1:
|
60 |
+
arpabet = "{" + arpabet[0] + "}"
|
61 |
+
else:
|
62 |
+
arpabet = acronym
|
63 |
+
return arpabet
|
64 |
+
|
65 |
+
text = re.sub(_acronym_re, _expand_acronyms, text)
|
66 |
+
return text
|
67 |
+
|
68 |
+
def __call__(self, text):
|
69 |
+
return self.normalize_acronyms(text)
|
tts_text_processing/cleaners.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
+
|
3 |
+
"""
|
4 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
5 |
+
|
6 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
7 |
+
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
8 |
+
1. "english_cleaners" for English text
|
9 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
10 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
11 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
12 |
+
the symbols in symbols.py to match your data).
|
13 |
+
"""
|
14 |
+
|
15 |
+
import re
|
16 |
+
from string import punctuation
|
17 |
+
from functools import reduce
|
18 |
+
from unidecode import unidecode
|
19 |
+
from .numerical import normalize_numbers, normalize_currency
|
20 |
+
from .acronyms import AcronymNormalizer
|
21 |
+
from .datestime import normalize_datestime
|
22 |
+
from .letters_and_numbers import normalize_letters_and_numbers
|
23 |
+
from .abbreviations import normalize_abbreviations
|
24 |
+
|
25 |
+
|
26 |
+
# Regular expression matching whitespace:
|
27 |
+
_whitespace_re = re.compile(r"\s+")
|
28 |
+
|
29 |
+
# Regular expression separating words enclosed in curly braces for cleaning
|
30 |
+
_arpa_re = re.compile(r"{[^}]+}|\S+")
|
31 |
+
|
32 |
+
|
33 |
+
def expand_abbreviations(text):
|
34 |
+
return normalize_abbreviations(text)
|
35 |
+
|
36 |
+
|
37 |
+
def expand_numbers(text):
|
38 |
+
return normalize_numbers(text)
|
39 |
+
|
40 |
+
|
41 |
+
def expand_currency(text):
|
42 |
+
return normalize_currency(text)
|
43 |
+
|
44 |
+
|
45 |
+
def expand_datestime(text):
|
46 |
+
return normalize_datestime(text)
|
47 |
+
|
48 |
+
|
49 |
+
def expand_letters_and_numbers(text):
|
50 |
+
return normalize_letters_and_numbers(text)
|
51 |
+
|
52 |
+
|
53 |
+
def lowercase(text):
|
54 |
+
return text.lower()
|
55 |
+
|
56 |
+
|
57 |
+
def collapse_whitespace(text):
|
58 |
+
return re.sub(_whitespace_re, " ", text)
|
59 |
+
|
60 |
+
|
61 |
+
def separate_acronyms(text):
|
62 |
+
text = re.sub(r"([0-9]+)([a-zA-Z]+)", r"\1 \2", text)
|
63 |
+
text = re.sub(r"([a-zA-Z]+)([0-9]+)", r"\1 \2", text)
|
64 |
+
return text
|
65 |
+
|
66 |
+
|
67 |
+
def convert_to_ascii(text):
|
68 |
+
return unidecode(text)
|
69 |
+
|
70 |
+
|
71 |
+
def dehyphenize_compound_words(text):
|
72 |
+
text = re.sub(r"(?<=[a-zA-Z0-9])-(?=[a-zA-Z])", " ", text)
|
73 |
+
return text
|
74 |
+
|
75 |
+
|
76 |
+
def remove_space_before_punctuation(text):
|
77 |
+
return re.sub(r"\s([{}](?:\s|$))".format(punctuation), r"\1", text)
|
78 |
+
|
79 |
+
|
80 |
+
class Cleaner(object):
|
81 |
+
def __init__(self, cleaner_names, phonemedict):
|
82 |
+
self.cleaner_names = cleaner_names
|
83 |
+
self.phonemedict = phonemedict
|
84 |
+
self.acronym_normalizer = AcronymNormalizer(self.phonemedict)
|
85 |
+
|
86 |
+
def __call__(self, text):
|
87 |
+
for cleaner_name in self.cleaner_names:
|
88 |
+
sequence_fns, word_fns = self.get_cleaner_fns(cleaner_name)
|
89 |
+
for fn in sequence_fns:
|
90 |
+
text = fn(text)
|
91 |
+
|
92 |
+
text = [
|
93 |
+
reduce(lambda x, y: y(x), word_fns, split) if split[0] != "{" else split
|
94 |
+
for split in _arpa_re.findall(text)
|
95 |
+
]
|
96 |
+
text = " ".join(text)
|
97 |
+
text = remove_space_before_punctuation(text)
|
98 |
+
return text
|
99 |
+
|
100 |
+
def get_cleaner_fns(self, cleaner_name):
|
101 |
+
if cleaner_name == "basic_cleaners":
|
102 |
+
sequence_fns = [lowercase, collapse_whitespace]
|
103 |
+
word_fns = []
|
104 |
+
elif cleaner_name == "english_cleaners":
|
105 |
+
sequence_fns = [collapse_whitespace, convert_to_ascii, lowercase]
|
106 |
+
word_fns = [expand_numbers, expand_abbreviations]
|
107 |
+
elif cleaner_name == "radtts_cleaners":
|
108 |
+
sequence_fns = [
|
109 |
+
collapse_whitespace,
|
110 |
+
expand_currency,
|
111 |
+
expand_datestime,
|
112 |
+
expand_letters_and_numbers,
|
113 |
+
]
|
114 |
+
word_fns = [expand_numbers, expand_abbreviations]
|
115 |
+
elif cleaner_name == "ukrainian_cleaners":
|
116 |
+
sequence_fns = [lowercase, collapse_whitespace]
|
117 |
+
word_fns = []
|
118 |
+
elif cleaner_name == "transliteration_cleaners":
|
119 |
+
sequence_fns = [convert_to_ascii, lowercase, collapse_whitespace]
|
120 |
+
else:
|
121 |
+
raise Exception("{} cleaner not supported".format(cleaner_name))
|
122 |
+
|
123 |
+
return sequence_fns, word_fns
|
tts_text_processing/cmudict.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
+
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
valid_symbols = [
|
7 |
+
"AA",
|
8 |
+
"AA0",
|
9 |
+
"AA1",
|
10 |
+
"AA2",
|
11 |
+
"AE",
|
12 |
+
"AE0",
|
13 |
+
"AE1",
|
14 |
+
"AE2",
|
15 |
+
"AH",
|
16 |
+
"AH0",
|
17 |
+
"AH1",
|
18 |
+
"AH2",
|
19 |
+
"AO",
|
20 |
+
"AO0",
|
21 |
+
"AO1",
|
22 |
+
"AO2",
|
23 |
+
"AW",
|
24 |
+
"AW0",
|
25 |
+
"AW1",
|
26 |
+
"AW2",
|
27 |
+
"AY",
|
28 |
+
"AY0",
|
29 |
+
"AY1",
|
30 |
+
"AY2",
|
31 |
+
"B",
|
32 |
+
"CH",
|
33 |
+
"D",
|
34 |
+
"DH",
|
35 |
+
"EH",
|
36 |
+
"EH0",
|
37 |
+
"EH1",
|
38 |
+
"EH2",
|
39 |
+
"ER",
|
40 |
+
"ER0",
|
41 |
+
"ER1",
|
42 |
+
"ER2",
|
43 |
+
"EY",
|
44 |
+
"EY0",
|
45 |
+
"EY1",
|
46 |
+
"EY2",
|
47 |
+
"F",
|
48 |
+
"G",
|
49 |
+
"HH",
|
50 |
+
"IH",
|
51 |
+
"IH0",
|
52 |
+
"IH1",
|
53 |
+
"IH2",
|
54 |
+
"IY",
|
55 |
+
"IY0",
|
56 |
+
"IY1",
|
57 |
+
"IY2",
|
58 |
+
"JH",
|
59 |
+
"K",
|
60 |
+
"L",
|
61 |
+
"M",
|
62 |
+
"N",
|
63 |
+
"NG",
|
64 |
+
"OW",
|
65 |
+
"OW0",
|
66 |
+
"OW1",
|
67 |
+
"OW2",
|
68 |
+
"OY",
|
69 |
+
"OY0",
|
70 |
+
"OY1",
|
71 |
+
"OY2",
|
72 |
+
"P",
|
73 |
+
"R",
|
74 |
+
"S",
|
75 |
+
"SH",
|
76 |
+
"T",
|
77 |
+
"TH",
|
78 |
+
"UH",
|
79 |
+
"UH0",
|
80 |
+
"UH1",
|
81 |
+
"UH2",
|
82 |
+
"UW",
|
83 |
+
"UW0",
|
84 |
+
"UW1",
|
85 |
+
"UW2",
|
86 |
+
"V",
|
87 |
+
"W",
|
88 |
+
"Y",
|
89 |
+
"Z",
|
90 |
+
"ZH",
|
91 |
+
]
|
92 |
+
|
93 |
+
_valid_symbol_set = set(valid_symbols)
|
94 |
+
|
95 |
+
|
96 |
+
class CMUDict:
|
97 |
+
"""Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
|
98 |
+
|
99 |
+
def __init__(self, file_or_path, keep_ambiguous=True):
|
100 |
+
if isinstance(file_or_path, str):
|
101 |
+
with open(file_or_path, encoding="latin-1") as f:
|
102 |
+
entries = _parse_cmudict(f)
|
103 |
+
else:
|
104 |
+
entries = _parse_cmudict(file_or_path)
|
105 |
+
if not keep_ambiguous:
|
106 |
+
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
107 |
+
self._entries = entries
|
108 |
+
|
109 |
+
def __len__(self):
|
110 |
+
return len(self._entries)
|
111 |
+
|
112 |
+
def lookup(self, word):
|
113 |
+
"""Returns list of ARPAbet pronunciations of the given word."""
|
114 |
+
return self._entries.get(word.upper())
|
115 |
+
|
116 |
+
|
117 |
+
_alt_re = re.compile(r"\([0-9]+\)")
|
118 |
+
|
119 |
+
|
120 |
+
def _parse_cmudict(file):
|
121 |
+
cmudict = {}
|
122 |
+
for line in file:
|
123 |
+
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
|
124 |
+
parts = line.split(" ")
|
125 |
+
word = re.sub(_alt_re, "", parts[0])
|
126 |
+
pronunciation = _get_pronunciation(parts[1])
|
127 |
+
if pronunciation:
|
128 |
+
if word in cmudict:
|
129 |
+
cmudict[word].append(pronunciation)
|
130 |
+
else:
|
131 |
+
cmudict[word] = [pronunciation]
|
132 |
+
return cmudict
|
133 |
+
|
134 |
+
|
135 |
+
def _get_pronunciation(s):
|
136 |
+
parts = s.strip().split(" ")
|
137 |
+
for part in parts:
|
138 |
+
if part not in _valid_symbol_set:
|
139 |
+
return None
|
140 |
+
return " ".join(parts)
|
tts_text_processing/datestime.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
+
|
3 |
+
import re
|
4 |
+
|
5 |
+
_ampm_re = re.compile(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):?([0-5][0-9])?\s*([AaPp][Mm]\b)")
|
6 |
+
|
7 |
+
|
8 |
+
def _expand_ampm(m):
|
9 |
+
matches = list(m.groups(0))
|
10 |
+
txt = matches[0]
|
11 |
+
txt = txt if int(matches[1]) == 0 else txt + " " + matches[1]
|
12 |
+
|
13 |
+
if matches[2][0].lower() == "a":
|
14 |
+
txt += " a.m."
|
15 |
+
elif matches[2][0].lower() == "p":
|
16 |
+
txt += " p.m."
|
17 |
+
|
18 |
+
return txt
|
19 |
+
|
20 |
+
|
21 |
+
def normalize_datestime(text):
|
22 |
+
text = re.sub(_ampm_re, _expand_ampm, text)
|
23 |
+
# text = re.sub(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])?", r"\1 \2", text)
|
24 |
+
return text
|
tts_text_processing/grapheme_dictionary.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
+
|
3 |
+
import re
|
4 |
+
|
5 |
+
_alt_re = re.compile(r"\([0-9]+\)")
|
6 |
+
|
7 |
+
|
8 |
+
class Grapheme2PhonemeDictionary:
|
9 |
+
"""Thin wrapper around g2p data."""
|
10 |
+
|
11 |
+
def __init__(self, file_or_path, keep_ambiguous=True, encoding="latin-1"):
|
12 |
+
with open(file_or_path, encoding=encoding) as f:
|
13 |
+
entries = _parse_g2p(f)
|
14 |
+
if not keep_ambiguous:
|
15 |
+
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
16 |
+
self._entries = entries
|
17 |
+
|
18 |
+
def __len__(self):
|
19 |
+
return len(self._entries)
|
20 |
+
|
21 |
+
def lookup(self, word):
|
22 |
+
"""Returns list of pronunciations of the given word."""
|
23 |
+
return self._entries.get(word.upper())
|
24 |
+
|
25 |
+
|
26 |
+
def _parse_g2p(file):
|
27 |
+
g2p = {}
|
28 |
+
for line in file:
|
29 |
+
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
|
30 |
+
parts = line.split(" ")
|
31 |
+
word = re.sub(_alt_re, "", parts[0])
|
32 |
+
pronunciation = parts[1].strip()
|
33 |
+
if word in g2p:
|
34 |
+
g2p[word].append(pronunciation)
|
35 |
+
else:
|
36 |
+
g2p[word] = [pronunciation]
|
37 |
+
return g2p
|
tts_text_processing/heteronyms
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
abject
|
2 |
+
abrogate
|
3 |
+
absent
|
4 |
+
abstract
|
5 |
+
abuse
|
6 |
+
ache
|
7 |
+
acre
|
8 |
+
acuminate
|
9 |
+
addict
|
10 |
+
address
|
11 |
+
adduct
|
12 |
+
adele
|
13 |
+
advocate
|
14 |
+
affect
|
15 |
+
affiliate
|
16 |
+
agape
|
17 |
+
aged
|
18 |
+
agglomerate
|
19 |
+
aggregate
|
20 |
+
agonic
|
21 |
+
agora
|
22 |
+
allied
|
23 |
+
ally
|
24 |
+
alternate
|
25 |
+
alum
|
26 |
+
am
|
27 |
+
analyses
|
28 |
+
andrea
|
29 |
+
animate
|
30 |
+
apply
|
31 |
+
appropriate
|
32 |
+
approximate
|
33 |
+
ares
|
34 |
+
arithmetic
|
35 |
+
arsenic
|
36 |
+
articulate
|
37 |
+
associate
|
38 |
+
attribute
|
39 |
+
august
|
40 |
+
axes
|
41 |
+
ay
|
42 |
+
aye
|
43 |
+
bases
|
44 |
+
bass
|
45 |
+
bathed
|
46 |
+
bested
|
47 |
+
bifurcate
|
48 |
+
blessed
|
49 |
+
blotto
|
50 |
+
bow
|
51 |
+
bowed
|
52 |
+
bowman
|
53 |
+
brassy
|
54 |
+
buffet
|
55 |
+
bustier
|
56 |
+
carbonate
|
57 |
+
celtic
|
58 |
+
choral
|
59 |
+
chumash
|
60 |
+
close
|
61 |
+
closer
|
62 |
+
coax
|
63 |
+
coincidence
|
64 |
+
color coordinate
|
65 |
+
colour coordinate
|
66 |
+
comber
|
67 |
+
combine
|
68 |
+
combs
|
69 |
+
committee
|
70 |
+
commune
|
71 |
+
compact
|
72 |
+
complex
|
73 |
+
compound
|
74 |
+
compress
|
75 |
+
concert
|
76 |
+
conduct
|
77 |
+
confine
|
78 |
+
confines
|
79 |
+
conflict
|
80 |
+
conglomerate
|
81 |
+
conscript
|
82 |
+
conserve
|
83 |
+
consist
|
84 |
+
console
|
85 |
+
consort
|
86 |
+
construct
|
87 |
+
consult
|
88 |
+
consummate
|
89 |
+
content
|
90 |
+
contest
|
91 |
+
contract
|
92 |
+
contracts
|
93 |
+
contrast
|
94 |
+
converse
|
95 |
+
convert
|
96 |
+
convict
|
97 |
+
coop
|
98 |
+
coordinate
|
99 |
+
covey
|
100 |
+
crooked
|
101 |
+
curate
|
102 |
+
cussed
|
103 |
+
decollate
|
104 |
+
decrease
|
105 |
+
defect
|
106 |
+
defense
|
107 |
+
delegate
|
108 |
+
deliberate
|
109 |
+
denier
|
110 |
+
desert
|
111 |
+
detail
|
112 |
+
deviate
|
113 |
+
diagnoses
|
114 |
+
diffuse
|
115 |
+
digest
|
116 |
+
discard
|
117 |
+
discharge
|
118 |
+
discount
|
119 |
+
do
|
120 |
+
document
|
121 |
+
does
|
122 |
+
dogged
|
123 |
+
domesticate
|
124 |
+
dominican
|
125 |
+
dove
|
126 |
+
dr
|
127 |
+
drawer
|
128 |
+
duplicate
|
129 |
+
egress
|
130 |
+
ejaculate
|
131 |
+
eject
|
132 |
+
elaborate
|
133 |
+
ellipses
|
134 |
+
email
|
135 |
+
emu
|
136 |
+
entrace
|
137 |
+
entrance
|
138 |
+
escort
|
139 |
+
estimate
|
140 |
+
eta
|
141 |
+
etna
|
142 |
+
evening
|
143 |
+
excise
|
144 |
+
excuse
|
145 |
+
exploit
|
146 |
+
export
|
147 |
+
extract
|
148 |
+
fine
|
149 |
+
flower
|
150 |
+
forbear
|
151 |
+
four-legged
|
152 |
+
frequent
|
153 |
+
furrier
|
154 |
+
gallant
|
155 |
+
gel
|
156 |
+
geminate
|
157 |
+
gillie
|
158 |
+
glower
|
159 |
+
gotham
|
160 |
+
graduate
|
161 |
+
haggis
|
162 |
+
heavy
|
163 |
+
hinder
|
164 |
+
house
|
165 |
+
housewife
|
166 |
+
impact
|
167 |
+
imped
|
168 |
+
implant
|
169 |
+
implement
|
170 |
+
import
|
171 |
+
impress
|
172 |
+
incense
|
173 |
+
incline
|
174 |
+
increase
|
175 |
+
infix
|
176 |
+
insert
|
177 |
+
instar
|
178 |
+
insult
|
179 |
+
integral
|
180 |
+
intercept
|
181 |
+
interchange
|
182 |
+
interflow
|
183 |
+
interleaf
|
184 |
+
intermediate
|
185 |
+
intern
|
186 |
+
interspace
|
187 |
+
intimate
|
188 |
+
intrigue
|
189 |
+
invalid
|
190 |
+
invert
|
191 |
+
invite
|
192 |
+
irony
|
193 |
+
jagged
|
194 |
+
jesses
|
195 |
+
julies
|
196 |
+
kite
|
197 |
+
laminate
|
198 |
+
laos
|
199 |
+
lather
|
200 |
+
lead
|
201 |
+
learned
|
202 |
+
leasing
|
203 |
+
lech
|
204 |
+
legitimate
|
205 |
+
lied
|
206 |
+
lima
|
207 |
+
lipread
|
208 |
+
live
|
209 |
+
lower
|
210 |
+
lunged
|
211 |
+
maas
|
212 |
+
magdalen
|
213 |
+
manes
|
214 |
+
mare
|
215 |
+
marked
|
216 |
+
merchandise
|
217 |
+
merlion
|
218 |
+
minute
|
219 |
+
misconduct
|
220 |
+
misled
|
221 |
+
misprint
|
222 |
+
mobile
|
223 |
+
moderate
|
224 |
+
mong
|
225 |
+
moped
|
226 |
+
moth
|
227 |
+
mouth
|
228 |
+
mow
|
229 |
+
mpg
|
230 |
+
multiply
|
231 |
+
mush
|
232 |
+
nana
|
233 |
+
nice
|
234 |
+
nice
|
235 |
+
number
|
236 |
+
numerate
|
237 |
+
nun
|
238 |
+
object
|
239 |
+
opiate
|
240 |
+
ornament
|
241 |
+
outbox
|
242 |
+
outcry
|
243 |
+
outpour
|
244 |
+
outreach
|
245 |
+
outride
|
246 |
+
outright
|
247 |
+
outside
|
248 |
+
outwork
|
249 |
+
overall
|
250 |
+
overbid
|
251 |
+
overcall
|
252 |
+
overcast
|
253 |
+
overfall
|
254 |
+
overflow
|
255 |
+
overhaul
|
256 |
+
overhead
|
257 |
+
overlap
|
258 |
+
overlay
|
259 |
+
overuse
|
260 |
+
overweight
|
261 |
+
overwork
|
262 |
+
pace
|
263 |
+
palled
|
264 |
+
palling
|
265 |
+
para
|
266 |
+
pasty
|
267 |
+
pate
|
268 |
+
pauline
|
269 |
+
pedal
|
270 |
+
peer
|
271 |
+
perfect
|
272 |
+
periodic
|
273 |
+
permit
|
274 |
+
pervert
|
275 |
+
pinta
|
276 |
+
placer
|
277 |
+
platy
|
278 |
+
polish
|
279 |
+
polish
|
280 |
+
poll
|
281 |
+
pontificate
|
282 |
+
postulate
|
283 |
+
pram
|
284 |
+
prayer
|
285 |
+
precipitate
|
286 |
+
predate
|
287 |
+
predicate
|
288 |
+
prefix
|
289 |
+
preposition
|
290 |
+
present
|
291 |
+
pretest
|
292 |
+
primer
|
293 |
+
proceeds
|
294 |
+
produce
|
295 |
+
progress
|
296 |
+
project
|
297 |
+
proportionate
|
298 |
+
prospect
|
299 |
+
protest
|
300 |
+
pussy
|
301 |
+
putter
|
302 |
+
putting
|
303 |
+
quite
|
304 |
+
ragged
|
305 |
+
raven
|
306 |
+
re
|
307 |
+
read
|
308 |
+
reading
|
309 |
+
reading
|
310 |
+
real
|
311 |
+
rebel
|
312 |
+
recall
|
313 |
+
recap
|
314 |
+
recitative
|
315 |
+
recollect
|
316 |
+
record
|
317 |
+
recreate
|
318 |
+
recreation
|
319 |
+
redress
|
320 |
+
refill
|
321 |
+
refund
|
322 |
+
refuse
|
323 |
+
reject
|
324 |
+
relay
|
325 |
+
remake
|
326 |
+
repaint
|
327 |
+
reprint
|
328 |
+
reread
|
329 |
+
rerun
|
330 |
+
resent
|
331 |
+
reside
|
332 |
+
resign
|
333 |
+
respray
|
334 |
+
resume
|
335 |
+
retard
|
336 |
+
retest
|
337 |
+
retread
|
338 |
+
rewrite
|
339 |
+
root
|
340 |
+
routed
|
341 |
+
routing
|
342 |
+
row
|
343 |
+
rugged
|
344 |
+
rummy
|
345 |
+
sais
|
346 |
+
sake
|
347 |
+
sambuca
|
348 |
+
saucier
|
349 |
+
second
|
350 |
+
secrete
|
351 |
+
secreted
|
352 |
+
secreting
|
353 |
+
segment
|
354 |
+
separate
|
355 |
+
sewer
|
356 |
+
shirk
|
357 |
+
shower
|
358 |
+
sin
|
359 |
+
skied
|
360 |
+
slaver
|
361 |
+
slough
|
362 |
+
sow
|
363 |
+
spoof
|
364 |
+
squid
|
365 |
+
stingy
|
366 |
+
subject
|
367 |
+
subordinate
|
368 |
+
subvert
|
369 |
+
supply
|
370 |
+
supposed
|
371 |
+
survey
|
372 |
+
suspect
|
373 |
+
syringes
|
374 |
+
tabulate
|
375 |
+
tales
|
376 |
+
tarrier
|
377 |
+
tarry
|
378 |
+
taxes
|
379 |
+
taxis
|
380 |
+
tear
|
381 |
+
theron
|
382 |
+
thou
|
383 |
+
three-legged
|
384 |
+
tier
|
385 |
+
tinged
|
386 |
+
torment
|
387 |
+
transfer
|
388 |
+
transform
|
389 |
+
transplant
|
390 |
+
transport
|
391 |
+
transpose
|
392 |
+
tush
|
393 |
+
two-legged
|
394 |
+
unionised
|
395 |
+
unionized
|
396 |
+
update
|
397 |
+
uplift
|
398 |
+
upset
|
399 |
+
use
|
400 |
+
used
|
401 |
+
vale
|
402 |
+
violist
|
403 |
+
viva
|
404 |
+
ware
|
405 |
+
whinged
|
406 |
+
whoop
|
407 |
+
wicked
|
408 |
+
wind
|
409 |
+
windy
|
410 |
+
wino
|
411 |
+
won
|
412 |
+
worsted
|
413 |
+
wound
|
tts_text_processing/letters_and_numbers.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
+
|
3 |
+
import re
|
4 |
+
|
5 |
+
_letters_and_numbers_re = re.compile(
|
6 |
+
r"((?:[a-zA-Z]+[0-9]|[0-9]+[a-zA-Z])[a-zA-Z0-9']*)", re.IGNORECASE
|
7 |
+
)
|
8 |
+
|
9 |
+
_hardware_re = re.compile(
|
10 |
+
"([0-9]+(?:[.,][0-9]+)?)(?:\s?)(tb|gb|mb|kb|ghz|mhz|khz|hz|mm)", re.IGNORECASE
|
11 |
+
)
|
12 |
+
_hardware_key = {
|
13 |
+
"tb": "terabyte",
|
14 |
+
"gb": "gigabyte",
|
15 |
+
"mb": "megabyte",
|
16 |
+
"kb": "kilobyte",
|
17 |
+
"ghz": "gigahertz",
|
18 |
+
"mhz": "megahertz",
|
19 |
+
"khz": "kilohertz",
|
20 |
+
"hz": "hertz",
|
21 |
+
"mm": "millimeter",
|
22 |
+
"cm": "centimeter",
|
23 |
+
"km": "kilometer",
|
24 |
+
}
|
25 |
+
|
26 |
+
_dimension_re = re.compile(
|
27 |
+
r"\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b|\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b"
|
28 |
+
)
|
29 |
+
_dimension_key = {"m": "meter", "in": "inch", "inch": "inch"}
|
30 |
+
|
31 |
+
|
32 |
+
def _expand_letters_and_numbers(m):
|
33 |
+
text = re.split(r"(\d+)", m.group(0))
|
34 |
+
|
35 |
+
# remove trailing space
|
36 |
+
if text[-1] == "":
|
37 |
+
text = text[:-1]
|
38 |
+
elif text[0] == "":
|
39 |
+
text = text[1:]
|
40 |
+
|
41 |
+
# if not like 1920s, or AK47's , 20th, 1st, 2nd, 3rd, etc...
|
42 |
+
if text[-1] in ("'s", "s", "th", "nd", "st", "rd") and text[-2].isdigit():
|
43 |
+
text[-2] = text[-2] + text[-1]
|
44 |
+
text = text[:-1]
|
45 |
+
|
46 |
+
# for combining digits 2 by 2
|
47 |
+
new_text = []
|
48 |
+
for i in range(len(text)):
|
49 |
+
string = text[i]
|
50 |
+
if string.isdigit() and len(string) < 5:
|
51 |
+
# heuristics
|
52 |
+
if len(string) > 2 and string[-2] == "0":
|
53 |
+
if string[-1] == "0":
|
54 |
+
string = [string]
|
55 |
+
else:
|
56 |
+
string = [string[:-3], string[-2], string[-1]]
|
57 |
+
elif len(string) % 2 == 0:
|
58 |
+
string = [string[i : i + 2] for i in range(0, len(string), 2)]
|
59 |
+
elif len(string) > 2:
|
60 |
+
string = [string[0]] + [
|
61 |
+
string[i : i + 2] for i in range(1, len(string), 2)
|
62 |
+
]
|
63 |
+
new_text.extend(string)
|
64 |
+
else:
|
65 |
+
new_text.append(string)
|
66 |
+
|
67 |
+
text = new_text
|
68 |
+
text = " ".join(text)
|
69 |
+
return text
|
70 |
+
|
71 |
+
|
72 |
+
def _expand_hardware(m):
|
73 |
+
quantity, measure = m.groups(0)
|
74 |
+
measure = _hardware_key[measure.lower()]
|
75 |
+
if measure[-1] != "z" and float(quantity.replace(",", "")) > 1:
|
76 |
+
return "{} {}s".format(quantity, measure)
|
77 |
+
return "{} {}".format(quantity, measure)
|
78 |
+
|
79 |
+
|
80 |
+
def _expand_dimension(m):
|
81 |
+
text = "".join([x for x in m.groups(0) if x != 0])
|
82 |
+
text = text.replace(" x ", " by ")
|
83 |
+
text = text.replace("x", " by ")
|
84 |
+
if text.endswith(tuple(_dimension_key.keys())):
|
85 |
+
if text[-2].isdigit():
|
86 |
+
text = "{} {}".format(text[:-1], _dimension_key[text[-1:]])
|
87 |
+
elif text[-3].isdigit():
|
88 |
+
text = "{} {}".format(text[:-2], _dimension_key[text[-2:]])
|
89 |
+
return text
|
90 |
+
|
91 |
+
|
92 |
+
def normalize_letters_and_numbers(text):
|
93 |
+
text = re.sub(_hardware_re, _expand_hardware, text)
|
94 |
+
text = re.sub(_dimension_re, _expand_dimension, text)
|
95 |
+
text = re.sub(_letters_and_numbers_re, _expand_letters_and_numbers, text)
|
96 |
+
return text
|
tts_text_processing/numerical.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
+
|
3 |
+
import inflect
|
4 |
+
import re
|
5 |
+
|
6 |
+
_magnitudes = ["trillion", "billion", "million", "thousand", "hundred", "m", "b", "t"]
|
7 |
+
_magnitudes_key = {"m": "million", "b": "billion", "t": "trillion"}
|
8 |
+
_measurements = "(f|c|k|d|m)"
|
9 |
+
_measurements_key = {"f": "fahrenheit", "c": "celsius", "k": "thousand", "m": "meters"}
|
10 |
+
_currency_key = {"$": "dollar", "£": "pound", "€": "euro", "₩": "won"}
|
11 |
+
_inflect = inflect.engine()
|
12 |
+
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
13 |
+
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
14 |
+
_currency_re = re.compile(
|
15 |
+
r"([\$€£₩])([0-9\.\,]*[0-9]+)(?:[ ]?({})(?=[^a-zA-Z]))?".format(
|
16 |
+
"|".join(_magnitudes)
|
17 |
+
),
|
18 |
+
re.IGNORECASE,
|
19 |
+
)
|
20 |
+
_measurement_re = re.compile(
|
21 |
+
r"([0-9\.\,]*[0-9]+(\s)?{}\b)".format(_measurements), re.IGNORECASE
|
22 |
+
)
|
23 |
+
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
24 |
+
# _range_re = re.compile(r'(?<=[0-9])+(-)(?=[0-9])+.*?')
|
25 |
+
_roman_re = re.compile(
|
26 |
+
r"\b(?=[MDCLXVI]+\b)M{0,4}(CM|CD|D?C{0,3})(XC|XL|L?X{0,3})(IX|IV|V?I{2,3})\b"
|
27 |
+
) # avoid I
|
28 |
+
_multiply_re = re.compile(r"(\b[0-9]+)(x)([0-9]+)")
|
29 |
+
_number_re = re.compile(r"[0-9]+'s|[0-9]+s|[0-9]+")
|
30 |
+
|
31 |
+
|
32 |
+
def _remove_commas(m):
|
33 |
+
return m.group(1).replace(",", "")
|
34 |
+
|
35 |
+
|
36 |
+
def _expand_decimal_point(m):
|
37 |
+
return m.group(1).replace(".", " point ")
|
38 |
+
|
39 |
+
|
40 |
+
def _expand_currency(m):
|
41 |
+
currency = _currency_key[m.group(1)]
|
42 |
+
quantity = m.group(2)
|
43 |
+
magnitude = m.group(3)
|
44 |
+
|
45 |
+
# remove commas from quantity to be able to convert to numerical
|
46 |
+
quantity = quantity.replace(",", "")
|
47 |
+
|
48 |
+
# check for million, billion, etc...
|
49 |
+
if magnitude is not None and magnitude.lower() in _magnitudes:
|
50 |
+
if len(magnitude) == 1:
|
51 |
+
magnitude = _magnitudes_key[magnitude.lower()]
|
52 |
+
return "{} {} {}".format(_expand_hundreds(quantity), magnitude, currency + "s")
|
53 |
+
|
54 |
+
parts = quantity.split(".")
|
55 |
+
if len(parts) > 2:
|
56 |
+
return quantity + " " + currency + "s" # Unexpected format
|
57 |
+
|
58 |
+
dollars = int(parts[0]) if parts[0] else 0
|
59 |
+
|
60 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
61 |
+
if dollars and cents:
|
62 |
+
dollar_unit = currency if dollars == 1 else currency + "s"
|
63 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
64 |
+
return "{} {}, {} {}".format(
|
65 |
+
_expand_hundreds(dollars),
|
66 |
+
dollar_unit,
|
67 |
+
_inflect.number_to_words(cents),
|
68 |
+
cent_unit,
|
69 |
+
)
|
70 |
+
elif dollars:
|
71 |
+
dollar_unit = currency if dollars == 1 else currency + "s"
|
72 |
+
return "{} {}".format(_expand_hundreds(dollars), dollar_unit)
|
73 |
+
elif cents:
|
74 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
75 |
+
return "{} {}".format(_inflect.number_to_words(cents), cent_unit)
|
76 |
+
else:
|
77 |
+
return "zero" + " " + currency + "s"
|
78 |
+
|
79 |
+
|
80 |
+
def _expand_hundreds(text):
|
81 |
+
number = float(text)
|
82 |
+
if number > 1000 < 10000 and (number % 100 == 0) and (number % 1000 != 0):
|
83 |
+
return _inflect.number_to_words(int(number / 100)) + " hundred"
|
84 |
+
else:
|
85 |
+
return _inflect.number_to_words(text)
|
86 |
+
|
87 |
+
|
88 |
+
def _expand_ordinal(m):
|
89 |
+
return _inflect.number_to_words(m.group(0))
|
90 |
+
|
91 |
+
|
92 |
+
def _expand_measurement(m):
|
93 |
+
_, number, measurement = re.split("(\d+(?:\.\d+)?)", m.group(0))
|
94 |
+
number = _inflect.number_to_words(number)
|
95 |
+
measurement = "".join(measurement.split())
|
96 |
+
measurement = _measurements_key[measurement.lower()]
|
97 |
+
return "{} {}".format(number, measurement)
|
98 |
+
|
99 |
+
|
100 |
+
def _expand_range(m):
|
101 |
+
return " to "
|
102 |
+
|
103 |
+
|
104 |
+
def _expand_multiply(m):
|
105 |
+
left = m.group(1)
|
106 |
+
right = m.group(3)
|
107 |
+
return "{} by {}".format(left, right)
|
108 |
+
|
109 |
+
|
110 |
+
def _expand_roman(m):
|
111 |
+
# from https://stackoverflow.com/questions/19308177/converting-roman-numerals-to-integers-in-python
|
112 |
+
roman_numerals = {"I": 1, "V": 5, "X": 10, "L": 50, "C": 100, "D": 500, "M": 1000}
|
113 |
+
result = 0
|
114 |
+
num = m.group(0)
|
115 |
+
for i, c in enumerate(num):
|
116 |
+
if (i + 1) == len(num) or roman_numerals[c] >= roman_numerals[num[i + 1]]:
|
117 |
+
result += roman_numerals[c]
|
118 |
+
else:
|
119 |
+
result -= roman_numerals[c]
|
120 |
+
return str(result)
|
121 |
+
|
122 |
+
|
123 |
+
def _expand_number(m):
|
124 |
+
_, number, suffix = re.split(r"(\d+(?:'?\d+)?)", m.group(0))
|
125 |
+
number = int(number)
|
126 |
+
if (
|
127 |
+
number > 1000
|
128 |
+
and number < 10000
|
129 |
+
and (number % 100 == 0)
|
130 |
+
and (number % 1000 != 0)
|
131 |
+
):
|
132 |
+
text = _inflect.number_to_words(number // 100) + " hundred"
|
133 |
+
elif number > 1000 and number < 3000:
|
134 |
+
if number == 2000:
|
135 |
+
text = "two thousand"
|
136 |
+
elif number > 2000 and number < 2010:
|
137 |
+
text = "two thousand " + _inflect.number_to_words(number % 100)
|
138 |
+
elif number % 100 == 0:
|
139 |
+
text = _inflect.number_to_words(number // 100) + " hundred"
|
140 |
+
else:
|
141 |
+
number = _inflect.number_to_words(
|
142 |
+
number, andword="", zero="oh", group=2
|
143 |
+
).replace(", ", " ")
|
144 |
+
number = re.sub(r"-", " ", number)
|
145 |
+
text = number
|
146 |
+
else:
|
147 |
+
number = _inflect.number_to_words(number, andword="and")
|
148 |
+
number = re.sub(r"-", " ", number)
|
149 |
+
number = re.sub(r",", "", number)
|
150 |
+
text = number
|
151 |
+
|
152 |
+
if suffix in ("'s", "s"):
|
153 |
+
if text[-1] == "y":
|
154 |
+
text = text[:-1] + "ies"
|
155 |
+
else:
|
156 |
+
text = text + suffix
|
157 |
+
|
158 |
+
return text
|
159 |
+
|
160 |
+
|
161 |
+
def normalize_currency(text):
|
162 |
+
return re.sub(_currency_re, _expand_currency, text)
|
163 |
+
|
164 |
+
|
165 |
+
def normalize_numbers(text):
|
166 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
167 |
+
text = re.sub(_currency_re, _expand_currency, text)
|
168 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
169 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
170 |
+
# text = re.sub(_range_re, _expand_range, text)
|
171 |
+
# text = re.sub(_measurement_re, _expand_measurement, text)
|
172 |
+
text = re.sub(_roman_re, _expand_roman, text)
|
173 |
+
text = re.sub(_multiply_re, _expand_multiply, text)
|
174 |
+
text = re.sub(_number_re, _expand_number, text)
|
175 |
+
return text
|
tts_text_processing/symbols.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
+
|
3 |
+
"""
|
4 |
+
Defines the set of symbols used in text input to the model.
|
5 |
+
|
6 |
+
The default is a set of ASCII characters that works well for English or text
|
7 |
+
that has been run through Unidecode. For other data, you can modify
|
8 |
+
_characters."""
|
9 |
+
|
10 |
+
|
11 |
+
arpabet = [
|
12 |
+
"AA",
|
13 |
+
"AA0",
|
14 |
+
"AA1",
|
15 |
+
"AA2",
|
16 |
+
"AE",
|
17 |
+
"AE0",
|
18 |
+
"AE1",
|
19 |
+
"AE2",
|
20 |
+
"AH",
|
21 |
+
"AH0",
|
22 |
+
"AH1",
|
23 |
+
"AH2",
|
24 |
+
"AO",
|
25 |
+
"AO0",
|
26 |
+
"AO1",
|
27 |
+
"AO2",
|
28 |
+
"AW",
|
29 |
+
"AW0",
|
30 |
+
"AW1",
|
31 |
+
"AW2",
|
32 |
+
"AY",
|
33 |
+
"AY0",
|
34 |
+
"AY1",
|
35 |
+
"AY2",
|
36 |
+
"B",
|
37 |
+
"CH",
|
38 |
+
"D",
|
39 |
+
"DH",
|
40 |
+
"EH",
|
41 |
+
"EH0",
|
42 |
+
"EH1",
|
43 |
+
"EH2",
|
44 |
+
"ER",
|
45 |
+
"ER0",
|
46 |
+
"ER1",
|
47 |
+
"ER2",
|
48 |
+
"EY",
|
49 |
+
"EY0",
|
50 |
+
"EY1",
|
51 |
+
"EY2",
|
52 |
+
"F",
|
53 |
+
"G",
|
54 |
+
"HH",
|
55 |
+
"IH",
|
56 |
+
"IH0",
|
57 |
+
"IH1",
|
58 |
+
"IH2",
|
59 |
+
"IY",
|
60 |
+
"IY0",
|
61 |
+
"IY1",
|
62 |
+
"IY2",
|
63 |
+
"JH",
|
64 |
+
"K",
|
65 |
+
"L",
|
66 |
+
"M",
|
67 |
+
"N",
|
68 |
+
"NG",
|
69 |
+
"OW",
|
70 |
+
"OW0",
|
71 |
+
"OW1",
|
72 |
+
"OW2",
|
73 |
+
"OY",
|
74 |
+
"OY0",
|
75 |
+
"OY1",
|
76 |
+
"OY2",
|
77 |
+
"P",
|
78 |
+
"R",
|
79 |
+
"S",
|
80 |
+
"SH",
|
81 |
+
"T",
|
82 |
+
"TH",
|
83 |
+
"UH",
|
84 |
+
"UH0",
|
85 |
+
"UH1",
|
86 |
+
"UH2",
|
87 |
+
"UW",
|
88 |
+
"UW0",
|
89 |
+
"UW1",
|
90 |
+
"UW2",
|
91 |
+
"V",
|
92 |
+
"W",
|
93 |
+
"Y",
|
94 |
+
"Z",
|
95 |
+
"ZH",
|
96 |
+
]
|
97 |
+
|
98 |
+
|
99 |
+
def get_symbols(symbol_set):
|
100 |
+
if symbol_set == "english_basic":
|
101 |
+
_pad = "_"
|
102 |
+
_punctuation = "!'\"(),.:;? "
|
103 |
+
_special = "-"
|
104 |
+
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
105 |
+
_arpabet = ["@" + s for s in arpabet]
|
106 |
+
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
107 |
+
elif symbol_set == "english_basic_lowercase":
|
108 |
+
_pad = "_"
|
109 |
+
_punctuation = "!'\"(),.:;? "
|
110 |
+
_special = "-"
|
111 |
+
_letters = "abcdefghijklmnopqrstuvwxyz"
|
112 |
+
_arpabet = ["@" + s for s in arpabet]
|
113 |
+
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
114 |
+
elif symbol_set == "english_expanded":
|
115 |
+
_punctuation = "!'\",.:;? "
|
116 |
+
_math = "#%&*+-/[]()"
|
117 |
+
_special = "_@©°½—₩€$"
|
118 |
+
_accented = "áçéêëñöøćž"
|
119 |
+
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
120 |
+
_arpabet = ["@" + s for s in arpabet]
|
121 |
+
symbols = (
|
122 |
+
list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
123 |
+
)
|
124 |
+
elif symbol_set == "ukrainian":
|
125 |
+
_punctuation = "'.,?! "
|
126 |
+
_special = "-+"
|
127 |
+
_letters = "абвгґдежзийклмнопрстуфхцчшщьюяєії"
|
128 |
+
symbols = list(_punctuation + _special + _letters)
|
129 |
+
elif symbol_set == "radtts":
|
130 |
+
_punctuation = "!'\",.:;? "
|
131 |
+
_math = "#%&*+-/[]()"
|
132 |
+
_special = "_@©°½—₩€$"
|
133 |
+
_accented = "áçéêëñöøćž"
|
134 |
+
_numbers = "0123456789"
|
135 |
+
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
136 |
+
_arpabet = ["@" + s for s in arpabet]
|
137 |
+
symbols = (
|
138 |
+
list(_punctuation + _math + _special + _accented + _numbers + _letters)
|
139 |
+
+ _arpabet
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
raise Exception("{} symbol set does not exist".format(symbol_set))
|
143 |
+
|
144 |
+
return symbols
|
tts_text_processing/text_processing.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
+
|
3 |
+
import re
|
4 |
+
import numpy as np
|
5 |
+
from .cleaners import Cleaner
|
6 |
+
from .symbols import get_symbols
|
7 |
+
from .grapheme_dictionary import Grapheme2PhonemeDictionary
|
8 |
+
|
9 |
+
|
10 |
+
#########
|
11 |
+
# REGEX #
|
12 |
+
#########
|
13 |
+
|
14 |
+
# Regular expression matching text enclosed in curly braces for encoding
|
15 |
+
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
16 |
+
|
17 |
+
# Regular expression matching words and not words
|
18 |
+
_words_re = re.compile(
|
19 |
+
r"([a-zA-ZÀ-ž]+['][a-zA-ZÀ-ž]+|[a-zA-ZÀ-ž]+)|([{][^}]+[}]|[^a-zA-ZÀ-ž{}]+)"
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def lines_to_list(filename):
|
24 |
+
with open(filename, encoding="utf-8") as f:
|
25 |
+
lines = f.readlines()
|
26 |
+
lines = [l.rstrip() for l in lines]
|
27 |
+
return lines
|
28 |
+
|
29 |
+
|
30 |
+
class TextProcessing(object):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
symbol_set,
|
34 |
+
cleaner_name,
|
35 |
+
heteronyms_path,
|
36 |
+
phoneme_dict_path,
|
37 |
+
p_phoneme,
|
38 |
+
handle_phoneme,
|
39 |
+
handle_phoneme_ambiguous,
|
40 |
+
prepend_space_to_text=False,
|
41 |
+
append_space_to_text=False,
|
42 |
+
add_bos_eos_to_text=False,
|
43 |
+
encoding="latin-1",
|
44 |
+
):
|
45 |
+
if heteronyms_path is not None and heteronyms_path != "":
|
46 |
+
self.heteronyms = set(lines_to_list(heteronyms_path))
|
47 |
+
else:
|
48 |
+
self.heteronyms = []
|
49 |
+
# phoneme dict
|
50 |
+
self.phonemedict = {}
|
51 |
+
|
52 |
+
self.p_phoneme = p_phoneme
|
53 |
+
self.handle_phoneme = handle_phoneme
|
54 |
+
self.handle_phoneme_ambiguous = handle_phoneme_ambiguous
|
55 |
+
|
56 |
+
self.symbols = get_symbols(symbol_set)
|
57 |
+
self.cleaner_names = cleaner_name
|
58 |
+
self.cleaner = Cleaner(cleaner_name, self.phonemedict)
|
59 |
+
|
60 |
+
self.prepend_space_to_text = prepend_space_to_text
|
61 |
+
self.append_space_to_text = append_space_to_text
|
62 |
+
self.add_bos_eos_to_text = add_bos_eos_to_text
|
63 |
+
|
64 |
+
if add_bos_eos_to_text:
|
65 |
+
self.symbols.append("<bos>")
|
66 |
+
self.symbols.append("<eos>")
|
67 |
+
|
68 |
+
# Mappings from symbol to numeric ID and vice versa:
|
69 |
+
self.symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
|
70 |
+
self.id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
|
71 |
+
|
72 |
+
def text_to_sequence(self, text):
|
73 |
+
sequence = []
|
74 |
+
|
75 |
+
# Check for curly braces and treat their contents as phoneme:
|
76 |
+
while len(text):
|
77 |
+
m = _curly_re.match(text)
|
78 |
+
if not m:
|
79 |
+
sequence += self.symbols_to_sequence(text)
|
80 |
+
break
|
81 |
+
sequence += self.symbols_to_sequence(m.group(1))
|
82 |
+
sequence += self.phoneme_to_sequence(m.group(2))
|
83 |
+
text = m.group(3)
|
84 |
+
|
85 |
+
return sequence
|
86 |
+
|
87 |
+
def sequence_to_text(self, sequence):
|
88 |
+
result = ""
|
89 |
+
for symbol_id in sequence:
|
90 |
+
if symbol_id in self.id_to_symbol:
|
91 |
+
s = self.id_to_symbol[symbol_id]
|
92 |
+
# Enclose phoneme back in curly braces:
|
93 |
+
if len(s) > 1 and s[0] == "@":
|
94 |
+
s = "{%s}" % s[1:]
|
95 |
+
result += s
|
96 |
+
return result.replace("}{", " ")
|
97 |
+
|
98 |
+
def clean_text(self, text):
|
99 |
+
text = self.cleaner(text)
|
100 |
+
return text
|
101 |
+
|
102 |
+
def symbols_to_sequence(self, symbols):
|
103 |
+
return [self.symbol_to_id[s] for s in symbols if s in self.symbol_to_id]
|
104 |
+
|
105 |
+
def phoneme_to_sequence(self, text):
|
106 |
+
return self.symbols_to_sequence(["@" + s for s in text.split()])
|
107 |
+
|
108 |
+
def get_phoneme(self, word):
|
109 |
+
phoneme_suffix = ""
|
110 |
+
|
111 |
+
if word.lower() in self.heteronyms:
|
112 |
+
return word
|
113 |
+
|
114 |
+
if len(word) > 2 and word.endswith("'s"):
|
115 |
+
phoneme = self.phonemedict.lookup(word)
|
116 |
+
if phoneme is None:
|
117 |
+
phoneme = self.phonemedict.lookup(word[:-2])
|
118 |
+
phoneme_suffix = "" if phoneme is None else " Z"
|
119 |
+
|
120 |
+
elif len(word) > 1 and word.endswith("s"):
|
121 |
+
phoneme = self.phonemedict.lookup(word)
|
122 |
+
if phoneme is None:
|
123 |
+
phoneme = self.phonemedict.lookup(word[:-1])
|
124 |
+
phoneme_suffix = "" if phoneme is None else " Z"
|
125 |
+
else:
|
126 |
+
phoneme = self.phonemedict.lookup(word)
|
127 |
+
|
128 |
+
if phoneme is None:
|
129 |
+
return word
|
130 |
+
|
131 |
+
if len(phoneme) > 1:
|
132 |
+
if self.handle_phoneme_ambiguous == "first":
|
133 |
+
phoneme = phoneme[0]
|
134 |
+
elif self.handle_phoneme_ambiguous == "random":
|
135 |
+
phoneme = np.random.choice(phoneme)
|
136 |
+
elif self.handle_phoneme_ambiguous == "ignore":
|
137 |
+
return word
|
138 |
+
else:
|
139 |
+
phoneme = phoneme[0]
|
140 |
+
|
141 |
+
phoneme = "{" + phoneme + phoneme_suffix + "}"
|
142 |
+
|
143 |
+
return phoneme
|
144 |
+
|
145 |
+
def encode_text(self, text, return_all=False):
|
146 |
+
text_clean = self.clean_text(text)
|
147 |
+
text = text_clean
|
148 |
+
|
149 |
+
text_phoneme = ""
|
150 |
+
if self.p_phoneme > 0:
|
151 |
+
text_phoneme = self.convert_to_phoneme(text)
|
152 |
+
text = text_phoneme
|
153 |
+
|
154 |
+
text_encoded = self.text_to_sequence(text)
|
155 |
+
|
156 |
+
if self.prepend_space_to_text:
|
157 |
+
text_encoded.insert(0, self.symbol_to_id[" "])
|
158 |
+
|
159 |
+
if self.append_space_to_text:
|
160 |
+
text_encoded.append(self.symbol_to_id[" "])
|
161 |
+
|
162 |
+
if self.add_bos_eos_to_text:
|
163 |
+
text_encoded.insert(0, self.symbol_to_id["<bos>"])
|
164 |
+
text_encoded.append(self.symbol_to_id["<eos>"])
|
165 |
+
|
166 |
+
if return_all:
|
167 |
+
return text_encoded, text_clean, text_phoneme
|
168 |
+
|
169 |
+
return text_encoded
|
170 |
+
|
171 |
+
def convert_to_phoneme(self, text):
|
172 |
+
if self.handle_phoneme == "sentence":
|
173 |
+
if np.random.uniform() < self.p_phoneme:
|
174 |
+
words = _words_re.findall(text)
|
175 |
+
text_phoneme = [
|
176 |
+
self.get_phoneme(word[0])
|
177 |
+
if (word[0] != "")
|
178 |
+
else re.sub(r"\s(\d)", r"\1", word[1].upper())
|
179 |
+
for word in words
|
180 |
+
]
|
181 |
+
text_phoneme = "".join(text_phoneme)
|
182 |
+
text = text_phoneme
|
183 |
+
elif self.handle_phoneme == "word":
|
184 |
+
words = _words_re.findall(text)
|
185 |
+
text_phoneme = [
|
186 |
+
re.sub(r"\s(\d)", r"\1", word[1].upper())
|
187 |
+
if word[0] == ""
|
188 |
+
else (
|
189 |
+
self.get_phoneme(word[0])
|
190 |
+
if np.random.uniform() < self.p_phoneme
|
191 |
+
else word[0]
|
192 |
+
)
|
193 |
+
for word in words
|
194 |
+
]
|
195 |
+
text_phoneme = "".join(text_phoneme)
|
196 |
+
text = text_phoneme
|
197 |
+
elif self.handle_phoneme != "":
|
198 |
+
raise Exception(
|
199 |
+
"{} handle_phoneme is not supported".format(self.handle_phoneme)
|
200 |
+
)
|
201 |
+
return text
|