Spaces:
Running
Running
Fixes to the codebase
Browse files- app.py +23 -49
- attribute_prediction_model.py +2 -0
- audio_processing.py +38 -37
- autoregressive_flow.py +2 -6
- common.py +32 -41
- data.py +10 -153
- export_weights.py +14 -0
- loss.py +0 -228
- partialconv1d.py +1 -2
- radam.py +0 -114
- radtts.py +18 -24
- requirements.txt +0 -3
- torch_env.py +19 -0
- tts_text_processing/abbreviations.py +0 -57
- tts_text_processing/acronyms.py +0 -69
- tts_text_processing/cleaners.py +4 -75
- tts_text_processing/cmudict.py +0 -140
- tts_text_processing/datestime.py +0 -24
- tts_text_processing/grapheme_dictionary.py +0 -37
- tts_text_processing/heteronyms +0 -413
- tts_text_processing/letters_and_numbers.py +0 -96
- tts_text_processing/numerical.py +0 -175
- tts_text_processing/symbols.py +0 -144
- tts_text_processing/text_processing.py +11 -13
app.py
CHANGED
@@ -6,38 +6,32 @@ import time
|
|
6 |
from importlib.metadata import version
|
7 |
from enum import Enum
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
use_zerogpu = False
|
12 |
|
13 |
-
|
14 |
-
import spaces # it's for ZeroGPU
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
20 |
|
21 |
import gradio as gr
|
22 |
|
23 |
-
import torch
|
24 |
-
import torchaudio
|
25 |
|
26 |
# Vocos
|
27 |
from vocos import Vocos
|
28 |
|
29 |
-
|
30 |
-
from radtts import RADTTS
|
31 |
-
from data import Data
|
32 |
-
from common import update_params
|
33 |
|
34 |
-
|
|
|
35 |
|
36 |
-
|
37 |
-
print("
|
38 |
-
|
39 |
-
|
40 |
-
device = "cpu"
|
41 |
|
42 |
|
43 |
def download_file_from_repo(
|
@@ -65,15 +59,13 @@ def download_file_from_repo(
|
|
65 |
|
66 |
download_file_from_repo(
|
67 |
"Yehor/radtts-uk",
|
68 |
-
"radtts-pp-dap-model/
|
69 |
"./models/",
|
70 |
)
|
71 |
|
72 |
# Init the model
|
73 |
-
seed = 1234
|
74 |
-
|
75 |
config = "configs/radtts-pp-dap-model.json"
|
76 |
-
radtts_path = "models/radtts-pp-dap-model/
|
77 |
|
78 |
params = []
|
79 |
|
@@ -87,19 +79,11 @@ update_params(config, params)
|
|
87 |
data_config = config["data_config"]
|
88 |
model_config = config["model_config"]
|
89 |
|
90 |
-
# Seed
|
91 |
-
if use_cuda:
|
92 |
-
torch.cuda.manual_seed(seed)
|
93 |
-
else:
|
94 |
-
torch.manual_seed(seed)
|
95 |
-
|
96 |
# Load vocoder
|
97 |
vocos = Vocos.from_pretrained("patriotyk/vocos-mel-hifigan-compat-44100khz").to(device)
|
98 |
|
99 |
# Load RAD-TTS
|
100 |
-
radtts = RADTTS(**model_config)
|
101 |
-
if use_cuda:
|
102 |
-
radtts = radtts.cuda()
|
103 |
|
104 |
radtts.enable_inverse_cache() # cache inverse matrix for 1x1 invertible convs
|
105 |
|
@@ -151,6 +135,7 @@ tech_env = f"""
|
|
151 |
#### Environment
|
152 |
|
153 |
- Python: {sys.version}
|
|
|
154 |
""".strip()
|
155 |
|
156 |
tech_libraries = f"""
|
@@ -161,8 +146,6 @@ tech_libraries = f"""
|
|
161 |
- scipy: {version("scipy")}
|
162 |
- numba: {version("numba")}
|
163 |
- librosa: {version("librosa")}
|
164 |
-
- unidecode: {version("unidecode")}
|
165 |
-
- inflect: {version("inflect")}
|
166 |
""".strip()
|
167 |
|
168 |
|
@@ -218,25 +201,16 @@ def inference(text, voice):
|
|
218 |
energy_mean = 0
|
219 |
energy_std = 0
|
220 |
|
221 |
-
tensor_text = trainset.get_text(text)
|
222 |
|
223 |
-
speaker_id = trainset.get_speaker_id(speaker)
|
224 |
speaker_id_text, speaker_id_attributes = speaker_id, speaker_id
|
225 |
|
226 |
if speaker_text is not None:
|
227 |
-
speaker_id_text = trainset.get_speaker_id(speaker_text)
|
228 |
|
229 |
if speaker_attributes is not None:
|
230 |
-
speaker_id_attributes = trainset.get_speaker_id(speaker_attributes)
|
231 |
-
|
232 |
-
if use_cuda:
|
233 |
-
tensor_text = tensor_text.cuda()
|
234 |
-
speaker_id = speaker_id.cuda()
|
235 |
-
|
236 |
-
if speaker_id_text is not None:
|
237 |
-
speaker_id_text = speaker_id_text.cuda()
|
238 |
-
if speaker_id_attributes is not None:
|
239 |
-
speaker_id_attributes = speaker_id_attributes.cuda()
|
240 |
|
241 |
inference_start = time.time()
|
242 |
|
|
|
6 |
from importlib.metadata import version
|
7 |
from enum import Enum
|
8 |
|
9 |
+
import torch
|
10 |
+
import torchaudio
|
|
|
11 |
|
12 |
+
from huggingface_hub import hf_hub_download
|
|
|
13 |
|
14 |
+
# RAD-TTS code
|
15 |
+
from radtts import RADTTS
|
16 |
+
from data import Data
|
17 |
+
from common import update_params
|
18 |
+
from torch_env import device
|
19 |
|
20 |
import gradio as gr
|
21 |
|
|
|
|
|
22 |
|
23 |
# Vocos
|
24 |
from vocos import Vocos
|
25 |
|
26 |
+
use_zerogpu = False
|
|
|
|
|
|
|
27 |
|
28 |
+
try:
|
29 |
+
import spaces # it's for ZeroGPU
|
30 |
|
31 |
+
use_zerogpu = True
|
32 |
+
print("ZeroGPU is available, changing inference call.")
|
33 |
+
except ImportError:
|
34 |
+
print("ZeroGPU is not available, skipping...")
|
|
|
35 |
|
36 |
|
37 |
def download_file_from_repo(
|
|
|
59 |
|
60 |
download_file_from_repo(
|
61 |
"Yehor/radtts-uk",
|
62 |
+
"radtts-pp-dap-model/model_dap_84000_state.pt",
|
63 |
"./models/",
|
64 |
)
|
65 |
|
66 |
# Init the model
|
|
|
|
|
67 |
config = "configs/radtts-pp-dap-model.json"
|
68 |
+
radtts_path = "models/radtts-pp-dap-model/model_dap_84000_state.pt"
|
69 |
|
70 |
params = []
|
71 |
|
|
|
79 |
data_config = config["data_config"]
|
80 |
model_config = config["model_config"]
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
# Load vocoder
|
83 |
vocos = Vocos.from_pretrained("patriotyk/vocos-mel-hifigan-compat-44100khz").to(device)
|
84 |
|
85 |
# Load RAD-TTS
|
86 |
+
radtts = RADTTS(**model_config).to(device)
|
|
|
|
|
87 |
|
88 |
radtts.enable_inverse_cache() # cache inverse matrix for 1x1 invertible convs
|
89 |
|
|
|
135 |
#### Environment
|
136 |
|
137 |
- Python: {sys.version}
|
138 |
+
- Torch device: {device}
|
139 |
""".strip()
|
140 |
|
141 |
tech_libraries = f"""
|
|
|
146 |
- scipy: {version("scipy")}
|
147 |
- numba: {version("numba")}
|
148 |
- librosa: {version("librosa")}
|
|
|
|
|
149 |
""".strip()
|
150 |
|
151 |
|
|
|
201 |
energy_mean = 0
|
202 |
energy_std = 0
|
203 |
|
204 |
+
tensor_text = trainset.get_text(text).to(device)
|
205 |
|
206 |
+
speaker_id = trainset.get_speaker_id(speaker).to(device)
|
207 |
speaker_id_text, speaker_id_attributes = speaker_id, speaker_id
|
208 |
|
209 |
if speaker_text is not None:
|
210 |
+
speaker_id_text = trainset.get_speaker_id(speaker_text).to(device)
|
211 |
|
212 |
if speaker_attributes is not None:
|
213 |
+
speaker_id_attributes = trainset.get_speaker_id(speaker_attributes).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
inference_start = time.time()
|
216 |
|
attribute_prediction_model.py
CHANGED
@@ -18,8 +18,10 @@
|
|
18 |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
# DEALINGS IN THE SOFTWARE.
|
|
|
21 |
import torch
|
22 |
from torch import nn
|
|
|
23 |
from common import ConvNorm, Invertible1x1Conv
|
24 |
from common import AffineTransformationLayer, SplineTransformationLayer
|
25 |
from common import ConvLSTMLinear
|
|
|
18 |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
# DEALINGS IN THE SOFTWARE.
|
21 |
+
|
22 |
import torch
|
23 |
from torch import nn
|
24 |
+
|
25 |
from common import ConvNorm, Invertible1x1Conv
|
26 |
from common import AffineTransformationLayer, SplineTransformationLayer
|
27 |
from common import ConvLSTMLinear
|
audio_processing.py
CHANGED
@@ -18,12 +18,50 @@
|
|
18 |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
# DEALINGS IN THE SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
import torch
|
22 |
import numpy as np
|
|
|
23 |
from scipy.signal import get_window
|
24 |
from librosa.filters import mel as librosa_mel_fn
|
25 |
import librosa.util as librosa_util
|
26 |
|
|
|
|
|
|
|
|
|
27 |
|
28 |
def window_sumsquare(
|
29 |
window,
|
@@ -174,43 +212,6 @@ class TacotronSTFT(torch.nn.Module):
|
|
174 |
return mel_output
|
175 |
|
176 |
|
177 |
-
"""
|
178 |
-
BSD 3-Clause License
|
179 |
-
|
180 |
-
Copyright (c) 2017, Prem Seetharaman
|
181 |
-
All rights reserved.
|
182 |
-
|
183 |
-
* Redistribution and use in source and binary forms, with or without
|
184 |
-
modification, are permitted provided that the following conditions are met:
|
185 |
-
|
186 |
-
* Redistributions of source code must retain the above copyright notice,
|
187 |
-
this list of conditions and the following disclaimer.
|
188 |
-
|
189 |
-
* Redistributions in binary form must reproduce the above copyright notice, this
|
190 |
-
list of conditions and the following disclaimer in the
|
191 |
-
documentation and/or other materials provided with the distribution.
|
192 |
-
|
193 |
-
* Neither the name of the copyright holder nor the names of its
|
194 |
-
contributors may be used to endorse or promote products derived from this
|
195 |
-
software without specific prior written permission.
|
196 |
-
|
197 |
-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
198 |
-
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
199 |
-
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
200 |
-
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
201 |
-
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
202 |
-
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
203 |
-
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
204 |
-
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
205 |
-
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
206 |
-
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
207 |
-
"""
|
208 |
-
import torch.nn.functional as F
|
209 |
-
from torch.autograd import Variable
|
210 |
-
from scipy.signal import get_window
|
211 |
-
from librosa.util import pad_center, tiny
|
212 |
-
|
213 |
-
|
214 |
class STFT(torch.nn.Module):
|
215 |
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
216 |
|
|
|
18 |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
# DEALINGS IN THE SOFTWARE.
|
21 |
+
|
22 |
+
"""
|
23 |
+
BSD 3-Clause License
|
24 |
+
|
25 |
+
Copyright (c) 2017, Prem Seetharaman
|
26 |
+
All rights reserved.
|
27 |
+
|
28 |
+
* Redistribution and use in source and binary forms, with or without
|
29 |
+
modification, are permitted provided that the following conditions are met:
|
30 |
+
|
31 |
+
* Redistributions of source code must retain the above copyright notice,
|
32 |
+
this list of conditions and the following disclaimer.
|
33 |
+
|
34 |
+
* Redistributions in binary form must reproduce the above copyright notice, this
|
35 |
+
list of conditions and the following disclaimer in the
|
36 |
+
documentation and/or other materials provided with the distribution.
|
37 |
+
|
38 |
+
* Neither the name of the copyright holder nor the names of its
|
39 |
+
contributors may be used to endorse or promote products derived from this
|
40 |
+
software without specific prior written permission.
|
41 |
+
|
42 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
43 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
44 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
45 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
46 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
47 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
48 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
49 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
50 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
51 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
52 |
+
"""
|
53 |
+
|
54 |
import torch
|
55 |
import numpy as np
|
56 |
+
|
57 |
from scipy.signal import get_window
|
58 |
from librosa.filters import mel as librosa_mel_fn
|
59 |
import librosa.util as librosa_util
|
60 |
|
61 |
+
import torch.nn.functional as F
|
62 |
+
from torch.autograd import Variable
|
63 |
+
from librosa.util import pad_center, tiny
|
64 |
+
|
65 |
|
66 |
def window_sumsquare(
|
67 |
window,
|
|
|
212 |
return mel_output
|
213 |
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
class STFT(torch.nn.Module):
|
216 |
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
217 |
|
autoregressive_flow.py
CHANGED
@@ -45,8 +45,7 @@ import torch
|
|
45 |
from torch import nn
|
46 |
|
47 |
from common import DenseLayer, SplineTransformationLayerAR
|
48 |
-
|
49 |
-
use_cuda = torch.cuda.is_available()
|
50 |
|
51 |
|
52 |
class AR_Back_Step(torch.nn.Module):
|
@@ -229,10 +228,7 @@ class AR_Step(torch.nn.Module):
|
|
229 |
(1, residual.size(1), residual.size(2)), dtype=residual.dtype
|
230 |
)
|
231 |
|
232 |
-
|
233 |
-
dummy = torch.tensor(data, device="cuda")
|
234 |
-
else:
|
235 |
-
dummy = torch.tensor(data)
|
236 |
|
237 |
self.attr_lstm.flatten_parameters()
|
238 |
|
|
|
45 |
from torch import nn
|
46 |
|
47 |
from common import DenseLayer, SplineTransformationLayerAR
|
48 |
+
from torch_env import device
|
|
|
49 |
|
50 |
|
51 |
class AR_Back_Step(torch.nn.Module):
|
|
|
228 |
(1, residual.size(1), residual.size(2)), dtype=residual.dtype
|
229 |
)
|
230 |
|
231 |
+
dummy = torch.tensor(data, device=device)
|
|
|
|
|
|
|
232 |
|
233 |
self.attr_lstm.flatten_parameters()
|
234 |
|
common.py
CHANGED
@@ -62,34 +62,7 @@ from splines import (
|
|
62 |
)
|
63 |
from partialconv1d import PartialConv1d as pconv1d
|
64 |
from typing import Tuple
|
65 |
-
|
66 |
-
use_cuda = torch.cuda.is_available()
|
67 |
-
|
68 |
-
if use_cuda:
|
69 |
-
device = "cuda"
|
70 |
-
else:
|
71 |
-
device = "cpu"
|
72 |
-
|
73 |
-
|
74 |
-
def update_params(config, params):
|
75 |
-
for param in params:
|
76 |
-
print(param)
|
77 |
-
k, v = param.split("=")
|
78 |
-
try:
|
79 |
-
v = ast.literal_eval(v)
|
80 |
-
except:
|
81 |
-
pass
|
82 |
-
|
83 |
-
k_split = k.split(".")
|
84 |
-
if len(k_split) > 1:
|
85 |
-
parent_k = k_split[0]
|
86 |
-
cur_param = [".".join(k_split[1:]) + "=" + str(v)]
|
87 |
-
update_params(config[parent_k], cur_param)
|
88 |
-
elif k in config and len(k_split) == 1:
|
89 |
-
print(f"overriding {k} with {v}")
|
90 |
-
config[k] = v
|
91 |
-
else:
|
92 |
-
print("{}, {} params not updated".format(k, v))
|
93 |
|
94 |
|
95 |
def get_mask_from_lengths(lengths):
|
@@ -103,10 +76,7 @@ def get_mask_from_lengths(lengths):
|
|
103 |
|
104 |
max_len = torch.max(lengths).item()
|
105 |
|
106 |
-
|
107 |
-
ids = torch.tensor(list(range(max_len)), dtype=torch.long, device="cuda")
|
108 |
-
else:
|
109 |
-
ids = torch.tensor(list(range(max_len)), dtype=torch.long, device="cpu")
|
110 |
|
111 |
mask = (ids < lengths.unsqueeze(1)).bool()
|
112 |
|
@@ -172,7 +142,7 @@ class ConvNorm(torch.nn.Module):
|
|
172 |
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
|
173 |
)
|
174 |
if self.use_weight_norm:
|
175 |
-
self.conv = nn.utils.weight_norm(self.conv)
|
176 |
|
177 |
def forward(self, signal, mask=None):
|
178 |
if self.use_partial_padding:
|
@@ -263,7 +233,7 @@ class ConvLSTMLinear(nn.Module):
|
|
263 |
dilation=1,
|
264 |
w_init_gain="relu",
|
265 |
)
|
266 |
-
conv_layer = torch.nn.utils.weight_norm(conv_layer.conv, name="weight")
|
267 |
convolutions.append(conv_layer)
|
268 |
|
269 |
self.convolutions = nn.ModuleList(convolutions)
|
@@ -281,7 +251,7 @@ class ConvLSTMLinear(nn.Module):
|
|
281 |
self.bilstm = nn.LSTM(
|
282 |
n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm
|
283 |
)
|
284 |
-
lstm_norm_fn_pntr = nn.utils.spectral_norm
|
285 |
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0")
|
286 |
if self.lstm_type == "bilstm":
|
287 |
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse")
|
@@ -391,10 +361,10 @@ class Encoder(nn.Module):
|
|
391 |
if lstm_norm_fn is not None:
|
392 |
if "spectral" in lstm_norm_fn:
|
393 |
print("Applying spectral norm to text encoder LSTM")
|
394 |
-
lstm_norm_fn_pntr = torch.nn.utils.spectral_norm
|
395 |
elif "weight" in lstm_norm_fn:
|
396 |
print("Applying weight norm to text encoder LSTM")
|
397 |
-
lstm_norm_fn_pntr = torch.nn.utils.weight_norm
|
398 |
self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0")
|
399 |
self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0_reverse")
|
400 |
|
@@ -450,7 +420,7 @@ class Invertible1x1ConvLUS(torch.nn.Module):
|
|
450 |
# Ensure determinant is 1.0 not -1.0
|
451 |
if torch.det(W) < 0:
|
452 |
W[:, 0] = -1 * W[:, 0]
|
453 |
-
p, lower, upper = torch.lu_unpack(*torch.
|
454 |
|
455 |
self.register_buffer("p", p)
|
456 |
# diagonals of lower will always be 1s anyway
|
@@ -616,7 +586,7 @@ class WN(torch.nn.Module):
|
|
616 |
self.in_layers = torch.nn.ModuleList()
|
617 |
self.res_skip_layers = torch.nn.ModuleList()
|
618 |
start = torch.nn.Conv1d(n_in_channels + n_context_dim, n_channels, 1)
|
619 |
-
start = torch.nn.utils.weight_norm(start, name="weight")
|
620 |
self.start = start
|
621 |
self.softplus = torch.nn.Softplus()
|
622 |
self.affine_activation = affine_activation
|
@@ -645,7 +615,7 @@ class WN(torch.nn.Module):
|
|
645 |
# in_layer = nn.utils.weight_norm(in_layer)
|
646 |
self.in_layers.append(in_layer)
|
647 |
res_skip_layer = nn.Conv1d(n_channels, n_channels, 1)
|
648 |
-
res_skip_layer = nn.utils.weight_norm(res_skip_layer)
|
649 |
self.res_skip_layers.append(res_skip_layer)
|
650 |
|
651 |
def forward(
|
@@ -823,7 +793,7 @@ class SplineTransformationLayer(torch.nn.Module):
|
|
823 |
# output is unnormalized bin weights
|
824 |
|
825 |
def forward(self, z, context, inverse=False, seq_lens=None):
|
826 |
-
b_s,
|
827 |
|
828 |
# condition on z_0, transform z_1
|
829 |
n_half = self.half_mel_channels
|
@@ -1085,3 +1055,24 @@ class ConvAttention(torch.nn.Module):
|
|
1085 |
|
1086 |
attn = self.softmax(attn) # softmax along T2
|
1087 |
return attn, attn_logprob
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
)
|
63 |
from partialconv1d import PartialConv1d as pconv1d
|
64 |
from typing import Tuple
|
65 |
+
from torch_env import device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
|
68 |
def get_mask_from_lengths(lengths):
|
|
|
76 |
|
77 |
max_len = torch.max(lengths).item()
|
78 |
|
79 |
+
ids = torch.tensor(list(range(max_len)), dtype=torch.long, device=device)
|
|
|
|
|
|
|
80 |
|
81 |
mask = (ids < lengths.unsqueeze(1)).bool()
|
82 |
|
|
|
142 |
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
|
143 |
)
|
144 |
if self.use_weight_norm:
|
145 |
+
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv)
|
146 |
|
147 |
def forward(self, signal, mask=None):
|
148 |
if self.use_partial_padding:
|
|
|
233 |
dilation=1,
|
234 |
w_init_gain="relu",
|
235 |
)
|
236 |
+
conv_layer = torch.nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight")
|
237 |
convolutions.append(conv_layer)
|
238 |
|
239 |
self.convolutions = nn.ModuleList(convolutions)
|
|
|
251 |
self.bilstm = nn.LSTM(
|
252 |
n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm
|
253 |
)
|
254 |
+
lstm_norm_fn_pntr = torch.nn.utils.parametrizations.spectral_norm
|
255 |
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0")
|
256 |
if self.lstm_type == "bilstm":
|
257 |
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse")
|
|
|
361 |
if lstm_norm_fn is not None:
|
362 |
if "spectral" in lstm_norm_fn:
|
363 |
print("Applying spectral norm to text encoder LSTM")
|
364 |
+
lstm_norm_fn_pntr = torch.nn.utils.parametrizations.spectral_norm
|
365 |
elif "weight" in lstm_norm_fn:
|
366 |
print("Applying weight norm to text encoder LSTM")
|
367 |
+
lstm_norm_fn_pntr = torch.nn.utils.parametrizations.weight_norm
|
368 |
self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0")
|
369 |
self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0_reverse")
|
370 |
|
|
|
420 |
# Ensure determinant is 1.0 not -1.0
|
421 |
if torch.det(W) < 0:
|
422 |
W[:, 0] = -1 * W[:, 0]
|
423 |
+
p, lower, upper = torch.lu_unpack(*torch.linalg.lu_factor(W))
|
424 |
|
425 |
self.register_buffer("p", p)
|
426 |
# diagonals of lower will always be 1s anyway
|
|
|
586 |
self.in_layers = torch.nn.ModuleList()
|
587 |
self.res_skip_layers = torch.nn.ModuleList()
|
588 |
start = torch.nn.Conv1d(n_in_channels + n_context_dim, n_channels, 1)
|
589 |
+
start = torch.nn.utils.parametrizations.weight_norm(start, name="weight")
|
590 |
self.start = start
|
591 |
self.softplus = torch.nn.Softplus()
|
592 |
self.affine_activation = affine_activation
|
|
|
615 |
# in_layer = nn.utils.weight_norm(in_layer)
|
616 |
self.in_layers.append(in_layer)
|
617 |
res_skip_layer = nn.Conv1d(n_channels, n_channels, 1)
|
618 |
+
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer)
|
619 |
self.res_skip_layers.append(res_skip_layer)
|
620 |
|
621 |
def forward(
|
|
|
793 |
# output is unnormalized bin weights
|
794 |
|
795 |
def forward(self, z, context, inverse=False, seq_lens=None):
|
796 |
+
b_s, _, t_s = z.size(0), z.size(1), z.size(2)
|
797 |
|
798 |
# condition on z_0, transform z_1
|
799 |
n_half = self.half_mel_channels
|
|
|
1055 |
|
1056 |
attn = self.softmax(attn) # softmax along T2
|
1057 |
return attn, attn_logprob
|
1058 |
+
|
1059 |
+
|
1060 |
+
def update_params(config, params):
|
1061 |
+
for param in params:
|
1062 |
+
print(param)
|
1063 |
+
k, v = param.split("=")
|
1064 |
+
try:
|
1065 |
+
v = ast.literal_eval(v)
|
1066 |
+
except Exception as e:
|
1067 |
+
print(e)
|
1068 |
+
|
1069 |
+
k_split = k.split(".")
|
1070 |
+
if len(k_split) > 1:
|
1071 |
+
parent_k = k_split[0]
|
1072 |
+
cur_param = [".".join(k_split[1:]) + "=" + str(v)]
|
1073 |
+
update_params(config[parent_k], cur_param)
|
1074 |
+
elif k in config and len(k_split) == 1:
|
1075 |
+
print(f"overriding {k} with {v}")
|
1076 |
+
config[k] = v
|
1077 |
+
else:
|
1078 |
+
print("{}, {} params not updated".format(k, v))
|
data.py
CHANGED
@@ -39,21 +39,21 @@
|
|
39 |
###############################################################################
|
40 |
|
41 |
import os
|
42 |
-
import argparse
|
43 |
-
import json
|
44 |
-
import numpy as np
|
45 |
-
import lmdb
|
46 |
import pickle as pkl
|
|
|
|
|
47 |
import torch
|
48 |
import torch.utils.data
|
|
|
|
|
|
|
49 |
from scipy.io.wavfile import read
|
50 |
-
from audio_processing import TacotronSTFT
|
51 |
-
from tts_text_processing.text_processing import TextProcessing
|
52 |
from scipy.stats import betabinom
|
53 |
-
from librosa import pyin
|
54 |
-
from common import update_params
|
55 |
from scipy.ndimage import distance_transform_edt as distance_transform
|
56 |
|
|
|
|
|
|
|
57 |
|
58 |
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=0.05):
|
59 |
P = phoneme_count
|
@@ -401,7 +401,8 @@ class Data(torch.utils.data.Dataset):
|
|
401 |
elif os.path.exists(f0_path):
|
402 |
try:
|
403 |
dikt = torch.load(f0_path)
|
404 |
-
except:
|
|
|
405 |
print(f"f0 loading from {f0_path} is broken, recomputing.")
|
406 |
|
407 |
if dikt is not None:
|
@@ -460,147 +461,3 @@ class Data(torch.utils.data.Dataset):
|
|
460 |
|
461 |
def __len__(self):
|
462 |
return len(self.data)
|
463 |
-
|
464 |
-
|
465 |
-
class DataCollate:
|
466 |
-
"""Zero-pads model inputs and targets given number of steps"""
|
467 |
-
|
468 |
-
def __init__(self, n_frames_per_step=1):
|
469 |
-
self.n_frames_per_step = n_frames_per_step
|
470 |
-
|
471 |
-
def __call__(self, batch):
|
472 |
-
"""Collate from normalized data"""
|
473 |
-
# Right zero-pad all one-hot text sequences to max input length
|
474 |
-
input_lengths, ids_sorted_decreasing = torch.sort(
|
475 |
-
torch.LongTensor([len(x["text_encoded"]) for x in batch]),
|
476 |
-
dim=0,
|
477 |
-
descending=True,
|
478 |
-
)
|
479 |
-
|
480 |
-
max_input_len = input_lengths[0]
|
481 |
-
text_padded = torch.LongTensor(len(batch), max_input_len)
|
482 |
-
text_padded.zero_()
|
483 |
-
|
484 |
-
for i in range(len(ids_sorted_decreasing)):
|
485 |
-
text = batch[ids_sorted_decreasing[i]]["text_encoded"]
|
486 |
-
text_padded[i, : text.size(0)] = text
|
487 |
-
|
488 |
-
# Right zero-pad mel-spec
|
489 |
-
num_mel_channels = batch[0]["mel"].size(0)
|
490 |
-
max_target_len = max([x["mel"].size(1) for x in batch])
|
491 |
-
|
492 |
-
# include mel padded, gate padded and speaker ids
|
493 |
-
mel_padded = torch.FloatTensor(len(batch), num_mel_channels, max_target_len)
|
494 |
-
mel_padded.zero_()
|
495 |
-
f0_padded = None
|
496 |
-
p_voiced_padded = None
|
497 |
-
voiced_mask_padded = None
|
498 |
-
energy_avg_padded = None
|
499 |
-
if batch[0]["f0"] is not None:
|
500 |
-
f0_padded = torch.FloatTensor(len(batch), max_target_len)
|
501 |
-
f0_padded.zero_()
|
502 |
-
|
503 |
-
if batch[0]["p_voiced"] is not None:
|
504 |
-
p_voiced_padded = torch.FloatTensor(len(batch), max_target_len)
|
505 |
-
p_voiced_padded.zero_()
|
506 |
-
|
507 |
-
if batch[0]["voiced_mask"] is not None:
|
508 |
-
voiced_mask_padded = torch.FloatTensor(len(batch), max_target_len)
|
509 |
-
voiced_mask_padded.zero_()
|
510 |
-
|
511 |
-
if batch[0]["energy_avg"] is not None:
|
512 |
-
energy_avg_padded = torch.FloatTensor(len(batch), max_target_len)
|
513 |
-
energy_avg_padded.zero_()
|
514 |
-
|
515 |
-
attn_prior_padded = torch.FloatTensor(len(batch), max_target_len, max_input_len)
|
516 |
-
attn_prior_padded.zero_()
|
517 |
-
|
518 |
-
output_lengths = torch.LongTensor(len(batch))
|
519 |
-
speaker_ids = torch.LongTensor(len(batch))
|
520 |
-
audiopaths = []
|
521 |
-
for i in range(len(ids_sorted_decreasing)):
|
522 |
-
mel = batch[ids_sorted_decreasing[i]]["mel"]
|
523 |
-
mel_padded[i, :, : mel.size(1)] = mel
|
524 |
-
if batch[ids_sorted_decreasing[i]]["f0"] is not None:
|
525 |
-
f0 = batch[ids_sorted_decreasing[i]]["f0"]
|
526 |
-
f0_padded[i, : len(f0)] = f0
|
527 |
-
|
528 |
-
if batch[ids_sorted_decreasing[i]]["voiced_mask"] is not None:
|
529 |
-
voiced_mask = batch[ids_sorted_decreasing[i]]["voiced_mask"]
|
530 |
-
voiced_mask_padded[i, : len(f0)] = voiced_mask
|
531 |
-
|
532 |
-
if batch[ids_sorted_decreasing[i]]["p_voiced"] is not None:
|
533 |
-
p_voiced = batch[ids_sorted_decreasing[i]]["p_voiced"]
|
534 |
-
p_voiced_padded[i, : len(f0)] = p_voiced
|
535 |
-
|
536 |
-
if batch[ids_sorted_decreasing[i]]["energy_avg"] is not None:
|
537 |
-
energy_avg = batch[ids_sorted_decreasing[i]]["energy_avg"]
|
538 |
-
energy_avg_padded[i, : len(energy_avg)] = energy_avg
|
539 |
-
|
540 |
-
output_lengths[i] = mel.size(1)
|
541 |
-
speaker_ids[i] = batch[ids_sorted_decreasing[i]]["speaker_id"]
|
542 |
-
audiopath = batch[ids_sorted_decreasing[i]]["audiopath"]
|
543 |
-
audiopaths.append(audiopath)
|
544 |
-
cur_attn_prior = batch[ids_sorted_decreasing[i]]["attn_prior"]
|
545 |
-
if cur_attn_prior is None:
|
546 |
-
attn_prior_padded = None
|
547 |
-
else:
|
548 |
-
attn_prior_padded[
|
549 |
-
i, : cur_attn_prior.size(0), : cur_attn_prior.size(1)
|
550 |
-
] = cur_attn_prior
|
551 |
-
|
552 |
-
return {
|
553 |
-
"mel": mel_padded,
|
554 |
-
"speaker_ids": speaker_ids,
|
555 |
-
"text": text_padded,
|
556 |
-
"input_lengths": input_lengths,
|
557 |
-
"output_lengths": output_lengths,
|
558 |
-
"audiopaths": audiopaths,
|
559 |
-
"attn_prior": attn_prior_padded,
|
560 |
-
"f0": f0_padded,
|
561 |
-
"p_voiced": p_voiced_padded,
|
562 |
-
"voiced_mask": voiced_mask_padded,
|
563 |
-
"energy_avg": energy_avg_padded,
|
564 |
-
}
|
565 |
-
|
566 |
-
|
567 |
-
# ===================================================================
|
568 |
-
# Takes directory of clean audio and makes directory of spectrograms
|
569 |
-
# Useful for making test sets
|
570 |
-
# ===================================================================
|
571 |
-
if __name__ == "__main__":
|
572 |
-
# Get defaults so it can work with no Sacred
|
573 |
-
parser = argparse.ArgumentParser()
|
574 |
-
parser.add_argument("-c", "--config", type=str, help="JSON file for configuration")
|
575 |
-
parser.add_argument("-p", "--params", nargs="+", default=[])
|
576 |
-
args = parser.parse_args()
|
577 |
-
args.rank = 0
|
578 |
-
|
579 |
-
# Parse configs. Globals nicer in this case
|
580 |
-
with open(args.config) as f:
|
581 |
-
data = f.read()
|
582 |
-
|
583 |
-
config = json.loads(data)
|
584 |
-
update_params(config, args.params)
|
585 |
-
print(config)
|
586 |
-
|
587 |
-
data_config = config["data_config"]
|
588 |
-
|
589 |
-
ignore_keys = ["training_files", "validation_files"]
|
590 |
-
trainset = Data(
|
591 |
-
data_config["training_files"],
|
592 |
-
**dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
|
593 |
-
)
|
594 |
-
|
595 |
-
valset = Data(
|
596 |
-
data_config["validation_files"],
|
597 |
-
**dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
|
598 |
-
speaker_ids=trainset.speaker_ids,
|
599 |
-
)
|
600 |
-
|
601 |
-
collate_fn = DataCollate()
|
602 |
-
|
603 |
-
for dataset in (trainset, valset):
|
604 |
-
for i, batch in enumerate(dataset):
|
605 |
-
out = batch
|
606 |
-
print("{}/{}".format(i, len(dataset)))
|
|
|
39 |
###############################################################################
|
40 |
|
41 |
import os
|
|
|
|
|
|
|
|
|
42 |
import pickle as pkl
|
43 |
+
|
44 |
+
import lmdb
|
45 |
import torch
|
46 |
import torch.utils.data
|
47 |
+
import numpy as np
|
48 |
+
|
49 |
+
from librosa import pyin
|
50 |
from scipy.io.wavfile import read
|
|
|
|
|
51 |
from scipy.stats import betabinom
|
|
|
|
|
52 |
from scipy.ndimage import distance_transform_edt as distance_transform
|
53 |
|
54 |
+
from audio_processing import TacotronSTFT
|
55 |
+
from tts_text_processing.text_processing import TextProcessing
|
56 |
+
|
57 |
|
58 |
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=0.05):
|
59 |
P = phoneme_count
|
|
|
401 |
elif os.path.exists(f0_path):
|
402 |
try:
|
403 |
dikt = torch.load(f0_path)
|
404 |
+
except Exception as e:
|
405 |
+
print(e)
|
406 |
print(f"f0 loading from {f0_path} is broken, recomputing.")
|
407 |
|
408 |
if dikt is not None:
|
|
|
461 |
|
462 |
def __len__(self):
|
463 |
return len(self.data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export_weights.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
radtts_path = "models/radtts-pp-dap-model/model_dap_84000.pt"
|
4 |
+
radtts_path_state = "models/radtts-pp-dap-model/model_dap_84000_state.pt"
|
5 |
+
|
6 |
+
checkpoint_dict = torch.load(radtts_path, map_location="cpu")
|
7 |
+
|
8 |
+
del checkpoint_dict['iteration']
|
9 |
+
del checkpoint_dict['optimizer']
|
10 |
+
del checkpoint_dict['learning_rate']
|
11 |
+
|
12 |
+
print(checkpoint_dict.keys())
|
13 |
+
|
14 |
+
torch.save(checkpoint_dict, radtts_path_state)
|
loss.py
DELETED
@@ -1,228 +0,0 @@
|
|
1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
# SPDX-License-Identifier: MIT
|
3 |
-
#
|
4 |
-
# Permission is hereby granted, free of charge, to any person obtaining a
|
5 |
-
# copy of this software and associated documentation files (the "Software"),
|
6 |
-
# to deal in the Software without restriction, including without limitation
|
7 |
-
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
8 |
-
# and/or sell copies of the Software, and to permit persons to whom the
|
9 |
-
# Software is furnished to do so, subject to the following conditions:
|
10 |
-
#
|
11 |
-
# The above copyright notice and this permission notice shall be included in
|
12 |
-
# all copies or substantial portions of the Software.
|
13 |
-
#
|
14 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
17 |
-
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19 |
-
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
20 |
-
# DEALINGS IN THE SOFTWARE.
|
21 |
-
import torch
|
22 |
-
import torch.nn as nn
|
23 |
-
from torch.nn import functional as F
|
24 |
-
from common import get_mask_from_lengths
|
25 |
-
|
26 |
-
|
27 |
-
def compute_flow_loss(
|
28 |
-
z, log_det_W_list, log_s_list, n_elements, n_dims, mask, sigma=1.0
|
29 |
-
):
|
30 |
-
log_det_W_total = 0.0
|
31 |
-
for i, log_s in enumerate(log_s_list):
|
32 |
-
if i == 0:
|
33 |
-
log_s_total = torch.sum(log_s * mask)
|
34 |
-
if len(log_det_W_list):
|
35 |
-
log_det_W_total = log_det_W_list[i]
|
36 |
-
else:
|
37 |
-
log_s_total = log_s_total + torch.sum(log_s * mask)
|
38 |
-
if len(log_det_W_list):
|
39 |
-
log_det_W_total += log_det_W_list[i]
|
40 |
-
|
41 |
-
if len(log_det_W_list):
|
42 |
-
log_det_W_total *= n_elements
|
43 |
-
|
44 |
-
z = z * mask
|
45 |
-
prior_NLL = torch.sum(z * z) / (2 * sigma * sigma)
|
46 |
-
|
47 |
-
loss = prior_NLL - log_s_total - log_det_W_total
|
48 |
-
|
49 |
-
denom = n_elements * n_dims
|
50 |
-
loss = loss / denom
|
51 |
-
loss_prior = prior_NLL / denom
|
52 |
-
return loss, loss_prior
|
53 |
-
|
54 |
-
|
55 |
-
def compute_regression_loss(x_hat, x, mask, name=False):
|
56 |
-
x = x[:, None] if len(x.shape) == 2 else x # add channel dim
|
57 |
-
mask = mask[:, None] if len(mask.shape) == 2 else mask # add channel dim
|
58 |
-
assert len(x.shape) == len(mask.shape)
|
59 |
-
|
60 |
-
x = x * mask
|
61 |
-
x_hat = x_hat * mask
|
62 |
-
|
63 |
-
if name == "vpred":
|
64 |
-
loss = F.binary_cross_entropy_with_logits(x_hat, x, reduction="sum")
|
65 |
-
else:
|
66 |
-
loss = F.mse_loss(x_hat, x, reduction="sum")
|
67 |
-
loss = loss / mask.sum()
|
68 |
-
|
69 |
-
loss_dict = {"loss_{}".format(name): loss}
|
70 |
-
|
71 |
-
return loss_dict
|
72 |
-
|
73 |
-
|
74 |
-
class AttributePredictionLoss(torch.nn.Module):
|
75 |
-
def __init__(self, name, model_config, loss_weight, sigma=1.0):
|
76 |
-
super(AttributePredictionLoss, self).__init__()
|
77 |
-
self.name = name
|
78 |
-
self.sigma = sigma
|
79 |
-
self.model_name = model_config["name"]
|
80 |
-
self.loss_weight = loss_weight
|
81 |
-
self.n_group_size = 1
|
82 |
-
if "n_group_size" in model_config["hparams"]:
|
83 |
-
self.n_group_size = model_config["hparams"]["n_group_size"]
|
84 |
-
|
85 |
-
def forward(self, model_output, lens):
|
86 |
-
mask = get_mask_from_lengths(lens // self.n_group_size)
|
87 |
-
mask = mask[:, None].float()
|
88 |
-
loss_dict = {}
|
89 |
-
if "z" in model_output:
|
90 |
-
n_elements = lens.sum() // self.n_group_size
|
91 |
-
n_dims = model_output["z"].size(1)
|
92 |
-
|
93 |
-
loss, loss_prior = compute_flow_loss(
|
94 |
-
model_output["z"],
|
95 |
-
model_output["log_det_W_list"],
|
96 |
-
model_output["log_s_list"],
|
97 |
-
n_elements,
|
98 |
-
n_dims,
|
99 |
-
mask,
|
100 |
-
self.sigma,
|
101 |
-
)
|
102 |
-
loss_dict = {
|
103 |
-
"loss_{}".format(self.name): (loss, self.loss_weight),
|
104 |
-
"loss_prior_{}".format(self.name): (loss_prior, 0.0),
|
105 |
-
}
|
106 |
-
elif "x_hat" in model_output:
|
107 |
-
loss_dict = compute_regression_loss(
|
108 |
-
model_output["x_hat"], model_output["x"], mask, self.name
|
109 |
-
)
|
110 |
-
for k, v in loss_dict.items():
|
111 |
-
loss_dict[k] = (v, self.loss_weight)
|
112 |
-
|
113 |
-
if len(loss_dict) == 0:
|
114 |
-
raise Exception("loss not supported")
|
115 |
-
|
116 |
-
return loss_dict
|
117 |
-
|
118 |
-
|
119 |
-
class AttentionCTCLoss(torch.nn.Module):
|
120 |
-
def __init__(self, blank_logprob=-1):
|
121 |
-
super(AttentionCTCLoss, self).__init__()
|
122 |
-
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
123 |
-
self.blank_logprob = blank_logprob
|
124 |
-
self.CTCLoss = nn.CTCLoss(zero_infinity=True)
|
125 |
-
|
126 |
-
def forward(self, attn_logprob, in_lens, out_lens):
|
127 |
-
key_lens = in_lens
|
128 |
-
query_lens = out_lens
|
129 |
-
attn_logprob_padded = F.pad(
|
130 |
-
input=attn_logprob, pad=(1, 0, 0, 0, 0, 0, 0, 0), value=self.blank_logprob
|
131 |
-
)
|
132 |
-
cost_total = 0.0
|
133 |
-
for bid in range(attn_logprob.shape[0]):
|
134 |
-
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
|
135 |
-
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[
|
136 |
-
: query_lens[bid], :, : key_lens[bid] + 1
|
137 |
-
]
|
138 |
-
curr_logprob = self.log_softmax(curr_logprob[None])[0]
|
139 |
-
ctc_cost = self.CTCLoss(
|
140 |
-
curr_logprob,
|
141 |
-
target_seq,
|
142 |
-
input_lengths=query_lens[bid : bid + 1],
|
143 |
-
target_lengths=key_lens[bid : bid + 1],
|
144 |
-
)
|
145 |
-
cost_total += ctc_cost
|
146 |
-
cost = cost_total / attn_logprob.shape[0]
|
147 |
-
return cost
|
148 |
-
|
149 |
-
|
150 |
-
class AttentionBinarizationLoss(torch.nn.Module):
|
151 |
-
def __init__(self):
|
152 |
-
super(AttentionBinarizationLoss, self).__init__()
|
153 |
-
|
154 |
-
def forward(self, hard_attention, soft_attention):
|
155 |
-
log_sum = torch.log(soft_attention[hard_attention == 1]).sum()
|
156 |
-
return -log_sum / hard_attention.sum()
|
157 |
-
|
158 |
-
|
159 |
-
class RADTTSLoss(torch.nn.Module):
|
160 |
-
def __init__(
|
161 |
-
self,
|
162 |
-
sigma=1.0,
|
163 |
-
n_group_size=1,
|
164 |
-
dur_model_config=None,
|
165 |
-
f0_model_config=None,
|
166 |
-
energy_model_config=None,
|
167 |
-
vpred_model_config=None,
|
168 |
-
loss_weights=None,
|
169 |
-
):
|
170 |
-
super(RADTTSLoss, self).__init__()
|
171 |
-
self.sigma = sigma
|
172 |
-
self.n_group_size = n_group_size
|
173 |
-
self.loss_weights = loss_weights
|
174 |
-
self.attn_ctc_loss = AttentionCTCLoss(
|
175 |
-
blank_logprob=loss_weights.get("blank_logprob", -1)
|
176 |
-
)
|
177 |
-
self.loss_fns = {}
|
178 |
-
if dur_model_config is not None:
|
179 |
-
self.loss_fns["duration_model_outputs"] = AttributePredictionLoss(
|
180 |
-
"duration", dur_model_config, loss_weights["dur_loss_weight"]
|
181 |
-
)
|
182 |
-
|
183 |
-
if f0_model_config is not None:
|
184 |
-
self.loss_fns["f0_model_outputs"] = AttributePredictionLoss(
|
185 |
-
"f0", f0_model_config, loss_weights["f0_loss_weight"], sigma=1.0
|
186 |
-
)
|
187 |
-
|
188 |
-
if energy_model_config is not None:
|
189 |
-
self.loss_fns["energy_model_outputs"] = AttributePredictionLoss(
|
190 |
-
"energy", energy_model_config, loss_weights["energy_loss_weight"]
|
191 |
-
)
|
192 |
-
|
193 |
-
if vpred_model_config is not None:
|
194 |
-
self.loss_fns["vpred_model_outputs"] = AttributePredictionLoss(
|
195 |
-
"vpred", vpred_model_config, loss_weights["vpred_loss_weight"]
|
196 |
-
)
|
197 |
-
|
198 |
-
def forward(self, model_output, in_lens, out_lens):
|
199 |
-
loss_dict = {}
|
200 |
-
if len(model_output["z_mel"]):
|
201 |
-
n_elements = out_lens.sum() // self.n_group_size
|
202 |
-
mask = get_mask_from_lengths(out_lens // self.n_group_size)
|
203 |
-
mask = mask[:, None].float()
|
204 |
-
n_dims = model_output["z_mel"].size(1)
|
205 |
-
loss_mel, loss_prior_mel = compute_flow_loss(
|
206 |
-
model_output["z_mel"],
|
207 |
-
model_output["log_det_W_list"],
|
208 |
-
model_output["log_s_list"],
|
209 |
-
n_elements,
|
210 |
-
n_dims,
|
211 |
-
mask,
|
212 |
-
self.sigma,
|
213 |
-
)
|
214 |
-
loss_dict["loss_mel"] = (loss_mel, 1.0) # loss, weight
|
215 |
-
loss_dict["loss_prior_mel"] = (loss_prior_mel, 0.0)
|
216 |
-
|
217 |
-
ctc_cost = self.attn_ctc_loss(model_output["attn_logprob"], in_lens, out_lens)
|
218 |
-
loss_dict["loss_ctc"] = (ctc_cost, self.loss_weights["ctc_loss_weight"])
|
219 |
-
|
220 |
-
for k in model_output:
|
221 |
-
if k in self.loss_fns:
|
222 |
-
if model_output[k] is not None and len(model_output[k]) > 0:
|
223 |
-
t_lens = in_lens if "dur" in k else out_lens
|
224 |
-
mout = model_output[k]
|
225 |
-
for loss_name, v in self.loss_fns[k](mout, t_lens).items():
|
226 |
-
loss_dict[loss_name] = v
|
227 |
-
|
228 |
-
return loss_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
partialconv1d.py
CHANGED
@@ -13,10 +13,9 @@
|
|
13 |
|
14 |
import torch
|
15 |
import torch.nn.functional as F
|
16 |
-
from torch import nn
|
17 |
|
18 |
|
19 |
-
class PartialConv1d(nn.Conv1d):
|
20 |
def __init__(self, *args, **kwargs):
|
21 |
self.multi_channel = False
|
22 |
self.return_mask = False
|
|
|
13 |
|
14 |
import torch
|
15 |
import torch.nn.functional as F
|
|
|
16 |
|
17 |
|
18 |
+
class PartialConv1d(torch.nn.Conv1d):
|
19 |
def __init__(self, *args, **kwargs):
|
20 |
self.multi_channel = False
|
21 |
self.return_mask = False
|
radam.py
DELETED
@@ -1,114 +0,0 @@
|
|
1 |
-
# Original source taken from https://github.com/LiyuanLucasLiu/RAdam
|
2 |
-
#
|
3 |
-
# Copyright 2019 Liyuan Liu
|
4 |
-
#
|
5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
-
# you may not use this file except in compliance with the License.
|
7 |
-
# You may obtain a copy of the License at
|
8 |
-
#
|
9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
-
#
|
11 |
-
# Unless required by applicable law or agreed to in writing, software
|
12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
-
# See the License for the specific language governing permissions and
|
15 |
-
# limitations under the License.
|
16 |
-
import math
|
17 |
-
|
18 |
-
import torch
|
19 |
-
|
20 |
-
# pylint: disable=no-name-in-module
|
21 |
-
from torch.optim.optimizer import Optimizer
|
22 |
-
|
23 |
-
|
24 |
-
class RAdam(Optimizer):
|
25 |
-
"""RAdam optimizer"""
|
26 |
-
|
27 |
-
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
28 |
-
"""
|
29 |
-
Init
|
30 |
-
|
31 |
-
:param params: parameters to optimize
|
32 |
-
:param lr: learning rate
|
33 |
-
:param betas: beta
|
34 |
-
:param eps: numerical precision
|
35 |
-
:param weight_decay: weight decay weight
|
36 |
-
"""
|
37 |
-
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
38 |
-
self.buffer = [[None, None, None] for _ in range(10)]
|
39 |
-
super().__init__(params, defaults)
|
40 |
-
|
41 |
-
def step(self, closure=None):
|
42 |
-
loss = None
|
43 |
-
if closure is not None:
|
44 |
-
loss = closure()
|
45 |
-
|
46 |
-
for group in self.param_groups:
|
47 |
-
for p in group["params"]:
|
48 |
-
if p.grad is None:
|
49 |
-
continue
|
50 |
-
grad = p.grad.data.float()
|
51 |
-
if grad.is_sparse:
|
52 |
-
raise RuntimeError("RAdam does not support sparse gradients")
|
53 |
-
|
54 |
-
p_data_fp32 = p.data.float()
|
55 |
-
|
56 |
-
state = self.state[p]
|
57 |
-
|
58 |
-
if len(state) == 0:
|
59 |
-
state["step"] = 0
|
60 |
-
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
61 |
-
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
62 |
-
else:
|
63 |
-
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
|
64 |
-
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
|
65 |
-
|
66 |
-
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
67 |
-
beta1, beta2 = group["betas"]
|
68 |
-
|
69 |
-
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
70 |
-
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
71 |
-
|
72 |
-
state["step"] += 1
|
73 |
-
buffered = self.buffer[int(state["step"] % 10)]
|
74 |
-
if state["step"] == buffered[0]:
|
75 |
-
N_sma, step_size = buffered[1], buffered[2]
|
76 |
-
else:
|
77 |
-
buffered[0] = state["step"]
|
78 |
-
beta2_t = beta2 ** state["step"]
|
79 |
-
N_sma_max = 2 / (1 - beta2) - 1
|
80 |
-
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
|
81 |
-
buffered[1] = N_sma
|
82 |
-
|
83 |
-
# more conservative since it's an approximated value
|
84 |
-
if N_sma >= 5:
|
85 |
-
step_size = (
|
86 |
-
group["lr"]
|
87 |
-
* math.sqrt(
|
88 |
-
(1 - beta2_t)
|
89 |
-
* (N_sma - 4)
|
90 |
-
/ (N_sma_max - 4)
|
91 |
-
* (N_sma - 2)
|
92 |
-
/ N_sma
|
93 |
-
* N_sma_max
|
94 |
-
/ (N_sma_max - 2)
|
95 |
-
)
|
96 |
-
/ (1 - beta1 ** state["step"])
|
97 |
-
)
|
98 |
-
else:
|
99 |
-
step_size = group["lr"] / (1 - beta1 ** state["step"])
|
100 |
-
buffered[2] = step_size
|
101 |
-
|
102 |
-
if group["weight_decay"] != 0:
|
103 |
-
p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
|
104 |
-
|
105 |
-
# more conservative since it's an approximated value
|
106 |
-
if N_sma >= 5:
|
107 |
-
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
108 |
-
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
109 |
-
else:
|
110 |
-
p_data_fp32.add_(-step_size, exp_avg)
|
111 |
-
|
112 |
-
p.data.copy_(p_data_fp32)
|
113 |
-
|
114 |
-
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
radtts.py
CHANGED
@@ -28,8 +28,7 @@ from common import AffineTransformationLayer, LinearNorm, ExponentialClass
|
|
28 |
from common import get_mask_from_lengths
|
29 |
from attribute_prediction_model import get_attribute_prediction_model
|
30 |
from alignment import mas_width1 as mas
|
31 |
-
|
32 |
-
use_cuda = torch.cuda.is_available()
|
33 |
|
34 |
|
35 |
class FlowStep(nn.Module):
|
@@ -202,10 +201,10 @@ class RADTTS(torch.nn.Module):
|
|
202 |
if context_lstm_norm is not None:
|
203 |
if "spectral" in context_lstm_norm:
|
204 |
print("Applying spectral norm to context encoder LSTM")
|
205 |
-
lstm_norm_fn_pntr = torch.nn.utils.spectral_norm
|
206 |
elif "weight" in context_lstm_norm:
|
207 |
print("Applying weight norm to context encoder LSTM")
|
208 |
-
lstm_norm_fn_pntr = torch.nn.utils.weight_norm
|
209 |
|
210 |
self.context_lstm = lstm_norm_fn_pntr(
|
211 |
self.context_lstm, "weight_hh_l0"
|
@@ -688,11 +687,10 @@ class RADTTS(torch.nn.Module):
|
|
688 |
|
689 |
if dur is None:
|
690 |
# get token durations
|
691 |
-
z_dur =
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
z_dur = z_dur.normal_() * sigma_dur
|
696 |
|
697 |
dur = self.dur_pred_layer.infer(z_dur, txt_enc, spk_vec_text)
|
698 |
if dur.shape[-1] < txt_enc.shape[-1]:
|
@@ -752,9 +750,7 @@ class RADTTS(torch.nn.Module):
|
|
752 |
dtype=torch.float32,
|
753 |
)
|
754 |
* sigma_f0
|
755 |
-
)
|
756 |
-
if use_cuda:
|
757 |
-
z_f0 = z_f0.cuda()
|
758 |
|
759 |
f0 = self.infer_f0(
|
760 |
z_f0,
|
@@ -780,13 +776,11 @@ class RADTTS(torch.nn.Module):
|
|
780 |
n_energy_feature_channels,
|
781 |
max_n_frames,
|
782 |
dtype=torch.float32,
|
|
|
783 |
)
|
784 |
* sigma_energy
|
785 |
)
|
786 |
|
787 |
-
if use_cuda:
|
788 |
-
z_energy_avg = z_energy_avg.cuda()
|
789 |
-
|
790 |
energy_avg = self.infer_energy(
|
791 |
z_energy_avg, ap_txt_enc_time_expanded, spk_vec, out_lens
|
792 |
)[:, 0]
|
@@ -829,9 +823,7 @@ class RADTTS(torch.nn.Module):
|
|
829 |
80 * self.n_group_size,
|
830 |
max_n_frames // self.n_group_size,
|
831 |
dtype=torch.float32,
|
832 |
-
)
|
833 |
-
if use_cuda:
|
834 |
-
residual = residual.cuda()
|
835 |
|
836 |
residual = residual * sigma
|
837 |
|
@@ -921,15 +913,17 @@ class RADTTS(torch.nn.Module):
|
|
921 |
try:
|
922 |
nn.utils.remove_spectral_norm(module, name="weight_hh_l0")
|
923 |
print("Removed spectral norm from {}".format(name))
|
924 |
-
except:
|
925 |
-
|
|
|
926 |
try:
|
927 |
nn.utils.remove_spectral_norm(module, name="weight_hh_l0_reverse")
|
928 |
print("Removed spectral norm from {}".format(name))
|
929 |
-
except:
|
930 |
-
|
|
|
931 |
try:
|
932 |
nn.utils.remove_weight_norm(module)
|
933 |
print("Removed wnorm from {}".format(name))
|
934 |
-
except:
|
935 |
-
|
|
|
28 |
from common import get_mask_from_lengths
|
29 |
from attribute_prediction_model import get_attribute_prediction_model
|
30 |
from alignment import mas_width1 as mas
|
31 |
+
from torch_env import device
|
|
|
32 |
|
33 |
|
34 |
class FlowStep(nn.Module):
|
|
|
201 |
if context_lstm_norm is not None:
|
202 |
if "spectral" in context_lstm_norm:
|
203 |
print("Applying spectral norm to context encoder LSTM")
|
204 |
+
lstm_norm_fn_pntr = torch.nn.utils.parametrizations.spectral_norm
|
205 |
elif "weight" in context_lstm_norm:
|
206 |
print("Applying weight norm to context encoder LSTM")
|
207 |
+
lstm_norm_fn_pntr = torch.nn.utils.parametrizations.weight_norm
|
208 |
|
209 |
self.context_lstm = lstm_norm_fn_pntr(
|
210 |
self.context_lstm, "weight_hh_l0"
|
|
|
687 |
|
688 |
if dur is None:
|
689 |
# get token durations
|
690 |
+
z_dur = (
|
691 |
+
torch.randn(batch_size, 1, n_tokens, dtype=torch.float32, device=device)
|
692 |
+
* sigma_dur
|
693 |
+
)
|
|
|
694 |
|
695 |
dur = self.dur_pred_layer.infer(z_dur, txt_enc, spk_vec_text)
|
696 |
if dur.shape[-1] < txt_enc.shape[-1]:
|
|
|
750 |
dtype=torch.float32,
|
751 |
)
|
752 |
* sigma_f0
|
753 |
+
).to(device)
|
|
|
|
|
754 |
|
755 |
f0 = self.infer_f0(
|
756 |
z_f0,
|
|
|
776 |
n_energy_feature_channels,
|
777 |
max_n_frames,
|
778 |
dtype=torch.float32,
|
779 |
+
device=device,
|
780 |
)
|
781 |
* sigma_energy
|
782 |
)
|
783 |
|
|
|
|
|
|
|
784 |
energy_avg = self.infer_energy(
|
785 |
z_energy_avg, ap_txt_enc_time_expanded, spk_vec, out_lens
|
786 |
)[:, 0]
|
|
|
823 |
80 * self.n_group_size,
|
824 |
max_n_frames // self.n_group_size,
|
825 |
dtype=torch.float32,
|
826 |
+
).to(device)
|
|
|
|
|
827 |
|
828 |
residual = residual * sigma
|
829 |
|
|
|
913 |
try:
|
914 |
nn.utils.remove_spectral_norm(module, name="weight_hh_l0")
|
915 |
print("Removed spectral norm from {}".format(name))
|
916 |
+
except Exception as e:
|
917 |
+
print(e)
|
918 |
+
|
919 |
try:
|
920 |
nn.utils.remove_spectral_norm(module, name="weight_hh_l0_reverse")
|
921 |
print("Removed spectral norm from {}".format(name))
|
922 |
+
except Exception as e:
|
923 |
+
print(e)
|
924 |
+
|
925 |
try:
|
926 |
nn.utils.remove_weight_norm(module)
|
927 |
print("Removed wnorm from {}".format(name))
|
928 |
+
except Exception as e:
|
929 |
+
print(e)
|
requirements.txt
CHANGED
@@ -9,7 +9,4 @@ numba
|
|
9 |
lmdb
|
10 |
librosa
|
11 |
|
12 |
-
unidecode
|
13 |
-
inflect
|
14 |
-
|
15 |
git+https://github.com/langtech-bsc/vocos.git@matcha
|
|
|
9 |
lmdb
|
10 |
librosa
|
11 |
|
|
|
|
|
|
|
12 |
git+https://github.com/langtech-bsc/vocos.git@matcha
|
torch_env.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
seed = 1234
|
4 |
+
|
5 |
+
# use_mps = torch.mps.is_available()
|
6 |
+
use_mps = False
|
7 |
+
use_cuda = torch.cuda.is_available()
|
8 |
+
|
9 |
+
if use_mps:
|
10 |
+
device = "mps"
|
11 |
+
torch.mps.manual_seed(seed)
|
12 |
+
elif use_cuda:
|
13 |
+
device = "cuda"
|
14 |
+
torch.cuda.manual_seed(seed)
|
15 |
+
else:
|
16 |
+
device = "cpu"
|
17 |
+
torch.manual_seed(seed)
|
18 |
+
|
19 |
+
print(f"Inference device: {device}")
|
tts_text_processing/abbreviations.py
DELETED
@@ -1,57 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
|
3 |
-
_no_period_re = re.compile(r"(No[.])(?=[ ]?[0-9])")
|
4 |
-
_percent_re = re.compile(r"([ ]?[%])")
|
5 |
-
_half_re = re.compile("([0-9]½)|(½)")
|
6 |
-
|
7 |
-
|
8 |
-
# List of (regular expression, replacement) pairs for abbreviations:
|
9 |
-
_abbreviations = [
|
10 |
-
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
11 |
-
for x in [
|
12 |
-
("mrs", "misess"),
|
13 |
-
("ms", "miss"),
|
14 |
-
("mr", "mister"),
|
15 |
-
("dr", "doctor"),
|
16 |
-
("st", "saint"),
|
17 |
-
("co", "company"),
|
18 |
-
("jr", "junior"),
|
19 |
-
("maj", "major"),
|
20 |
-
("gen", "general"),
|
21 |
-
("drs", "doctors"),
|
22 |
-
("rev", "reverend"),
|
23 |
-
("lt", "lieutenant"),
|
24 |
-
("hon", "honorable"),
|
25 |
-
("sgt", "sergeant"),
|
26 |
-
("capt", "captain"),
|
27 |
-
("esq", "esquire"),
|
28 |
-
("ltd", "limited"),
|
29 |
-
("col", "colonel"),
|
30 |
-
("ft", "fort"),
|
31 |
-
]
|
32 |
-
]
|
33 |
-
|
34 |
-
|
35 |
-
def _expand_no_period(m):
|
36 |
-
word = m.group(0)
|
37 |
-
if word[0] == "N":
|
38 |
-
return "Number"
|
39 |
-
return "number"
|
40 |
-
|
41 |
-
|
42 |
-
def _expand_percent(m):
|
43 |
-
return " percent"
|
44 |
-
|
45 |
-
|
46 |
-
def _expand_half(m):
|
47 |
-
word = m.group(1)
|
48 |
-
if word is None:
|
49 |
-
return "half"
|
50 |
-
return word[0] + " and a half"
|
51 |
-
|
52 |
-
|
53 |
-
def normalize_abbreviations(text):
|
54 |
-
text = re.sub(_no_period_re, _expand_no_period, text)
|
55 |
-
text = re.sub(_percent_re, _expand_percent, text)
|
56 |
-
text = re.sub(_half_re, _expand_half, text)
|
57 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tts_text_processing/acronyms.py
DELETED
@@ -1,69 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
|
3 |
-
_letter_to_arpabet = {
|
4 |
-
"A": "EY1",
|
5 |
-
"B": "B IY1",
|
6 |
-
"C": "S IY1",
|
7 |
-
"D": "D IY1",
|
8 |
-
"E": "IY1",
|
9 |
-
"F": "EH1 F",
|
10 |
-
"G": "JH IY1",
|
11 |
-
"H": "EY1 CH",
|
12 |
-
"I": "AY1",
|
13 |
-
"J": "JH EY1",
|
14 |
-
"K": "K EY1",
|
15 |
-
"L": "EH1 L",
|
16 |
-
"M": "EH1 M",
|
17 |
-
"N": "EH1 N",
|
18 |
-
"O": "OW1",
|
19 |
-
"P": "P IY1",
|
20 |
-
"Q": "K Y UW1",
|
21 |
-
"R": "AA1 R",
|
22 |
-
"S": "EH1 S",
|
23 |
-
"T": "T IY1",
|
24 |
-
"U": "Y UW1",
|
25 |
-
"V": "V IY1",
|
26 |
-
"X": "EH1 K S",
|
27 |
-
"Y": "W AY1",
|
28 |
-
"W": "D AH1 B AH0 L Y UW0",
|
29 |
-
"Z": "Z IY1",
|
30 |
-
"s": "Z",
|
31 |
-
}
|
32 |
-
|
33 |
-
# must ignore roman numerals
|
34 |
-
# _acronym_re = re.compile(r'([A-Z][A-Z]+)s?|([A-Z]\.([A-Z]\.)+s?)')
|
35 |
-
_acronym_re = re.compile(r"([A-Z][A-Z]+)s?")
|
36 |
-
|
37 |
-
|
38 |
-
class AcronymNormalizer(object):
|
39 |
-
def __init__(self, phoneme_dict):
|
40 |
-
self.phoneme_dict = phoneme_dict
|
41 |
-
|
42 |
-
def normalize_acronyms(self, text):
|
43 |
-
def _expand_acronyms(m, add_spaces=True):
|
44 |
-
acronym = m.group(0)
|
45 |
-
# remove dots if they exist
|
46 |
-
acronym = re.sub("\.", "", acronym)
|
47 |
-
|
48 |
-
acronym = "".join(acronym.split())
|
49 |
-
arpabet = self.phoneme_dict.lookup(acronym)
|
50 |
-
|
51 |
-
if arpabet is None:
|
52 |
-
acronym = list(acronym)
|
53 |
-
arpabet = ["{" + _letter_to_arpabet[letter] + "}" for letter in acronym]
|
54 |
-
# temporary fix
|
55 |
-
if arpabet[-1] == "{Z}" and len(arpabet) > 1:
|
56 |
-
arpabet[-2] = arpabet[-2][:-1] + " " + arpabet[-1][1:]
|
57 |
-
del arpabet[-1]
|
58 |
-
arpabet = " ".join(arpabet)
|
59 |
-
elif len(arpabet) == 1:
|
60 |
-
arpabet = "{" + arpabet[0] + "}"
|
61 |
-
else:
|
62 |
-
arpabet = acronym
|
63 |
-
return arpabet
|
64 |
-
|
65 |
-
text = re.sub(_acronym_re, _expand_acronyms, text)
|
66 |
-
return text
|
67 |
-
|
68 |
-
def __call__(self, text):
|
69 |
-
return self.normalize_acronyms(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tts_text_processing/cleaners.py
CHANGED
@@ -1,26 +1,8 @@
|
|
1 |
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
|
3 |
-
"""
|
4 |
-
Cleaners are transformations that run over the input text at both training and eval time.
|
5 |
-
|
6 |
-
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
7 |
-
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
8 |
-
1. "english_cleaners" for English text
|
9 |
-
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
10 |
-
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
11 |
-
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
12 |
-
the symbols in symbols.py to match your data).
|
13 |
-
"""
|
14 |
-
|
15 |
import re
|
16 |
from string import punctuation
|
17 |
from functools import reduce
|
18 |
-
from unidecode import unidecode
|
19 |
-
from .numerical import normalize_numbers, normalize_currency
|
20 |
-
from .acronyms import AcronymNormalizer
|
21 |
-
from .datestime import normalize_datestime
|
22 |
-
from .letters_and_numbers import normalize_letters_and_numbers
|
23 |
-
from .abbreviations import normalize_abbreviations
|
24 |
|
25 |
|
26 |
# Regular expression matching whitespace:
|
@@ -30,26 +12,6 @@ _whitespace_re = re.compile(r"\s+")
|
|
30 |
_arpa_re = re.compile(r"{[^}]+}|\S+")
|
31 |
|
32 |
|
33 |
-
def expand_abbreviations(text):
|
34 |
-
return normalize_abbreviations(text)
|
35 |
-
|
36 |
-
|
37 |
-
def expand_numbers(text):
|
38 |
-
return normalize_numbers(text)
|
39 |
-
|
40 |
-
|
41 |
-
def expand_currency(text):
|
42 |
-
return normalize_currency(text)
|
43 |
-
|
44 |
-
|
45 |
-
def expand_datestime(text):
|
46 |
-
return normalize_datestime(text)
|
47 |
-
|
48 |
-
|
49 |
-
def expand_letters_and_numbers(text):
|
50 |
-
return normalize_letters_and_numbers(text)
|
51 |
-
|
52 |
-
|
53 |
def lowercase(text):
|
54 |
return text.lower()
|
55 |
|
@@ -58,21 +20,6 @@ def collapse_whitespace(text):
|
|
58 |
return re.sub(_whitespace_re, " ", text)
|
59 |
|
60 |
|
61 |
-
def separate_acronyms(text):
|
62 |
-
text = re.sub(r"([0-9]+)([a-zA-Z]+)", r"\1 \2", text)
|
63 |
-
text = re.sub(r"([a-zA-Z]+)([0-9]+)", r"\1 \2", text)
|
64 |
-
return text
|
65 |
-
|
66 |
-
|
67 |
-
def convert_to_ascii(text):
|
68 |
-
return unidecode(text)
|
69 |
-
|
70 |
-
|
71 |
-
def dehyphenize_compound_words(text):
|
72 |
-
text = re.sub(r"(?<=[a-zA-Z0-9])-(?=[a-zA-Z])", " ", text)
|
73 |
-
return text
|
74 |
-
|
75 |
-
|
76 |
def remove_space_before_punctuation(text):
|
77 |
return re.sub(r"\s([{}](?:\s|$))".format(punctuation), r"\1", text)
|
78 |
|
@@ -81,7 +28,6 @@ class Cleaner(object):
|
|
81 |
def __init__(self, cleaner_names, phonemedict):
|
82 |
self.cleaner_names = cleaner_names
|
83 |
self.phonemedict = phonemedict
|
84 |
-
self.acronym_normalizer = AcronymNormalizer(self.phonemedict)
|
85 |
|
86 |
def __call__(self, text):
|
87 |
for cleaner_name in self.cleaner_names:
|
@@ -94,30 +40,13 @@ class Cleaner(object):
|
|
94 |
for split in _arpa_re.findall(text)
|
95 |
]
|
96 |
text = " ".join(text)
|
|
|
97 |
text = remove_space_before_punctuation(text)
|
|
|
98 |
return text
|
99 |
|
100 |
def get_cleaner_fns(self, cleaner_name):
|
101 |
-
|
102 |
-
|
103 |
-
word_fns = []
|
104 |
-
elif cleaner_name == "english_cleaners":
|
105 |
-
sequence_fns = [collapse_whitespace, convert_to_ascii, lowercase]
|
106 |
-
word_fns = [expand_numbers, expand_abbreviations]
|
107 |
-
elif cleaner_name == "radtts_cleaners":
|
108 |
-
sequence_fns = [
|
109 |
-
collapse_whitespace,
|
110 |
-
expand_currency,
|
111 |
-
expand_datestime,
|
112 |
-
expand_letters_and_numbers,
|
113 |
-
]
|
114 |
-
word_fns = [expand_numbers, expand_abbreviations]
|
115 |
-
elif cleaner_name == "ukrainian_cleaners":
|
116 |
-
sequence_fns = [lowercase, collapse_whitespace]
|
117 |
-
word_fns = []
|
118 |
-
elif cleaner_name == "transliteration_cleaners":
|
119 |
-
sequence_fns = [convert_to_ascii, lowercase, collapse_whitespace]
|
120 |
-
else:
|
121 |
-
raise Exception("{} cleaner not supported".format(cleaner_name))
|
122 |
|
123 |
return sequence_fns, word_fns
|
|
|
1 |
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import re
|
4 |
from string import punctuation
|
5 |
from functools import reduce
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
# Regular expression matching whitespace:
|
|
|
12 |
_arpa_re = re.compile(r"{[^}]+}|\S+")
|
13 |
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def lowercase(text):
|
16 |
return text.lower()
|
17 |
|
|
|
20 |
return re.sub(_whitespace_re, " ", text)
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def remove_space_before_punctuation(text):
|
24 |
return re.sub(r"\s([{}](?:\s|$))".format(punctuation), r"\1", text)
|
25 |
|
|
|
28 |
def __init__(self, cleaner_names, phonemedict):
|
29 |
self.cleaner_names = cleaner_names
|
30 |
self.phonemedict = phonemedict
|
|
|
31 |
|
32 |
def __call__(self, text):
|
33 |
for cleaner_name in self.cleaner_names:
|
|
|
40 |
for split in _arpa_re.findall(text)
|
41 |
]
|
42 |
text = " ".join(text)
|
43 |
+
|
44 |
text = remove_space_before_punctuation(text)
|
45 |
+
|
46 |
return text
|
47 |
|
48 |
def get_cleaner_fns(self, cleaner_name):
|
49 |
+
sequence_fns = [lowercase, collapse_whitespace]
|
50 |
+
word_fns = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
return sequence_fns, word_fns
|
tts_text_processing/cmudict.py
DELETED
@@ -1,140 +0,0 @@
|
|
1 |
-
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
-
|
3 |
-
import re
|
4 |
-
|
5 |
-
|
6 |
-
valid_symbols = [
|
7 |
-
"AA",
|
8 |
-
"AA0",
|
9 |
-
"AA1",
|
10 |
-
"AA2",
|
11 |
-
"AE",
|
12 |
-
"AE0",
|
13 |
-
"AE1",
|
14 |
-
"AE2",
|
15 |
-
"AH",
|
16 |
-
"AH0",
|
17 |
-
"AH1",
|
18 |
-
"AH2",
|
19 |
-
"AO",
|
20 |
-
"AO0",
|
21 |
-
"AO1",
|
22 |
-
"AO2",
|
23 |
-
"AW",
|
24 |
-
"AW0",
|
25 |
-
"AW1",
|
26 |
-
"AW2",
|
27 |
-
"AY",
|
28 |
-
"AY0",
|
29 |
-
"AY1",
|
30 |
-
"AY2",
|
31 |
-
"B",
|
32 |
-
"CH",
|
33 |
-
"D",
|
34 |
-
"DH",
|
35 |
-
"EH",
|
36 |
-
"EH0",
|
37 |
-
"EH1",
|
38 |
-
"EH2",
|
39 |
-
"ER",
|
40 |
-
"ER0",
|
41 |
-
"ER1",
|
42 |
-
"ER2",
|
43 |
-
"EY",
|
44 |
-
"EY0",
|
45 |
-
"EY1",
|
46 |
-
"EY2",
|
47 |
-
"F",
|
48 |
-
"G",
|
49 |
-
"HH",
|
50 |
-
"IH",
|
51 |
-
"IH0",
|
52 |
-
"IH1",
|
53 |
-
"IH2",
|
54 |
-
"IY",
|
55 |
-
"IY0",
|
56 |
-
"IY1",
|
57 |
-
"IY2",
|
58 |
-
"JH",
|
59 |
-
"K",
|
60 |
-
"L",
|
61 |
-
"M",
|
62 |
-
"N",
|
63 |
-
"NG",
|
64 |
-
"OW",
|
65 |
-
"OW0",
|
66 |
-
"OW1",
|
67 |
-
"OW2",
|
68 |
-
"OY",
|
69 |
-
"OY0",
|
70 |
-
"OY1",
|
71 |
-
"OY2",
|
72 |
-
"P",
|
73 |
-
"R",
|
74 |
-
"S",
|
75 |
-
"SH",
|
76 |
-
"T",
|
77 |
-
"TH",
|
78 |
-
"UH",
|
79 |
-
"UH0",
|
80 |
-
"UH1",
|
81 |
-
"UH2",
|
82 |
-
"UW",
|
83 |
-
"UW0",
|
84 |
-
"UW1",
|
85 |
-
"UW2",
|
86 |
-
"V",
|
87 |
-
"W",
|
88 |
-
"Y",
|
89 |
-
"Z",
|
90 |
-
"ZH",
|
91 |
-
]
|
92 |
-
|
93 |
-
_valid_symbol_set = set(valid_symbols)
|
94 |
-
|
95 |
-
|
96 |
-
class CMUDict:
|
97 |
-
"""Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
|
98 |
-
|
99 |
-
def __init__(self, file_or_path, keep_ambiguous=True):
|
100 |
-
if isinstance(file_or_path, str):
|
101 |
-
with open(file_or_path, encoding="latin-1") as f:
|
102 |
-
entries = _parse_cmudict(f)
|
103 |
-
else:
|
104 |
-
entries = _parse_cmudict(file_or_path)
|
105 |
-
if not keep_ambiguous:
|
106 |
-
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
107 |
-
self._entries = entries
|
108 |
-
|
109 |
-
def __len__(self):
|
110 |
-
return len(self._entries)
|
111 |
-
|
112 |
-
def lookup(self, word):
|
113 |
-
"""Returns list of ARPAbet pronunciations of the given word."""
|
114 |
-
return self._entries.get(word.upper())
|
115 |
-
|
116 |
-
|
117 |
-
_alt_re = re.compile(r"\([0-9]+\)")
|
118 |
-
|
119 |
-
|
120 |
-
def _parse_cmudict(file):
|
121 |
-
cmudict = {}
|
122 |
-
for line in file:
|
123 |
-
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
|
124 |
-
parts = line.split(" ")
|
125 |
-
word = re.sub(_alt_re, "", parts[0])
|
126 |
-
pronunciation = _get_pronunciation(parts[1])
|
127 |
-
if pronunciation:
|
128 |
-
if word in cmudict:
|
129 |
-
cmudict[word].append(pronunciation)
|
130 |
-
else:
|
131 |
-
cmudict[word] = [pronunciation]
|
132 |
-
return cmudict
|
133 |
-
|
134 |
-
|
135 |
-
def _get_pronunciation(s):
|
136 |
-
parts = s.strip().split(" ")
|
137 |
-
for part in parts:
|
138 |
-
if part not in _valid_symbol_set:
|
139 |
-
return None
|
140 |
-
return " ".join(parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tts_text_processing/datestime.py
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
-
|
3 |
-
import re
|
4 |
-
|
5 |
-
_ampm_re = re.compile(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):?([0-5][0-9])?\s*([AaPp][Mm]\b)")
|
6 |
-
|
7 |
-
|
8 |
-
def _expand_ampm(m):
|
9 |
-
matches = list(m.groups(0))
|
10 |
-
txt = matches[0]
|
11 |
-
txt = txt if int(matches[1]) == 0 else txt + " " + matches[1]
|
12 |
-
|
13 |
-
if matches[2][0].lower() == "a":
|
14 |
-
txt += " a.m."
|
15 |
-
elif matches[2][0].lower() == "p":
|
16 |
-
txt += " p.m."
|
17 |
-
|
18 |
-
return txt
|
19 |
-
|
20 |
-
|
21 |
-
def normalize_datestime(text):
|
22 |
-
text = re.sub(_ampm_re, _expand_ampm, text)
|
23 |
-
# text = re.sub(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])?", r"\1 \2", text)
|
24 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tts_text_processing/grapheme_dictionary.py
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
-
|
3 |
-
import re
|
4 |
-
|
5 |
-
_alt_re = re.compile(r"\([0-9]+\)")
|
6 |
-
|
7 |
-
|
8 |
-
class Grapheme2PhonemeDictionary:
|
9 |
-
"""Thin wrapper around g2p data."""
|
10 |
-
|
11 |
-
def __init__(self, file_or_path, keep_ambiguous=True, encoding="latin-1"):
|
12 |
-
with open(file_or_path, encoding=encoding) as f:
|
13 |
-
entries = _parse_g2p(f)
|
14 |
-
if not keep_ambiguous:
|
15 |
-
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
16 |
-
self._entries = entries
|
17 |
-
|
18 |
-
def __len__(self):
|
19 |
-
return len(self._entries)
|
20 |
-
|
21 |
-
def lookup(self, word):
|
22 |
-
"""Returns list of pronunciations of the given word."""
|
23 |
-
return self._entries.get(word.upper())
|
24 |
-
|
25 |
-
|
26 |
-
def _parse_g2p(file):
|
27 |
-
g2p = {}
|
28 |
-
for line in file:
|
29 |
-
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
|
30 |
-
parts = line.split(" ")
|
31 |
-
word = re.sub(_alt_re, "", parts[0])
|
32 |
-
pronunciation = parts[1].strip()
|
33 |
-
if word in g2p:
|
34 |
-
g2p[word].append(pronunciation)
|
35 |
-
else:
|
36 |
-
g2p[word] = [pronunciation]
|
37 |
-
return g2p
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tts_text_processing/heteronyms
DELETED
@@ -1,413 +0,0 @@
|
|
1 |
-
abject
|
2 |
-
abrogate
|
3 |
-
absent
|
4 |
-
abstract
|
5 |
-
abuse
|
6 |
-
ache
|
7 |
-
acre
|
8 |
-
acuminate
|
9 |
-
addict
|
10 |
-
address
|
11 |
-
adduct
|
12 |
-
adele
|
13 |
-
advocate
|
14 |
-
affect
|
15 |
-
affiliate
|
16 |
-
agape
|
17 |
-
aged
|
18 |
-
agglomerate
|
19 |
-
aggregate
|
20 |
-
agonic
|
21 |
-
agora
|
22 |
-
allied
|
23 |
-
ally
|
24 |
-
alternate
|
25 |
-
alum
|
26 |
-
am
|
27 |
-
analyses
|
28 |
-
andrea
|
29 |
-
animate
|
30 |
-
apply
|
31 |
-
appropriate
|
32 |
-
approximate
|
33 |
-
ares
|
34 |
-
arithmetic
|
35 |
-
arsenic
|
36 |
-
articulate
|
37 |
-
associate
|
38 |
-
attribute
|
39 |
-
august
|
40 |
-
axes
|
41 |
-
ay
|
42 |
-
aye
|
43 |
-
bases
|
44 |
-
bass
|
45 |
-
bathed
|
46 |
-
bested
|
47 |
-
bifurcate
|
48 |
-
blessed
|
49 |
-
blotto
|
50 |
-
bow
|
51 |
-
bowed
|
52 |
-
bowman
|
53 |
-
brassy
|
54 |
-
buffet
|
55 |
-
bustier
|
56 |
-
carbonate
|
57 |
-
celtic
|
58 |
-
choral
|
59 |
-
chumash
|
60 |
-
close
|
61 |
-
closer
|
62 |
-
coax
|
63 |
-
coincidence
|
64 |
-
color coordinate
|
65 |
-
colour coordinate
|
66 |
-
comber
|
67 |
-
combine
|
68 |
-
combs
|
69 |
-
committee
|
70 |
-
commune
|
71 |
-
compact
|
72 |
-
complex
|
73 |
-
compound
|
74 |
-
compress
|
75 |
-
concert
|
76 |
-
conduct
|
77 |
-
confine
|
78 |
-
confines
|
79 |
-
conflict
|
80 |
-
conglomerate
|
81 |
-
conscript
|
82 |
-
conserve
|
83 |
-
consist
|
84 |
-
console
|
85 |
-
consort
|
86 |
-
construct
|
87 |
-
consult
|
88 |
-
consummate
|
89 |
-
content
|
90 |
-
contest
|
91 |
-
contract
|
92 |
-
contracts
|
93 |
-
contrast
|
94 |
-
converse
|
95 |
-
convert
|
96 |
-
convict
|
97 |
-
coop
|
98 |
-
coordinate
|
99 |
-
covey
|
100 |
-
crooked
|
101 |
-
curate
|
102 |
-
cussed
|
103 |
-
decollate
|
104 |
-
decrease
|
105 |
-
defect
|
106 |
-
defense
|
107 |
-
delegate
|
108 |
-
deliberate
|
109 |
-
denier
|
110 |
-
desert
|
111 |
-
detail
|
112 |
-
deviate
|
113 |
-
diagnoses
|
114 |
-
diffuse
|
115 |
-
digest
|
116 |
-
discard
|
117 |
-
discharge
|
118 |
-
discount
|
119 |
-
do
|
120 |
-
document
|
121 |
-
does
|
122 |
-
dogged
|
123 |
-
domesticate
|
124 |
-
dominican
|
125 |
-
dove
|
126 |
-
dr
|
127 |
-
drawer
|
128 |
-
duplicate
|
129 |
-
egress
|
130 |
-
ejaculate
|
131 |
-
eject
|
132 |
-
elaborate
|
133 |
-
ellipses
|
134 |
-
email
|
135 |
-
emu
|
136 |
-
entrace
|
137 |
-
entrance
|
138 |
-
escort
|
139 |
-
estimate
|
140 |
-
eta
|
141 |
-
etna
|
142 |
-
evening
|
143 |
-
excise
|
144 |
-
excuse
|
145 |
-
exploit
|
146 |
-
export
|
147 |
-
extract
|
148 |
-
fine
|
149 |
-
flower
|
150 |
-
forbear
|
151 |
-
four-legged
|
152 |
-
frequent
|
153 |
-
furrier
|
154 |
-
gallant
|
155 |
-
gel
|
156 |
-
geminate
|
157 |
-
gillie
|
158 |
-
glower
|
159 |
-
gotham
|
160 |
-
graduate
|
161 |
-
haggis
|
162 |
-
heavy
|
163 |
-
hinder
|
164 |
-
house
|
165 |
-
housewife
|
166 |
-
impact
|
167 |
-
imped
|
168 |
-
implant
|
169 |
-
implement
|
170 |
-
import
|
171 |
-
impress
|
172 |
-
incense
|
173 |
-
incline
|
174 |
-
increase
|
175 |
-
infix
|
176 |
-
insert
|
177 |
-
instar
|
178 |
-
insult
|
179 |
-
integral
|
180 |
-
intercept
|
181 |
-
interchange
|
182 |
-
interflow
|
183 |
-
interleaf
|
184 |
-
intermediate
|
185 |
-
intern
|
186 |
-
interspace
|
187 |
-
intimate
|
188 |
-
intrigue
|
189 |
-
invalid
|
190 |
-
invert
|
191 |
-
invite
|
192 |
-
irony
|
193 |
-
jagged
|
194 |
-
jesses
|
195 |
-
julies
|
196 |
-
kite
|
197 |
-
laminate
|
198 |
-
laos
|
199 |
-
lather
|
200 |
-
lead
|
201 |
-
learned
|
202 |
-
leasing
|
203 |
-
lech
|
204 |
-
legitimate
|
205 |
-
lied
|
206 |
-
lima
|
207 |
-
lipread
|
208 |
-
live
|
209 |
-
lower
|
210 |
-
lunged
|
211 |
-
maas
|
212 |
-
magdalen
|
213 |
-
manes
|
214 |
-
mare
|
215 |
-
marked
|
216 |
-
merchandise
|
217 |
-
merlion
|
218 |
-
minute
|
219 |
-
misconduct
|
220 |
-
misled
|
221 |
-
misprint
|
222 |
-
mobile
|
223 |
-
moderate
|
224 |
-
mong
|
225 |
-
moped
|
226 |
-
moth
|
227 |
-
mouth
|
228 |
-
mow
|
229 |
-
mpg
|
230 |
-
multiply
|
231 |
-
mush
|
232 |
-
nana
|
233 |
-
nice
|
234 |
-
nice
|
235 |
-
number
|
236 |
-
numerate
|
237 |
-
nun
|
238 |
-
object
|
239 |
-
opiate
|
240 |
-
ornament
|
241 |
-
outbox
|
242 |
-
outcry
|
243 |
-
outpour
|
244 |
-
outreach
|
245 |
-
outride
|
246 |
-
outright
|
247 |
-
outside
|
248 |
-
outwork
|
249 |
-
overall
|
250 |
-
overbid
|
251 |
-
overcall
|
252 |
-
overcast
|
253 |
-
overfall
|
254 |
-
overflow
|
255 |
-
overhaul
|
256 |
-
overhead
|
257 |
-
overlap
|
258 |
-
overlay
|
259 |
-
overuse
|
260 |
-
overweight
|
261 |
-
overwork
|
262 |
-
pace
|
263 |
-
palled
|
264 |
-
palling
|
265 |
-
para
|
266 |
-
pasty
|
267 |
-
pate
|
268 |
-
pauline
|
269 |
-
pedal
|
270 |
-
peer
|
271 |
-
perfect
|
272 |
-
periodic
|
273 |
-
permit
|
274 |
-
pervert
|
275 |
-
pinta
|
276 |
-
placer
|
277 |
-
platy
|
278 |
-
polish
|
279 |
-
polish
|
280 |
-
poll
|
281 |
-
pontificate
|
282 |
-
postulate
|
283 |
-
pram
|
284 |
-
prayer
|
285 |
-
precipitate
|
286 |
-
predate
|
287 |
-
predicate
|
288 |
-
prefix
|
289 |
-
preposition
|
290 |
-
present
|
291 |
-
pretest
|
292 |
-
primer
|
293 |
-
proceeds
|
294 |
-
produce
|
295 |
-
progress
|
296 |
-
project
|
297 |
-
proportionate
|
298 |
-
prospect
|
299 |
-
protest
|
300 |
-
pussy
|
301 |
-
putter
|
302 |
-
putting
|
303 |
-
quite
|
304 |
-
ragged
|
305 |
-
raven
|
306 |
-
re
|
307 |
-
read
|
308 |
-
reading
|
309 |
-
reading
|
310 |
-
real
|
311 |
-
rebel
|
312 |
-
recall
|
313 |
-
recap
|
314 |
-
recitative
|
315 |
-
recollect
|
316 |
-
record
|
317 |
-
recreate
|
318 |
-
recreation
|
319 |
-
redress
|
320 |
-
refill
|
321 |
-
refund
|
322 |
-
refuse
|
323 |
-
reject
|
324 |
-
relay
|
325 |
-
remake
|
326 |
-
repaint
|
327 |
-
reprint
|
328 |
-
reread
|
329 |
-
rerun
|
330 |
-
resent
|
331 |
-
reside
|
332 |
-
resign
|
333 |
-
respray
|
334 |
-
resume
|
335 |
-
retard
|
336 |
-
retest
|
337 |
-
retread
|
338 |
-
rewrite
|
339 |
-
root
|
340 |
-
routed
|
341 |
-
routing
|
342 |
-
row
|
343 |
-
rugged
|
344 |
-
rummy
|
345 |
-
sais
|
346 |
-
sake
|
347 |
-
sambuca
|
348 |
-
saucier
|
349 |
-
second
|
350 |
-
secrete
|
351 |
-
secreted
|
352 |
-
secreting
|
353 |
-
segment
|
354 |
-
separate
|
355 |
-
sewer
|
356 |
-
shirk
|
357 |
-
shower
|
358 |
-
sin
|
359 |
-
skied
|
360 |
-
slaver
|
361 |
-
slough
|
362 |
-
sow
|
363 |
-
spoof
|
364 |
-
squid
|
365 |
-
stingy
|
366 |
-
subject
|
367 |
-
subordinate
|
368 |
-
subvert
|
369 |
-
supply
|
370 |
-
supposed
|
371 |
-
survey
|
372 |
-
suspect
|
373 |
-
syringes
|
374 |
-
tabulate
|
375 |
-
tales
|
376 |
-
tarrier
|
377 |
-
tarry
|
378 |
-
taxes
|
379 |
-
taxis
|
380 |
-
tear
|
381 |
-
theron
|
382 |
-
thou
|
383 |
-
three-legged
|
384 |
-
tier
|
385 |
-
tinged
|
386 |
-
torment
|
387 |
-
transfer
|
388 |
-
transform
|
389 |
-
transplant
|
390 |
-
transport
|
391 |
-
transpose
|
392 |
-
tush
|
393 |
-
two-legged
|
394 |
-
unionised
|
395 |
-
unionized
|
396 |
-
update
|
397 |
-
uplift
|
398 |
-
upset
|
399 |
-
use
|
400 |
-
used
|
401 |
-
vale
|
402 |
-
violist
|
403 |
-
viva
|
404 |
-
ware
|
405 |
-
whinged
|
406 |
-
whoop
|
407 |
-
wicked
|
408 |
-
wind
|
409 |
-
windy
|
410 |
-
wino
|
411 |
-
won
|
412 |
-
worsted
|
413 |
-
wound
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tts_text_processing/letters_and_numbers.py
DELETED
@@ -1,96 +0,0 @@
|
|
1 |
-
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
-
|
3 |
-
import re
|
4 |
-
|
5 |
-
_letters_and_numbers_re = re.compile(
|
6 |
-
r"((?:[a-zA-Z]+[0-9]|[0-9]+[a-zA-Z])[a-zA-Z0-9']*)", re.IGNORECASE
|
7 |
-
)
|
8 |
-
|
9 |
-
_hardware_re = re.compile(
|
10 |
-
"([0-9]+(?:[.,][0-9]+)?)(?:\s?)(tb|gb|mb|kb|ghz|mhz|khz|hz|mm)", re.IGNORECASE
|
11 |
-
)
|
12 |
-
_hardware_key = {
|
13 |
-
"tb": "terabyte",
|
14 |
-
"gb": "gigabyte",
|
15 |
-
"mb": "megabyte",
|
16 |
-
"kb": "kilobyte",
|
17 |
-
"ghz": "gigahertz",
|
18 |
-
"mhz": "megahertz",
|
19 |
-
"khz": "kilohertz",
|
20 |
-
"hz": "hertz",
|
21 |
-
"mm": "millimeter",
|
22 |
-
"cm": "centimeter",
|
23 |
-
"km": "kilometer",
|
24 |
-
}
|
25 |
-
|
26 |
-
_dimension_re = re.compile(
|
27 |
-
r"\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b|\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b"
|
28 |
-
)
|
29 |
-
_dimension_key = {"m": "meter", "in": "inch", "inch": "inch"}
|
30 |
-
|
31 |
-
|
32 |
-
def _expand_letters_and_numbers(m):
|
33 |
-
text = re.split(r"(\d+)", m.group(0))
|
34 |
-
|
35 |
-
# remove trailing space
|
36 |
-
if text[-1] == "":
|
37 |
-
text = text[:-1]
|
38 |
-
elif text[0] == "":
|
39 |
-
text = text[1:]
|
40 |
-
|
41 |
-
# if not like 1920s, or AK47's , 20th, 1st, 2nd, 3rd, etc...
|
42 |
-
if text[-1] in ("'s", "s", "th", "nd", "st", "rd") and text[-2].isdigit():
|
43 |
-
text[-2] = text[-2] + text[-1]
|
44 |
-
text = text[:-1]
|
45 |
-
|
46 |
-
# for combining digits 2 by 2
|
47 |
-
new_text = []
|
48 |
-
for i in range(len(text)):
|
49 |
-
string = text[i]
|
50 |
-
if string.isdigit() and len(string) < 5:
|
51 |
-
# heuristics
|
52 |
-
if len(string) > 2 and string[-2] == "0":
|
53 |
-
if string[-1] == "0":
|
54 |
-
string = [string]
|
55 |
-
else:
|
56 |
-
string = [string[:-3], string[-2], string[-1]]
|
57 |
-
elif len(string) % 2 == 0:
|
58 |
-
string = [string[i : i + 2] for i in range(0, len(string), 2)]
|
59 |
-
elif len(string) > 2:
|
60 |
-
string = [string[0]] + [
|
61 |
-
string[i : i + 2] for i in range(1, len(string), 2)
|
62 |
-
]
|
63 |
-
new_text.extend(string)
|
64 |
-
else:
|
65 |
-
new_text.append(string)
|
66 |
-
|
67 |
-
text = new_text
|
68 |
-
text = " ".join(text)
|
69 |
-
return text
|
70 |
-
|
71 |
-
|
72 |
-
def _expand_hardware(m):
|
73 |
-
quantity, measure = m.groups(0)
|
74 |
-
measure = _hardware_key[measure.lower()]
|
75 |
-
if measure[-1] != "z" and float(quantity.replace(",", "")) > 1:
|
76 |
-
return "{} {}s".format(quantity, measure)
|
77 |
-
return "{} {}".format(quantity, measure)
|
78 |
-
|
79 |
-
|
80 |
-
def _expand_dimension(m):
|
81 |
-
text = "".join([x for x in m.groups(0) if x != 0])
|
82 |
-
text = text.replace(" x ", " by ")
|
83 |
-
text = text.replace("x", " by ")
|
84 |
-
if text.endswith(tuple(_dimension_key.keys())):
|
85 |
-
if text[-2].isdigit():
|
86 |
-
text = "{} {}".format(text[:-1], _dimension_key[text[-1:]])
|
87 |
-
elif text[-3].isdigit():
|
88 |
-
text = "{} {}".format(text[:-2], _dimension_key[text[-2:]])
|
89 |
-
return text
|
90 |
-
|
91 |
-
|
92 |
-
def normalize_letters_and_numbers(text):
|
93 |
-
text = re.sub(_hardware_re, _expand_hardware, text)
|
94 |
-
text = re.sub(_dimension_re, _expand_dimension, text)
|
95 |
-
text = re.sub(_letters_and_numbers_re, _expand_letters_and_numbers, text)
|
96 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tts_text_processing/numerical.py
DELETED
@@ -1,175 +0,0 @@
|
|
1 |
-
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
-
|
3 |
-
import inflect
|
4 |
-
import re
|
5 |
-
|
6 |
-
_magnitudes = ["trillion", "billion", "million", "thousand", "hundred", "m", "b", "t"]
|
7 |
-
_magnitudes_key = {"m": "million", "b": "billion", "t": "trillion"}
|
8 |
-
_measurements = "(f|c|k|d|m)"
|
9 |
-
_measurements_key = {"f": "fahrenheit", "c": "celsius", "k": "thousand", "m": "meters"}
|
10 |
-
_currency_key = {"$": "dollar", "£": "pound", "€": "euro", "₩": "won"}
|
11 |
-
_inflect = inflect.engine()
|
12 |
-
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
13 |
-
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
14 |
-
_currency_re = re.compile(
|
15 |
-
r"([\$€£₩])([0-9\.\,]*[0-9]+)(?:[ ]?({})(?=[^a-zA-Z]))?".format(
|
16 |
-
"|".join(_magnitudes)
|
17 |
-
),
|
18 |
-
re.IGNORECASE,
|
19 |
-
)
|
20 |
-
_measurement_re = re.compile(
|
21 |
-
r"([0-9\.\,]*[0-9]+(\s)?{}\b)".format(_measurements), re.IGNORECASE
|
22 |
-
)
|
23 |
-
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
24 |
-
# _range_re = re.compile(r'(?<=[0-9])+(-)(?=[0-9])+.*?')
|
25 |
-
_roman_re = re.compile(
|
26 |
-
r"\b(?=[MDCLXVI]+\b)M{0,4}(CM|CD|D?C{0,3})(XC|XL|L?X{0,3})(IX|IV|V?I{2,3})\b"
|
27 |
-
) # avoid I
|
28 |
-
_multiply_re = re.compile(r"(\b[0-9]+)(x)([0-9]+)")
|
29 |
-
_number_re = re.compile(r"[0-9]+'s|[0-9]+s|[0-9]+")
|
30 |
-
|
31 |
-
|
32 |
-
def _remove_commas(m):
|
33 |
-
return m.group(1).replace(",", "")
|
34 |
-
|
35 |
-
|
36 |
-
def _expand_decimal_point(m):
|
37 |
-
return m.group(1).replace(".", " point ")
|
38 |
-
|
39 |
-
|
40 |
-
def _expand_currency(m):
|
41 |
-
currency = _currency_key[m.group(1)]
|
42 |
-
quantity = m.group(2)
|
43 |
-
magnitude = m.group(3)
|
44 |
-
|
45 |
-
# remove commas from quantity to be able to convert to numerical
|
46 |
-
quantity = quantity.replace(",", "")
|
47 |
-
|
48 |
-
# check for million, billion, etc...
|
49 |
-
if magnitude is not None and magnitude.lower() in _magnitudes:
|
50 |
-
if len(magnitude) == 1:
|
51 |
-
magnitude = _magnitudes_key[magnitude.lower()]
|
52 |
-
return "{} {} {}".format(_expand_hundreds(quantity), magnitude, currency + "s")
|
53 |
-
|
54 |
-
parts = quantity.split(".")
|
55 |
-
if len(parts) > 2:
|
56 |
-
return quantity + " " + currency + "s" # Unexpected format
|
57 |
-
|
58 |
-
dollars = int(parts[0]) if parts[0] else 0
|
59 |
-
|
60 |
-
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
61 |
-
if dollars and cents:
|
62 |
-
dollar_unit = currency if dollars == 1 else currency + "s"
|
63 |
-
cent_unit = "cent" if cents == 1 else "cents"
|
64 |
-
return "{} {}, {} {}".format(
|
65 |
-
_expand_hundreds(dollars),
|
66 |
-
dollar_unit,
|
67 |
-
_inflect.number_to_words(cents),
|
68 |
-
cent_unit,
|
69 |
-
)
|
70 |
-
elif dollars:
|
71 |
-
dollar_unit = currency if dollars == 1 else currency + "s"
|
72 |
-
return "{} {}".format(_expand_hundreds(dollars), dollar_unit)
|
73 |
-
elif cents:
|
74 |
-
cent_unit = "cent" if cents == 1 else "cents"
|
75 |
-
return "{} {}".format(_inflect.number_to_words(cents), cent_unit)
|
76 |
-
else:
|
77 |
-
return "zero" + " " + currency + "s"
|
78 |
-
|
79 |
-
|
80 |
-
def _expand_hundreds(text):
|
81 |
-
number = float(text)
|
82 |
-
if number > 1000 < 10000 and (number % 100 == 0) and (number % 1000 != 0):
|
83 |
-
return _inflect.number_to_words(int(number / 100)) + " hundred"
|
84 |
-
else:
|
85 |
-
return _inflect.number_to_words(text)
|
86 |
-
|
87 |
-
|
88 |
-
def _expand_ordinal(m):
|
89 |
-
return _inflect.number_to_words(m.group(0))
|
90 |
-
|
91 |
-
|
92 |
-
def _expand_measurement(m):
|
93 |
-
_, number, measurement = re.split("(\d+(?:\.\d+)?)", m.group(0))
|
94 |
-
number = _inflect.number_to_words(number)
|
95 |
-
measurement = "".join(measurement.split())
|
96 |
-
measurement = _measurements_key[measurement.lower()]
|
97 |
-
return "{} {}".format(number, measurement)
|
98 |
-
|
99 |
-
|
100 |
-
def _expand_range(m):
|
101 |
-
return " to "
|
102 |
-
|
103 |
-
|
104 |
-
def _expand_multiply(m):
|
105 |
-
left = m.group(1)
|
106 |
-
right = m.group(3)
|
107 |
-
return "{} by {}".format(left, right)
|
108 |
-
|
109 |
-
|
110 |
-
def _expand_roman(m):
|
111 |
-
# from https://stackoverflow.com/questions/19308177/converting-roman-numerals-to-integers-in-python
|
112 |
-
roman_numerals = {"I": 1, "V": 5, "X": 10, "L": 50, "C": 100, "D": 500, "M": 1000}
|
113 |
-
result = 0
|
114 |
-
num = m.group(0)
|
115 |
-
for i, c in enumerate(num):
|
116 |
-
if (i + 1) == len(num) or roman_numerals[c] >= roman_numerals[num[i + 1]]:
|
117 |
-
result += roman_numerals[c]
|
118 |
-
else:
|
119 |
-
result -= roman_numerals[c]
|
120 |
-
return str(result)
|
121 |
-
|
122 |
-
|
123 |
-
def _expand_number(m):
|
124 |
-
_, number, suffix = re.split(r"(\d+(?:'?\d+)?)", m.group(0))
|
125 |
-
number = int(number)
|
126 |
-
if (
|
127 |
-
number > 1000
|
128 |
-
and number < 10000
|
129 |
-
and (number % 100 == 0)
|
130 |
-
and (number % 1000 != 0)
|
131 |
-
):
|
132 |
-
text = _inflect.number_to_words(number // 100) + " hundred"
|
133 |
-
elif number > 1000 and number < 3000:
|
134 |
-
if number == 2000:
|
135 |
-
text = "two thousand"
|
136 |
-
elif number > 2000 and number < 2010:
|
137 |
-
text = "two thousand " + _inflect.number_to_words(number % 100)
|
138 |
-
elif number % 100 == 0:
|
139 |
-
text = _inflect.number_to_words(number // 100) + " hundred"
|
140 |
-
else:
|
141 |
-
number = _inflect.number_to_words(
|
142 |
-
number, andword="", zero="oh", group=2
|
143 |
-
).replace(", ", " ")
|
144 |
-
number = re.sub(r"-", " ", number)
|
145 |
-
text = number
|
146 |
-
else:
|
147 |
-
number = _inflect.number_to_words(number, andword="and")
|
148 |
-
number = re.sub(r"-", " ", number)
|
149 |
-
number = re.sub(r",", "", number)
|
150 |
-
text = number
|
151 |
-
|
152 |
-
if suffix in ("'s", "s"):
|
153 |
-
if text[-1] == "y":
|
154 |
-
text = text[:-1] + "ies"
|
155 |
-
else:
|
156 |
-
text = text + suffix
|
157 |
-
|
158 |
-
return text
|
159 |
-
|
160 |
-
|
161 |
-
def normalize_currency(text):
|
162 |
-
return re.sub(_currency_re, _expand_currency, text)
|
163 |
-
|
164 |
-
|
165 |
-
def normalize_numbers(text):
|
166 |
-
text = re.sub(_comma_number_re, _remove_commas, text)
|
167 |
-
text = re.sub(_currency_re, _expand_currency, text)
|
168 |
-
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
169 |
-
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
170 |
-
# text = re.sub(_range_re, _expand_range, text)
|
171 |
-
# text = re.sub(_measurement_re, _expand_measurement, text)
|
172 |
-
text = re.sub(_roman_re, _expand_roman, text)
|
173 |
-
text = re.sub(_multiply_re, _expand_multiply, text)
|
174 |
-
text = re.sub(_number_re, _expand_number, text)
|
175 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tts_text_processing/symbols.py
DELETED
@@ -1,144 +0,0 @@
|
|
1 |
-
"""adapted from https://github.com/keithito/tacotron"""
|
2 |
-
|
3 |
-
"""
|
4 |
-
Defines the set of symbols used in text input to the model.
|
5 |
-
|
6 |
-
The default is a set of ASCII characters that works well for English or text
|
7 |
-
that has been run through Unidecode. For other data, you can modify
|
8 |
-
_characters."""
|
9 |
-
|
10 |
-
|
11 |
-
arpabet = [
|
12 |
-
"AA",
|
13 |
-
"AA0",
|
14 |
-
"AA1",
|
15 |
-
"AA2",
|
16 |
-
"AE",
|
17 |
-
"AE0",
|
18 |
-
"AE1",
|
19 |
-
"AE2",
|
20 |
-
"AH",
|
21 |
-
"AH0",
|
22 |
-
"AH1",
|
23 |
-
"AH2",
|
24 |
-
"AO",
|
25 |
-
"AO0",
|
26 |
-
"AO1",
|
27 |
-
"AO2",
|
28 |
-
"AW",
|
29 |
-
"AW0",
|
30 |
-
"AW1",
|
31 |
-
"AW2",
|
32 |
-
"AY",
|
33 |
-
"AY0",
|
34 |
-
"AY1",
|
35 |
-
"AY2",
|
36 |
-
"B",
|
37 |
-
"CH",
|
38 |
-
"D",
|
39 |
-
"DH",
|
40 |
-
"EH",
|
41 |
-
"EH0",
|
42 |
-
"EH1",
|
43 |
-
"EH2",
|
44 |
-
"ER",
|
45 |
-
"ER0",
|
46 |
-
"ER1",
|
47 |
-
"ER2",
|
48 |
-
"EY",
|
49 |
-
"EY0",
|
50 |
-
"EY1",
|
51 |
-
"EY2",
|
52 |
-
"F",
|
53 |
-
"G",
|
54 |
-
"HH",
|
55 |
-
"IH",
|
56 |
-
"IH0",
|
57 |
-
"IH1",
|
58 |
-
"IH2",
|
59 |
-
"IY",
|
60 |
-
"IY0",
|
61 |
-
"IY1",
|
62 |
-
"IY2",
|
63 |
-
"JH",
|
64 |
-
"K",
|
65 |
-
"L",
|
66 |
-
"M",
|
67 |
-
"N",
|
68 |
-
"NG",
|
69 |
-
"OW",
|
70 |
-
"OW0",
|
71 |
-
"OW1",
|
72 |
-
"OW2",
|
73 |
-
"OY",
|
74 |
-
"OY0",
|
75 |
-
"OY1",
|
76 |
-
"OY2",
|
77 |
-
"P",
|
78 |
-
"R",
|
79 |
-
"S",
|
80 |
-
"SH",
|
81 |
-
"T",
|
82 |
-
"TH",
|
83 |
-
"UH",
|
84 |
-
"UH0",
|
85 |
-
"UH1",
|
86 |
-
"UH2",
|
87 |
-
"UW",
|
88 |
-
"UW0",
|
89 |
-
"UW1",
|
90 |
-
"UW2",
|
91 |
-
"V",
|
92 |
-
"W",
|
93 |
-
"Y",
|
94 |
-
"Z",
|
95 |
-
"ZH",
|
96 |
-
]
|
97 |
-
|
98 |
-
|
99 |
-
def get_symbols(symbol_set):
|
100 |
-
if symbol_set == "english_basic":
|
101 |
-
_pad = "_"
|
102 |
-
_punctuation = "!'\"(),.:;? "
|
103 |
-
_special = "-"
|
104 |
-
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
105 |
-
_arpabet = ["@" + s for s in arpabet]
|
106 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
107 |
-
elif symbol_set == "english_basic_lowercase":
|
108 |
-
_pad = "_"
|
109 |
-
_punctuation = "!'\"(),.:;? "
|
110 |
-
_special = "-"
|
111 |
-
_letters = "abcdefghijklmnopqrstuvwxyz"
|
112 |
-
_arpabet = ["@" + s for s in arpabet]
|
113 |
-
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
|
114 |
-
elif symbol_set == "english_expanded":
|
115 |
-
_punctuation = "!'\",.:;? "
|
116 |
-
_math = "#%&*+-/[]()"
|
117 |
-
_special = "_@©°½—₩€$"
|
118 |
-
_accented = "áçéêëñöøćž"
|
119 |
-
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
120 |
-
_arpabet = ["@" + s for s in arpabet]
|
121 |
-
symbols = (
|
122 |
-
list(_punctuation + _math + _special + _accented + _letters) + _arpabet
|
123 |
-
)
|
124 |
-
elif symbol_set == "ukrainian":
|
125 |
-
_punctuation = "'.,?! "
|
126 |
-
_special = "-+"
|
127 |
-
_letters = "абвгґдежзийклмнопрстуфхцчшщьюяєії"
|
128 |
-
symbols = list(_punctuation + _special + _letters)
|
129 |
-
elif symbol_set == "radtts":
|
130 |
-
_punctuation = "!'\",.:;? "
|
131 |
-
_math = "#%&*+-/[]()"
|
132 |
-
_special = "_@©°½—₩€$"
|
133 |
-
_accented = "áçéêëñöøćž"
|
134 |
-
_numbers = "0123456789"
|
135 |
-
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
136 |
-
_arpabet = ["@" + s for s in arpabet]
|
137 |
-
symbols = (
|
138 |
-
list(_punctuation + _math + _special + _accented + _numbers + _letters)
|
139 |
-
+ _arpabet
|
140 |
-
)
|
141 |
-
else:
|
142 |
-
raise Exception("{} symbol set does not exist".format(symbol_set))
|
143 |
-
|
144 |
-
return symbols
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tts_text_processing/text_processing.py
CHANGED
@@ -2,9 +2,8 @@
|
|
2 |
|
3 |
import re
|
4 |
import numpy as np
|
|
|
5 |
from .cleaners import Cleaner
|
6 |
-
from .symbols import get_symbols
|
7 |
-
from .grapheme_dictionary import Grapheme2PhonemeDictionary
|
8 |
|
9 |
|
10 |
#########
|
@@ -20,11 +19,14 @@ _words_re = re.compile(
|
|
20 |
)
|
21 |
|
22 |
|
23 |
-
def
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
class TextProcessing(object):
|
@@ -42,18 +44,14 @@ class TextProcessing(object):
|
|
42 |
add_bos_eos_to_text=False,
|
43 |
encoding="latin-1",
|
44 |
):
|
45 |
-
|
46 |
-
self.heteronyms = set(lines_to_list(heteronyms_path))
|
47 |
-
else:
|
48 |
-
self.heteronyms = []
|
49 |
-
# phoneme dict
|
50 |
self.phonemedict = {}
|
51 |
|
52 |
self.p_phoneme = p_phoneme
|
53 |
self.handle_phoneme = handle_phoneme
|
54 |
self.handle_phoneme_ambiguous = handle_phoneme_ambiguous
|
55 |
|
56 |
-
self.symbols = get_symbols(
|
57 |
self.cleaner_names = cleaner_name
|
58 |
self.cleaner = Cleaner(cleaner_name, self.phonemedict)
|
59 |
|
|
|
2 |
|
3 |
import re
|
4 |
import numpy as np
|
5 |
+
|
6 |
from .cleaners import Cleaner
|
|
|
|
|
7 |
|
8 |
|
9 |
#########
|
|
|
19 |
)
|
20 |
|
21 |
|
22 |
+
def get_symbols():
|
23 |
+
_punctuation = "'.,?! "
|
24 |
+
_special = "-+"
|
25 |
+
_letters = "абвгґдежзийклмнопрстуфхцчшщьюяєії"
|
26 |
+
|
27 |
+
symbols = list(_punctuation + _special + _letters)
|
28 |
+
|
29 |
+
return symbols
|
30 |
|
31 |
|
32 |
class TextProcessing(object):
|
|
|
44 |
add_bos_eos_to_text=False,
|
45 |
encoding="latin-1",
|
46 |
):
|
47 |
+
self.heteronyms = []
|
|
|
|
|
|
|
|
|
48 |
self.phonemedict = {}
|
49 |
|
50 |
self.p_phoneme = p_phoneme
|
51 |
self.handle_phoneme = handle_phoneme
|
52 |
self.handle_phoneme_ambiguous = handle_phoneme_ambiguous
|
53 |
|
54 |
+
self.symbols = get_symbols()
|
55 |
self.cleaner_names = cleaner_name
|
56 |
self.cleaner = Cleaner(cleaner_name, self.phonemedict)
|
57 |
|