Yehor commited on
Commit
8a6f9a8
·
1 Parent(s): a7810c9

Fixes to the codebase

Browse files
app.py CHANGED
@@ -6,38 +6,32 @@ import time
6
  from importlib.metadata import version
7
  from enum import Enum
8
 
9
- from huggingface_hub import hf_hub_download
10
-
11
- use_zerogpu = False
12
 
13
- try:
14
- import spaces # it's for ZeroGPU
15
 
16
- use_zerogpu = True
17
- print("ZeroGPU is available, changing inference call.")
18
- except ImportError:
19
- print("ZeroGPU is not available, skipping...")
 
20
 
21
  import gradio as gr
22
 
23
- import torch
24
- import torchaudio
25
 
26
  # Vocos
27
  from vocos import Vocos
28
 
29
- # RAD-TTS code
30
- from radtts import RADTTS
31
- from data import Data
32
- from common import update_params
33
 
34
- use_cuda = torch.cuda.is_available()
 
35
 
36
- if use_cuda:
37
- print("CUDA is available, setting correct inference_device variable.")
38
- device = "cuda"
39
- else:
40
- device = "cpu"
41
 
42
 
43
  def download_file_from_repo(
@@ -65,15 +59,13 @@ def download_file_from_repo(
65
 
66
  download_file_from_repo(
67
  "Yehor/radtts-uk",
68
- "radtts-pp-dap-model/model_dap_84000.pt",
69
  "./models/",
70
  )
71
 
72
  # Init the model
73
- seed = 1234
74
-
75
  config = "configs/radtts-pp-dap-model.json"
76
- radtts_path = "models/radtts-pp-dap-model/model_dap_84000.pt"
77
 
78
  params = []
79
 
@@ -87,19 +79,11 @@ update_params(config, params)
87
  data_config = config["data_config"]
88
  model_config = config["model_config"]
89
 
90
- # Seed
91
- if use_cuda:
92
- torch.cuda.manual_seed(seed)
93
- else:
94
- torch.manual_seed(seed)
95
-
96
  # Load vocoder
97
  vocos = Vocos.from_pretrained("patriotyk/vocos-mel-hifigan-compat-44100khz").to(device)
98
 
99
  # Load RAD-TTS
100
- radtts = RADTTS(**model_config)
101
- if use_cuda:
102
- radtts = radtts.cuda()
103
 
104
  radtts.enable_inverse_cache() # cache inverse matrix for 1x1 invertible convs
105
 
@@ -151,6 +135,7 @@ tech_env = f"""
151
  #### Environment
152
 
153
  - Python: {sys.version}
 
154
  """.strip()
155
 
156
  tech_libraries = f"""
@@ -161,8 +146,6 @@ tech_libraries = f"""
161
  - scipy: {version("scipy")}
162
  - numba: {version("numba")}
163
  - librosa: {version("librosa")}
164
- - unidecode: {version("unidecode")}
165
- - inflect: {version("inflect")}
166
  """.strip()
167
 
168
 
@@ -218,25 +201,16 @@ def inference(text, voice):
218
  energy_mean = 0
219
  energy_std = 0
220
 
221
- tensor_text = trainset.get_text(text)
222
 
223
- speaker_id = trainset.get_speaker_id(speaker)
224
  speaker_id_text, speaker_id_attributes = speaker_id, speaker_id
225
 
226
  if speaker_text is not None:
227
- speaker_id_text = trainset.get_speaker_id(speaker_text)
228
 
229
  if speaker_attributes is not None:
230
- speaker_id_attributes = trainset.get_speaker_id(speaker_attributes)
231
-
232
- if use_cuda:
233
- tensor_text = tensor_text.cuda()
234
- speaker_id = speaker_id.cuda()
235
-
236
- if speaker_id_text is not None:
237
- speaker_id_text = speaker_id_text.cuda()
238
- if speaker_id_attributes is not None:
239
- speaker_id_attributes = speaker_id_attributes.cuda()
240
 
241
  inference_start = time.time()
242
 
 
6
  from importlib.metadata import version
7
  from enum import Enum
8
 
9
+ import torch
10
+ import torchaudio
 
11
 
12
+ from huggingface_hub import hf_hub_download
 
13
 
14
+ # RAD-TTS code
15
+ from radtts import RADTTS
16
+ from data import Data
17
+ from common import update_params
18
+ from torch_env import device
19
 
20
  import gradio as gr
21
 
 
 
22
 
23
  # Vocos
24
  from vocos import Vocos
25
 
26
+ use_zerogpu = False
 
 
 
27
 
28
+ try:
29
+ import spaces # it's for ZeroGPU
30
 
31
+ use_zerogpu = True
32
+ print("ZeroGPU is available, changing inference call.")
33
+ except ImportError:
34
+ print("ZeroGPU is not available, skipping...")
 
35
 
36
 
37
  def download_file_from_repo(
 
59
 
60
  download_file_from_repo(
61
  "Yehor/radtts-uk",
62
+ "radtts-pp-dap-model/model_dap_84000_state.pt",
63
  "./models/",
64
  )
65
 
66
  # Init the model
 
 
67
  config = "configs/radtts-pp-dap-model.json"
68
+ radtts_path = "models/radtts-pp-dap-model/model_dap_84000_state.pt"
69
 
70
  params = []
71
 
 
79
  data_config = config["data_config"]
80
  model_config = config["model_config"]
81
 
 
 
 
 
 
 
82
  # Load vocoder
83
  vocos = Vocos.from_pretrained("patriotyk/vocos-mel-hifigan-compat-44100khz").to(device)
84
 
85
  # Load RAD-TTS
86
+ radtts = RADTTS(**model_config).to(device)
 
 
87
 
88
  radtts.enable_inverse_cache() # cache inverse matrix for 1x1 invertible convs
89
 
 
135
  #### Environment
136
 
137
  - Python: {sys.version}
138
+ - Torch device: {device}
139
  """.strip()
140
 
141
  tech_libraries = f"""
 
146
  - scipy: {version("scipy")}
147
  - numba: {version("numba")}
148
  - librosa: {version("librosa")}
 
 
149
  """.strip()
150
 
151
 
 
201
  energy_mean = 0
202
  energy_std = 0
203
 
204
+ tensor_text = trainset.get_text(text).to(device)
205
 
206
+ speaker_id = trainset.get_speaker_id(speaker).to(device)
207
  speaker_id_text, speaker_id_attributes = speaker_id, speaker_id
208
 
209
  if speaker_text is not None:
210
+ speaker_id_text = trainset.get_speaker_id(speaker_text).to(device)
211
 
212
  if speaker_attributes is not None:
213
+ speaker_id_attributes = trainset.get_speaker_id(speaker_attributes).to(device)
 
 
 
 
 
 
 
 
 
214
 
215
  inference_start = time.time()
216
 
attribute_prediction_model.py CHANGED
@@ -18,8 +18,10 @@
18
  # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
  # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
  # DEALINGS IN THE SOFTWARE.
 
21
  import torch
22
  from torch import nn
 
23
  from common import ConvNorm, Invertible1x1Conv
24
  from common import AffineTransformationLayer, SplineTransformationLayer
25
  from common import ConvLSTMLinear
 
18
  # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
  # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
  # DEALINGS IN THE SOFTWARE.
21
+
22
  import torch
23
  from torch import nn
24
+
25
  from common import ConvNorm, Invertible1x1Conv
26
  from common import AffineTransformationLayer, SplineTransformationLayer
27
  from common import ConvLSTMLinear
audio_processing.py CHANGED
@@ -18,12 +18,50 @@
18
  # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
  # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
  # DEALINGS IN THE SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  import torch
22
  import numpy as np
 
23
  from scipy.signal import get_window
24
  from librosa.filters import mel as librosa_mel_fn
25
  import librosa.util as librosa_util
26
 
 
 
 
 
27
 
28
  def window_sumsquare(
29
  window,
@@ -174,43 +212,6 @@ class TacotronSTFT(torch.nn.Module):
174
  return mel_output
175
 
176
 
177
- """
178
- BSD 3-Clause License
179
-
180
- Copyright (c) 2017, Prem Seetharaman
181
- All rights reserved.
182
-
183
- * Redistribution and use in source and binary forms, with or without
184
- modification, are permitted provided that the following conditions are met:
185
-
186
- * Redistributions of source code must retain the above copyright notice,
187
- this list of conditions and the following disclaimer.
188
-
189
- * Redistributions in binary form must reproduce the above copyright notice, this
190
- list of conditions and the following disclaimer in the
191
- documentation and/or other materials provided with the distribution.
192
-
193
- * Neither the name of the copyright holder nor the names of its
194
- contributors may be used to endorse or promote products derived from this
195
- software without specific prior written permission.
196
-
197
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
198
- ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
199
- WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
200
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
201
- ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
202
- (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
203
- LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
204
- ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
205
- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
206
- SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
207
- """
208
- import torch.nn.functional as F
209
- from torch.autograd import Variable
210
- from scipy.signal import get_window
211
- from librosa.util import pad_center, tiny
212
-
213
-
214
  class STFT(torch.nn.Module):
215
  """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
216
 
 
18
  # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
  # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
  # DEALINGS IN THE SOFTWARE.
21
+
22
+ """
23
+ BSD 3-Clause License
24
+
25
+ Copyright (c) 2017, Prem Seetharaman
26
+ All rights reserved.
27
+
28
+ * Redistribution and use in source and binary forms, with or without
29
+ modification, are permitted provided that the following conditions are met:
30
+
31
+ * Redistributions of source code must retain the above copyright notice,
32
+ this list of conditions and the following disclaimer.
33
+
34
+ * Redistributions in binary form must reproduce the above copyright notice, this
35
+ list of conditions and the following disclaimer in the
36
+ documentation and/or other materials provided with the distribution.
37
+
38
+ * Neither the name of the copyright holder nor the names of its
39
+ contributors may be used to endorse or promote products derived from this
40
+ software without specific prior written permission.
41
+
42
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
43
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
44
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
45
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
46
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
47
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
48
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
49
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
50
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
51
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
52
+ """
53
+
54
  import torch
55
  import numpy as np
56
+
57
  from scipy.signal import get_window
58
  from librosa.filters import mel as librosa_mel_fn
59
  import librosa.util as librosa_util
60
 
61
+ import torch.nn.functional as F
62
+ from torch.autograd import Variable
63
+ from librosa.util import pad_center, tiny
64
+
65
 
66
  def window_sumsquare(
67
  window,
 
212
  return mel_output
213
 
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  class STFT(torch.nn.Module):
216
  """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
217
 
autoregressive_flow.py CHANGED
@@ -45,8 +45,7 @@ import torch
45
  from torch import nn
46
 
47
  from common import DenseLayer, SplineTransformationLayerAR
48
-
49
- use_cuda = torch.cuda.is_available()
50
 
51
 
52
  class AR_Back_Step(torch.nn.Module):
@@ -229,10 +228,7 @@ class AR_Step(torch.nn.Module):
229
  (1, residual.size(1), residual.size(2)), dtype=residual.dtype
230
  )
231
 
232
- if use_cuda:
233
- dummy = torch.tensor(data, device="cuda")
234
- else:
235
- dummy = torch.tensor(data)
236
 
237
  self.attr_lstm.flatten_parameters()
238
 
 
45
  from torch import nn
46
 
47
  from common import DenseLayer, SplineTransformationLayerAR
48
+ from torch_env import device
 
49
 
50
 
51
  class AR_Back_Step(torch.nn.Module):
 
228
  (1, residual.size(1), residual.size(2)), dtype=residual.dtype
229
  )
230
 
231
+ dummy = torch.tensor(data, device=device)
 
 
 
232
 
233
  self.attr_lstm.flatten_parameters()
234
 
common.py CHANGED
@@ -62,34 +62,7 @@ from splines import (
62
  )
63
  from partialconv1d import PartialConv1d as pconv1d
64
  from typing import Tuple
65
-
66
- use_cuda = torch.cuda.is_available()
67
-
68
- if use_cuda:
69
- device = "cuda"
70
- else:
71
- device = "cpu"
72
-
73
-
74
- def update_params(config, params):
75
- for param in params:
76
- print(param)
77
- k, v = param.split("=")
78
- try:
79
- v = ast.literal_eval(v)
80
- except:
81
- pass
82
-
83
- k_split = k.split(".")
84
- if len(k_split) > 1:
85
- parent_k = k_split[0]
86
- cur_param = [".".join(k_split[1:]) + "=" + str(v)]
87
- update_params(config[parent_k], cur_param)
88
- elif k in config and len(k_split) == 1:
89
- print(f"overriding {k} with {v}")
90
- config[k] = v
91
- else:
92
- print("{}, {} params not updated".format(k, v))
93
 
94
 
95
  def get_mask_from_lengths(lengths):
@@ -103,10 +76,7 @@ def get_mask_from_lengths(lengths):
103
 
104
  max_len = torch.max(lengths).item()
105
 
106
- if use_cuda:
107
- ids = torch.tensor(list(range(max_len)), dtype=torch.long, device="cuda")
108
- else:
109
- ids = torch.tensor(list(range(max_len)), dtype=torch.long, device="cpu")
110
 
111
  mask = (ids < lengths.unsqueeze(1)).bool()
112
 
@@ -172,7 +142,7 @@ class ConvNorm(torch.nn.Module):
172
  self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
173
  )
174
  if self.use_weight_norm:
175
- self.conv = nn.utils.weight_norm(self.conv)
176
 
177
  def forward(self, signal, mask=None):
178
  if self.use_partial_padding:
@@ -263,7 +233,7 @@ class ConvLSTMLinear(nn.Module):
263
  dilation=1,
264
  w_init_gain="relu",
265
  )
266
- conv_layer = torch.nn.utils.weight_norm(conv_layer.conv, name="weight")
267
  convolutions.append(conv_layer)
268
 
269
  self.convolutions = nn.ModuleList(convolutions)
@@ -281,7 +251,7 @@ class ConvLSTMLinear(nn.Module):
281
  self.bilstm = nn.LSTM(
282
  n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm
283
  )
284
- lstm_norm_fn_pntr = nn.utils.spectral_norm
285
  self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0")
286
  if self.lstm_type == "bilstm":
287
  self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse")
@@ -391,10 +361,10 @@ class Encoder(nn.Module):
391
  if lstm_norm_fn is not None:
392
  if "spectral" in lstm_norm_fn:
393
  print("Applying spectral norm to text encoder LSTM")
394
- lstm_norm_fn_pntr = torch.nn.utils.spectral_norm
395
  elif "weight" in lstm_norm_fn:
396
  print("Applying weight norm to text encoder LSTM")
397
- lstm_norm_fn_pntr = torch.nn.utils.weight_norm
398
  self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0")
399
  self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0_reverse")
400
 
@@ -450,7 +420,7 @@ class Invertible1x1ConvLUS(torch.nn.Module):
450
  # Ensure determinant is 1.0 not -1.0
451
  if torch.det(W) < 0:
452
  W[:, 0] = -1 * W[:, 0]
453
- p, lower, upper = torch.lu_unpack(*torch.lu(W))
454
 
455
  self.register_buffer("p", p)
456
  # diagonals of lower will always be 1s anyway
@@ -616,7 +586,7 @@ class WN(torch.nn.Module):
616
  self.in_layers = torch.nn.ModuleList()
617
  self.res_skip_layers = torch.nn.ModuleList()
618
  start = torch.nn.Conv1d(n_in_channels + n_context_dim, n_channels, 1)
619
- start = torch.nn.utils.weight_norm(start, name="weight")
620
  self.start = start
621
  self.softplus = torch.nn.Softplus()
622
  self.affine_activation = affine_activation
@@ -645,7 +615,7 @@ class WN(torch.nn.Module):
645
  # in_layer = nn.utils.weight_norm(in_layer)
646
  self.in_layers.append(in_layer)
647
  res_skip_layer = nn.Conv1d(n_channels, n_channels, 1)
648
- res_skip_layer = nn.utils.weight_norm(res_skip_layer)
649
  self.res_skip_layers.append(res_skip_layer)
650
 
651
  def forward(
@@ -823,7 +793,7 @@ class SplineTransformationLayer(torch.nn.Module):
823
  # output is unnormalized bin weights
824
 
825
  def forward(self, z, context, inverse=False, seq_lens=None):
826
- b_s, c_s, t_s = z.size(0), z.size(1), z.size(2)
827
 
828
  # condition on z_0, transform z_1
829
  n_half = self.half_mel_channels
@@ -1085,3 +1055,24 @@ class ConvAttention(torch.nn.Module):
1085
 
1086
  attn = self.softmax(attn) # softmax along T2
1087
  return attn, attn_logprob
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
63
  from partialconv1d import PartialConv1d as pconv1d
64
  from typing import Tuple
65
+ from torch_env import device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
 
68
  def get_mask_from_lengths(lengths):
 
76
 
77
  max_len = torch.max(lengths).item()
78
 
79
+ ids = torch.tensor(list(range(max_len)), dtype=torch.long, device=device)
 
 
 
80
 
81
  mask = (ids < lengths.unsqueeze(1)).bool()
82
 
 
142
  self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
143
  )
144
  if self.use_weight_norm:
145
+ self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv)
146
 
147
  def forward(self, signal, mask=None):
148
  if self.use_partial_padding:
 
233
  dilation=1,
234
  w_init_gain="relu",
235
  )
236
+ conv_layer = torch.nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight")
237
  convolutions.append(conv_layer)
238
 
239
  self.convolutions = nn.ModuleList(convolutions)
 
251
  self.bilstm = nn.LSTM(
252
  n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm
253
  )
254
+ lstm_norm_fn_pntr = torch.nn.utils.parametrizations.spectral_norm
255
  self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0")
256
  if self.lstm_type == "bilstm":
257
  self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse")
 
361
  if lstm_norm_fn is not None:
362
  if "spectral" in lstm_norm_fn:
363
  print("Applying spectral norm to text encoder LSTM")
364
+ lstm_norm_fn_pntr = torch.nn.utils.parametrizations.spectral_norm
365
  elif "weight" in lstm_norm_fn:
366
  print("Applying weight norm to text encoder LSTM")
367
+ lstm_norm_fn_pntr = torch.nn.utils.parametrizations.weight_norm
368
  self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0")
369
  self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0_reverse")
370
 
 
420
  # Ensure determinant is 1.0 not -1.0
421
  if torch.det(W) < 0:
422
  W[:, 0] = -1 * W[:, 0]
423
+ p, lower, upper = torch.lu_unpack(*torch.linalg.lu_factor(W))
424
 
425
  self.register_buffer("p", p)
426
  # diagonals of lower will always be 1s anyway
 
586
  self.in_layers = torch.nn.ModuleList()
587
  self.res_skip_layers = torch.nn.ModuleList()
588
  start = torch.nn.Conv1d(n_in_channels + n_context_dim, n_channels, 1)
589
+ start = torch.nn.utils.parametrizations.weight_norm(start, name="weight")
590
  self.start = start
591
  self.softplus = torch.nn.Softplus()
592
  self.affine_activation = affine_activation
 
615
  # in_layer = nn.utils.weight_norm(in_layer)
616
  self.in_layers.append(in_layer)
617
  res_skip_layer = nn.Conv1d(n_channels, n_channels, 1)
618
+ res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer)
619
  self.res_skip_layers.append(res_skip_layer)
620
 
621
  def forward(
 
793
  # output is unnormalized bin weights
794
 
795
  def forward(self, z, context, inverse=False, seq_lens=None):
796
+ b_s, _, t_s = z.size(0), z.size(1), z.size(2)
797
 
798
  # condition on z_0, transform z_1
799
  n_half = self.half_mel_channels
 
1055
 
1056
  attn = self.softmax(attn) # softmax along T2
1057
  return attn, attn_logprob
1058
+
1059
+
1060
+ def update_params(config, params):
1061
+ for param in params:
1062
+ print(param)
1063
+ k, v = param.split("=")
1064
+ try:
1065
+ v = ast.literal_eval(v)
1066
+ except Exception as e:
1067
+ print(e)
1068
+
1069
+ k_split = k.split(".")
1070
+ if len(k_split) > 1:
1071
+ parent_k = k_split[0]
1072
+ cur_param = [".".join(k_split[1:]) + "=" + str(v)]
1073
+ update_params(config[parent_k], cur_param)
1074
+ elif k in config and len(k_split) == 1:
1075
+ print(f"overriding {k} with {v}")
1076
+ config[k] = v
1077
+ else:
1078
+ print("{}, {} params not updated".format(k, v))
data.py CHANGED
@@ -39,21 +39,21 @@
39
  ###############################################################################
40
 
41
  import os
42
- import argparse
43
- import json
44
- import numpy as np
45
- import lmdb
46
  import pickle as pkl
 
 
47
  import torch
48
  import torch.utils.data
 
 
 
49
  from scipy.io.wavfile import read
50
- from audio_processing import TacotronSTFT
51
- from tts_text_processing.text_processing import TextProcessing
52
  from scipy.stats import betabinom
53
- from librosa import pyin
54
- from common import update_params
55
  from scipy.ndimage import distance_transform_edt as distance_transform
56
 
 
 
 
57
 
58
  def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=0.05):
59
  P = phoneme_count
@@ -401,7 +401,8 @@ class Data(torch.utils.data.Dataset):
401
  elif os.path.exists(f0_path):
402
  try:
403
  dikt = torch.load(f0_path)
404
- except:
 
405
  print(f"f0 loading from {f0_path} is broken, recomputing.")
406
 
407
  if dikt is not None:
@@ -460,147 +461,3 @@ class Data(torch.utils.data.Dataset):
460
 
461
  def __len__(self):
462
  return len(self.data)
463
-
464
-
465
- class DataCollate:
466
- """Zero-pads model inputs and targets given number of steps"""
467
-
468
- def __init__(self, n_frames_per_step=1):
469
- self.n_frames_per_step = n_frames_per_step
470
-
471
- def __call__(self, batch):
472
- """Collate from normalized data"""
473
- # Right zero-pad all one-hot text sequences to max input length
474
- input_lengths, ids_sorted_decreasing = torch.sort(
475
- torch.LongTensor([len(x["text_encoded"]) for x in batch]),
476
- dim=0,
477
- descending=True,
478
- )
479
-
480
- max_input_len = input_lengths[0]
481
- text_padded = torch.LongTensor(len(batch), max_input_len)
482
- text_padded.zero_()
483
-
484
- for i in range(len(ids_sorted_decreasing)):
485
- text = batch[ids_sorted_decreasing[i]]["text_encoded"]
486
- text_padded[i, : text.size(0)] = text
487
-
488
- # Right zero-pad mel-spec
489
- num_mel_channels = batch[0]["mel"].size(0)
490
- max_target_len = max([x["mel"].size(1) for x in batch])
491
-
492
- # include mel padded, gate padded and speaker ids
493
- mel_padded = torch.FloatTensor(len(batch), num_mel_channels, max_target_len)
494
- mel_padded.zero_()
495
- f0_padded = None
496
- p_voiced_padded = None
497
- voiced_mask_padded = None
498
- energy_avg_padded = None
499
- if batch[0]["f0"] is not None:
500
- f0_padded = torch.FloatTensor(len(batch), max_target_len)
501
- f0_padded.zero_()
502
-
503
- if batch[0]["p_voiced"] is not None:
504
- p_voiced_padded = torch.FloatTensor(len(batch), max_target_len)
505
- p_voiced_padded.zero_()
506
-
507
- if batch[0]["voiced_mask"] is not None:
508
- voiced_mask_padded = torch.FloatTensor(len(batch), max_target_len)
509
- voiced_mask_padded.zero_()
510
-
511
- if batch[0]["energy_avg"] is not None:
512
- energy_avg_padded = torch.FloatTensor(len(batch), max_target_len)
513
- energy_avg_padded.zero_()
514
-
515
- attn_prior_padded = torch.FloatTensor(len(batch), max_target_len, max_input_len)
516
- attn_prior_padded.zero_()
517
-
518
- output_lengths = torch.LongTensor(len(batch))
519
- speaker_ids = torch.LongTensor(len(batch))
520
- audiopaths = []
521
- for i in range(len(ids_sorted_decreasing)):
522
- mel = batch[ids_sorted_decreasing[i]]["mel"]
523
- mel_padded[i, :, : mel.size(1)] = mel
524
- if batch[ids_sorted_decreasing[i]]["f0"] is not None:
525
- f0 = batch[ids_sorted_decreasing[i]]["f0"]
526
- f0_padded[i, : len(f0)] = f0
527
-
528
- if batch[ids_sorted_decreasing[i]]["voiced_mask"] is not None:
529
- voiced_mask = batch[ids_sorted_decreasing[i]]["voiced_mask"]
530
- voiced_mask_padded[i, : len(f0)] = voiced_mask
531
-
532
- if batch[ids_sorted_decreasing[i]]["p_voiced"] is not None:
533
- p_voiced = batch[ids_sorted_decreasing[i]]["p_voiced"]
534
- p_voiced_padded[i, : len(f0)] = p_voiced
535
-
536
- if batch[ids_sorted_decreasing[i]]["energy_avg"] is not None:
537
- energy_avg = batch[ids_sorted_decreasing[i]]["energy_avg"]
538
- energy_avg_padded[i, : len(energy_avg)] = energy_avg
539
-
540
- output_lengths[i] = mel.size(1)
541
- speaker_ids[i] = batch[ids_sorted_decreasing[i]]["speaker_id"]
542
- audiopath = batch[ids_sorted_decreasing[i]]["audiopath"]
543
- audiopaths.append(audiopath)
544
- cur_attn_prior = batch[ids_sorted_decreasing[i]]["attn_prior"]
545
- if cur_attn_prior is None:
546
- attn_prior_padded = None
547
- else:
548
- attn_prior_padded[
549
- i, : cur_attn_prior.size(0), : cur_attn_prior.size(1)
550
- ] = cur_attn_prior
551
-
552
- return {
553
- "mel": mel_padded,
554
- "speaker_ids": speaker_ids,
555
- "text": text_padded,
556
- "input_lengths": input_lengths,
557
- "output_lengths": output_lengths,
558
- "audiopaths": audiopaths,
559
- "attn_prior": attn_prior_padded,
560
- "f0": f0_padded,
561
- "p_voiced": p_voiced_padded,
562
- "voiced_mask": voiced_mask_padded,
563
- "energy_avg": energy_avg_padded,
564
- }
565
-
566
-
567
- # ===================================================================
568
- # Takes directory of clean audio and makes directory of spectrograms
569
- # Useful for making test sets
570
- # ===================================================================
571
- if __name__ == "__main__":
572
- # Get defaults so it can work with no Sacred
573
- parser = argparse.ArgumentParser()
574
- parser.add_argument("-c", "--config", type=str, help="JSON file for configuration")
575
- parser.add_argument("-p", "--params", nargs="+", default=[])
576
- args = parser.parse_args()
577
- args.rank = 0
578
-
579
- # Parse configs. Globals nicer in this case
580
- with open(args.config) as f:
581
- data = f.read()
582
-
583
- config = json.loads(data)
584
- update_params(config, args.params)
585
- print(config)
586
-
587
- data_config = config["data_config"]
588
-
589
- ignore_keys = ["training_files", "validation_files"]
590
- trainset = Data(
591
- data_config["training_files"],
592
- **dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
593
- )
594
-
595
- valset = Data(
596
- data_config["validation_files"],
597
- **dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
598
- speaker_ids=trainset.speaker_ids,
599
- )
600
-
601
- collate_fn = DataCollate()
602
-
603
- for dataset in (trainset, valset):
604
- for i, batch in enumerate(dataset):
605
- out = batch
606
- print("{}/{}".format(i, len(dataset)))
 
39
  ###############################################################################
40
 
41
  import os
 
 
 
 
42
  import pickle as pkl
43
+
44
+ import lmdb
45
  import torch
46
  import torch.utils.data
47
+ import numpy as np
48
+
49
+ from librosa import pyin
50
  from scipy.io.wavfile import read
 
 
51
  from scipy.stats import betabinom
 
 
52
  from scipy.ndimage import distance_transform_edt as distance_transform
53
 
54
+ from audio_processing import TacotronSTFT
55
+ from tts_text_processing.text_processing import TextProcessing
56
+
57
 
58
  def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=0.05):
59
  P = phoneme_count
 
401
  elif os.path.exists(f0_path):
402
  try:
403
  dikt = torch.load(f0_path)
404
+ except Exception as e:
405
+ print(e)
406
  print(f"f0 loading from {f0_path} is broken, recomputing.")
407
 
408
  if dikt is not None:
 
461
 
462
  def __len__(self):
463
  return len(self.data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
export_weights.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ radtts_path = "models/radtts-pp-dap-model/model_dap_84000.pt"
4
+ radtts_path_state = "models/radtts-pp-dap-model/model_dap_84000_state.pt"
5
+
6
+ checkpoint_dict = torch.load(radtts_path, map_location="cpu")
7
+
8
+ del checkpoint_dict['iteration']
9
+ del checkpoint_dict['optimizer']
10
+ del checkpoint_dict['learning_rate']
11
+
12
+ print(checkpoint_dict.keys())
13
+
14
+ torch.save(checkpoint_dict, radtts_path_state)
loss.py DELETED
@@ -1,228 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: MIT
3
- #
4
- # Permission is hereby granted, free of charge, to any person obtaining a
5
- # copy of this software and associated documentation files (the "Software"),
6
- # to deal in the Software without restriction, including without limitation
7
- # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
- # and/or sell copies of the Software, and to permit persons to whom the
9
- # Software is furnished to do so, subject to the following conditions:
10
- #
11
- # The above copyright notice and this permission notice shall be included in
12
- # all copies or substantial portions of the Software.
13
- #
14
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17
- # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
- # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20
- # DEALINGS IN THE SOFTWARE.
21
- import torch
22
- import torch.nn as nn
23
- from torch.nn import functional as F
24
- from common import get_mask_from_lengths
25
-
26
-
27
- def compute_flow_loss(
28
- z, log_det_W_list, log_s_list, n_elements, n_dims, mask, sigma=1.0
29
- ):
30
- log_det_W_total = 0.0
31
- for i, log_s in enumerate(log_s_list):
32
- if i == 0:
33
- log_s_total = torch.sum(log_s * mask)
34
- if len(log_det_W_list):
35
- log_det_W_total = log_det_W_list[i]
36
- else:
37
- log_s_total = log_s_total + torch.sum(log_s * mask)
38
- if len(log_det_W_list):
39
- log_det_W_total += log_det_W_list[i]
40
-
41
- if len(log_det_W_list):
42
- log_det_W_total *= n_elements
43
-
44
- z = z * mask
45
- prior_NLL = torch.sum(z * z) / (2 * sigma * sigma)
46
-
47
- loss = prior_NLL - log_s_total - log_det_W_total
48
-
49
- denom = n_elements * n_dims
50
- loss = loss / denom
51
- loss_prior = prior_NLL / denom
52
- return loss, loss_prior
53
-
54
-
55
- def compute_regression_loss(x_hat, x, mask, name=False):
56
- x = x[:, None] if len(x.shape) == 2 else x # add channel dim
57
- mask = mask[:, None] if len(mask.shape) == 2 else mask # add channel dim
58
- assert len(x.shape) == len(mask.shape)
59
-
60
- x = x * mask
61
- x_hat = x_hat * mask
62
-
63
- if name == "vpred":
64
- loss = F.binary_cross_entropy_with_logits(x_hat, x, reduction="sum")
65
- else:
66
- loss = F.mse_loss(x_hat, x, reduction="sum")
67
- loss = loss / mask.sum()
68
-
69
- loss_dict = {"loss_{}".format(name): loss}
70
-
71
- return loss_dict
72
-
73
-
74
- class AttributePredictionLoss(torch.nn.Module):
75
- def __init__(self, name, model_config, loss_weight, sigma=1.0):
76
- super(AttributePredictionLoss, self).__init__()
77
- self.name = name
78
- self.sigma = sigma
79
- self.model_name = model_config["name"]
80
- self.loss_weight = loss_weight
81
- self.n_group_size = 1
82
- if "n_group_size" in model_config["hparams"]:
83
- self.n_group_size = model_config["hparams"]["n_group_size"]
84
-
85
- def forward(self, model_output, lens):
86
- mask = get_mask_from_lengths(lens // self.n_group_size)
87
- mask = mask[:, None].float()
88
- loss_dict = {}
89
- if "z" in model_output:
90
- n_elements = lens.sum() // self.n_group_size
91
- n_dims = model_output["z"].size(1)
92
-
93
- loss, loss_prior = compute_flow_loss(
94
- model_output["z"],
95
- model_output["log_det_W_list"],
96
- model_output["log_s_list"],
97
- n_elements,
98
- n_dims,
99
- mask,
100
- self.sigma,
101
- )
102
- loss_dict = {
103
- "loss_{}".format(self.name): (loss, self.loss_weight),
104
- "loss_prior_{}".format(self.name): (loss_prior, 0.0),
105
- }
106
- elif "x_hat" in model_output:
107
- loss_dict = compute_regression_loss(
108
- model_output["x_hat"], model_output["x"], mask, self.name
109
- )
110
- for k, v in loss_dict.items():
111
- loss_dict[k] = (v, self.loss_weight)
112
-
113
- if len(loss_dict) == 0:
114
- raise Exception("loss not supported")
115
-
116
- return loss_dict
117
-
118
-
119
- class AttentionCTCLoss(torch.nn.Module):
120
- def __init__(self, blank_logprob=-1):
121
- super(AttentionCTCLoss, self).__init__()
122
- self.log_softmax = torch.nn.LogSoftmax(dim=3)
123
- self.blank_logprob = blank_logprob
124
- self.CTCLoss = nn.CTCLoss(zero_infinity=True)
125
-
126
- def forward(self, attn_logprob, in_lens, out_lens):
127
- key_lens = in_lens
128
- query_lens = out_lens
129
- attn_logprob_padded = F.pad(
130
- input=attn_logprob, pad=(1, 0, 0, 0, 0, 0, 0, 0), value=self.blank_logprob
131
- )
132
- cost_total = 0.0
133
- for bid in range(attn_logprob.shape[0]):
134
- target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
135
- curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[
136
- : query_lens[bid], :, : key_lens[bid] + 1
137
- ]
138
- curr_logprob = self.log_softmax(curr_logprob[None])[0]
139
- ctc_cost = self.CTCLoss(
140
- curr_logprob,
141
- target_seq,
142
- input_lengths=query_lens[bid : bid + 1],
143
- target_lengths=key_lens[bid : bid + 1],
144
- )
145
- cost_total += ctc_cost
146
- cost = cost_total / attn_logprob.shape[0]
147
- return cost
148
-
149
-
150
- class AttentionBinarizationLoss(torch.nn.Module):
151
- def __init__(self):
152
- super(AttentionBinarizationLoss, self).__init__()
153
-
154
- def forward(self, hard_attention, soft_attention):
155
- log_sum = torch.log(soft_attention[hard_attention == 1]).sum()
156
- return -log_sum / hard_attention.sum()
157
-
158
-
159
- class RADTTSLoss(torch.nn.Module):
160
- def __init__(
161
- self,
162
- sigma=1.0,
163
- n_group_size=1,
164
- dur_model_config=None,
165
- f0_model_config=None,
166
- energy_model_config=None,
167
- vpred_model_config=None,
168
- loss_weights=None,
169
- ):
170
- super(RADTTSLoss, self).__init__()
171
- self.sigma = sigma
172
- self.n_group_size = n_group_size
173
- self.loss_weights = loss_weights
174
- self.attn_ctc_loss = AttentionCTCLoss(
175
- blank_logprob=loss_weights.get("blank_logprob", -1)
176
- )
177
- self.loss_fns = {}
178
- if dur_model_config is not None:
179
- self.loss_fns["duration_model_outputs"] = AttributePredictionLoss(
180
- "duration", dur_model_config, loss_weights["dur_loss_weight"]
181
- )
182
-
183
- if f0_model_config is not None:
184
- self.loss_fns["f0_model_outputs"] = AttributePredictionLoss(
185
- "f0", f0_model_config, loss_weights["f0_loss_weight"], sigma=1.0
186
- )
187
-
188
- if energy_model_config is not None:
189
- self.loss_fns["energy_model_outputs"] = AttributePredictionLoss(
190
- "energy", energy_model_config, loss_weights["energy_loss_weight"]
191
- )
192
-
193
- if vpred_model_config is not None:
194
- self.loss_fns["vpred_model_outputs"] = AttributePredictionLoss(
195
- "vpred", vpred_model_config, loss_weights["vpred_loss_weight"]
196
- )
197
-
198
- def forward(self, model_output, in_lens, out_lens):
199
- loss_dict = {}
200
- if len(model_output["z_mel"]):
201
- n_elements = out_lens.sum() // self.n_group_size
202
- mask = get_mask_from_lengths(out_lens // self.n_group_size)
203
- mask = mask[:, None].float()
204
- n_dims = model_output["z_mel"].size(1)
205
- loss_mel, loss_prior_mel = compute_flow_loss(
206
- model_output["z_mel"],
207
- model_output["log_det_W_list"],
208
- model_output["log_s_list"],
209
- n_elements,
210
- n_dims,
211
- mask,
212
- self.sigma,
213
- )
214
- loss_dict["loss_mel"] = (loss_mel, 1.0) # loss, weight
215
- loss_dict["loss_prior_mel"] = (loss_prior_mel, 0.0)
216
-
217
- ctc_cost = self.attn_ctc_loss(model_output["attn_logprob"], in_lens, out_lens)
218
- loss_dict["loss_ctc"] = (ctc_cost, self.loss_weights["ctc_loss_weight"])
219
-
220
- for k in model_output:
221
- if k in self.loss_fns:
222
- if model_output[k] is not None and len(model_output[k]) > 0:
223
- t_lens = in_lens if "dur" in k else out_lens
224
- mout = model_output[k]
225
- for loss_name, v in self.loss_fns[k](mout, t_lens).items():
226
- loss_dict[loss_name] = v
227
-
228
- return loss_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
partialconv1d.py CHANGED
@@ -13,10 +13,9 @@
13
 
14
  import torch
15
  import torch.nn.functional as F
16
- from torch import nn
17
 
18
 
19
- class PartialConv1d(nn.Conv1d):
20
  def __init__(self, *args, **kwargs):
21
  self.multi_channel = False
22
  self.return_mask = False
 
13
 
14
  import torch
15
  import torch.nn.functional as F
 
16
 
17
 
18
+ class PartialConv1d(torch.nn.Conv1d):
19
  def __init__(self, *args, **kwargs):
20
  self.multi_channel = False
21
  self.return_mask = False
radam.py DELETED
@@ -1,114 +0,0 @@
1
- # Original source taken from https://github.com/LiyuanLucasLiu/RAdam
2
- #
3
- # Copyright 2019 Liyuan Liu
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- import math
17
-
18
- import torch
19
-
20
- # pylint: disable=no-name-in-module
21
- from torch.optim.optimizer import Optimizer
22
-
23
-
24
- class RAdam(Optimizer):
25
- """RAdam optimizer"""
26
-
27
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
28
- """
29
- Init
30
-
31
- :param params: parameters to optimize
32
- :param lr: learning rate
33
- :param betas: beta
34
- :param eps: numerical precision
35
- :param weight_decay: weight decay weight
36
- """
37
- defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
38
- self.buffer = [[None, None, None] for _ in range(10)]
39
- super().__init__(params, defaults)
40
-
41
- def step(self, closure=None):
42
- loss = None
43
- if closure is not None:
44
- loss = closure()
45
-
46
- for group in self.param_groups:
47
- for p in group["params"]:
48
- if p.grad is None:
49
- continue
50
- grad = p.grad.data.float()
51
- if grad.is_sparse:
52
- raise RuntimeError("RAdam does not support sparse gradients")
53
-
54
- p_data_fp32 = p.data.float()
55
-
56
- state = self.state[p]
57
-
58
- if len(state) == 0:
59
- state["step"] = 0
60
- state["exp_avg"] = torch.zeros_like(p_data_fp32)
61
- state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
62
- else:
63
- state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
64
- state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
65
-
66
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
67
- beta1, beta2 = group["betas"]
68
-
69
- exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
70
- exp_avg.mul_(beta1).add_(1 - beta1, grad)
71
-
72
- state["step"] += 1
73
- buffered = self.buffer[int(state["step"] % 10)]
74
- if state["step"] == buffered[0]:
75
- N_sma, step_size = buffered[1], buffered[2]
76
- else:
77
- buffered[0] = state["step"]
78
- beta2_t = beta2 ** state["step"]
79
- N_sma_max = 2 / (1 - beta2) - 1
80
- N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
81
- buffered[1] = N_sma
82
-
83
- # more conservative since it's an approximated value
84
- if N_sma >= 5:
85
- step_size = (
86
- group["lr"]
87
- * math.sqrt(
88
- (1 - beta2_t)
89
- * (N_sma - 4)
90
- / (N_sma_max - 4)
91
- * (N_sma - 2)
92
- / N_sma
93
- * N_sma_max
94
- / (N_sma_max - 2)
95
- )
96
- / (1 - beta1 ** state["step"])
97
- )
98
- else:
99
- step_size = group["lr"] / (1 - beta1 ** state["step"])
100
- buffered[2] = step_size
101
-
102
- if group["weight_decay"] != 0:
103
- p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
104
-
105
- # more conservative since it's an approximated value
106
- if N_sma >= 5:
107
- denom = exp_avg_sq.sqrt().add_(group["eps"])
108
- p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
109
- else:
110
- p_data_fp32.add_(-step_size, exp_avg)
111
-
112
- p.data.copy_(p_data_fp32)
113
-
114
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
radtts.py CHANGED
@@ -28,8 +28,7 @@ from common import AffineTransformationLayer, LinearNorm, ExponentialClass
28
  from common import get_mask_from_lengths
29
  from attribute_prediction_model import get_attribute_prediction_model
30
  from alignment import mas_width1 as mas
31
-
32
- use_cuda = torch.cuda.is_available()
33
 
34
 
35
  class FlowStep(nn.Module):
@@ -202,10 +201,10 @@ class RADTTS(torch.nn.Module):
202
  if context_lstm_norm is not None:
203
  if "spectral" in context_lstm_norm:
204
  print("Applying spectral norm to context encoder LSTM")
205
- lstm_norm_fn_pntr = torch.nn.utils.spectral_norm
206
  elif "weight" in context_lstm_norm:
207
  print("Applying weight norm to context encoder LSTM")
208
- lstm_norm_fn_pntr = torch.nn.utils.weight_norm
209
 
210
  self.context_lstm = lstm_norm_fn_pntr(
211
  self.context_lstm, "weight_hh_l0"
@@ -688,11 +687,10 @@ class RADTTS(torch.nn.Module):
688
 
689
  if dur is None:
690
  # get token durations
691
- z_dur = torch.empty(batch_size, 1, n_tokens, dtype=torch.float32)
692
- if use_cuda:
693
- z_dur = z_dur.cuda()
694
-
695
- z_dur = z_dur.normal_() * sigma_dur
696
 
697
  dur = self.dur_pred_layer.infer(z_dur, txt_enc, spk_vec_text)
698
  if dur.shape[-1] < txt_enc.shape[-1]:
@@ -752,9 +750,7 @@ class RADTTS(torch.nn.Module):
752
  dtype=torch.float32,
753
  )
754
  * sigma_f0
755
- )
756
- if use_cuda:
757
- z_f0 = z_f0.cuda()
758
 
759
  f0 = self.infer_f0(
760
  z_f0,
@@ -780,13 +776,11 @@ class RADTTS(torch.nn.Module):
780
  n_energy_feature_channels,
781
  max_n_frames,
782
  dtype=torch.float32,
 
783
  )
784
  * sigma_energy
785
  )
786
 
787
- if use_cuda:
788
- z_energy_avg = z_energy_avg.cuda()
789
-
790
  energy_avg = self.infer_energy(
791
  z_energy_avg, ap_txt_enc_time_expanded, spk_vec, out_lens
792
  )[:, 0]
@@ -829,9 +823,7 @@ class RADTTS(torch.nn.Module):
829
  80 * self.n_group_size,
830
  max_n_frames // self.n_group_size,
831
  dtype=torch.float32,
832
- )
833
- if use_cuda:
834
- residual = residual.cuda()
835
 
836
  residual = residual * sigma
837
 
@@ -921,15 +913,17 @@ class RADTTS(torch.nn.Module):
921
  try:
922
  nn.utils.remove_spectral_norm(module, name="weight_hh_l0")
923
  print("Removed spectral norm from {}".format(name))
924
- except:
925
- pass
 
926
  try:
927
  nn.utils.remove_spectral_norm(module, name="weight_hh_l0_reverse")
928
  print("Removed spectral norm from {}".format(name))
929
- except:
930
- pass
 
931
  try:
932
  nn.utils.remove_weight_norm(module)
933
  print("Removed wnorm from {}".format(name))
934
- except:
935
- pass
 
28
  from common import get_mask_from_lengths
29
  from attribute_prediction_model import get_attribute_prediction_model
30
  from alignment import mas_width1 as mas
31
+ from torch_env import device
 
32
 
33
 
34
  class FlowStep(nn.Module):
 
201
  if context_lstm_norm is not None:
202
  if "spectral" in context_lstm_norm:
203
  print("Applying spectral norm to context encoder LSTM")
204
+ lstm_norm_fn_pntr = torch.nn.utils.parametrizations.spectral_norm
205
  elif "weight" in context_lstm_norm:
206
  print("Applying weight norm to context encoder LSTM")
207
+ lstm_norm_fn_pntr = torch.nn.utils.parametrizations.weight_norm
208
 
209
  self.context_lstm = lstm_norm_fn_pntr(
210
  self.context_lstm, "weight_hh_l0"
 
687
 
688
  if dur is None:
689
  # get token durations
690
+ z_dur = (
691
+ torch.randn(batch_size, 1, n_tokens, dtype=torch.float32, device=device)
692
+ * sigma_dur
693
+ )
 
694
 
695
  dur = self.dur_pred_layer.infer(z_dur, txt_enc, spk_vec_text)
696
  if dur.shape[-1] < txt_enc.shape[-1]:
 
750
  dtype=torch.float32,
751
  )
752
  * sigma_f0
753
+ ).to(device)
 
 
754
 
755
  f0 = self.infer_f0(
756
  z_f0,
 
776
  n_energy_feature_channels,
777
  max_n_frames,
778
  dtype=torch.float32,
779
+ device=device,
780
  )
781
  * sigma_energy
782
  )
783
 
 
 
 
784
  energy_avg = self.infer_energy(
785
  z_energy_avg, ap_txt_enc_time_expanded, spk_vec, out_lens
786
  )[:, 0]
 
823
  80 * self.n_group_size,
824
  max_n_frames // self.n_group_size,
825
  dtype=torch.float32,
826
+ ).to(device)
 
 
827
 
828
  residual = residual * sigma
829
 
 
913
  try:
914
  nn.utils.remove_spectral_norm(module, name="weight_hh_l0")
915
  print("Removed spectral norm from {}".format(name))
916
+ except Exception as e:
917
+ print(e)
918
+
919
  try:
920
  nn.utils.remove_spectral_norm(module, name="weight_hh_l0_reverse")
921
  print("Removed spectral norm from {}".format(name))
922
+ except Exception as e:
923
+ print(e)
924
+
925
  try:
926
  nn.utils.remove_weight_norm(module)
927
  print("Removed wnorm from {}".format(name))
928
+ except Exception as e:
929
+ print(e)
requirements.txt CHANGED
@@ -9,7 +9,4 @@ numba
9
  lmdb
10
  librosa
11
 
12
- unidecode
13
- inflect
14
-
15
  git+https://github.com/langtech-bsc/vocos.git@matcha
 
9
  lmdb
10
  librosa
11
 
 
 
 
12
  git+https://github.com/langtech-bsc/vocos.git@matcha
torch_env.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ seed = 1234
4
+
5
+ # use_mps = torch.mps.is_available()
6
+ use_mps = False
7
+ use_cuda = torch.cuda.is_available()
8
+
9
+ if use_mps:
10
+ device = "mps"
11
+ torch.mps.manual_seed(seed)
12
+ elif use_cuda:
13
+ device = "cuda"
14
+ torch.cuda.manual_seed(seed)
15
+ else:
16
+ device = "cpu"
17
+ torch.manual_seed(seed)
18
+
19
+ print(f"Inference device: {device}")
tts_text_processing/abbreviations.py DELETED
@@ -1,57 +0,0 @@
1
- import re
2
-
3
- _no_period_re = re.compile(r"(No[.])(?=[ ]?[0-9])")
4
- _percent_re = re.compile(r"([ ]?[%])")
5
- _half_re = re.compile("([0-9]½)|(½)")
6
-
7
-
8
- # List of (regular expression, replacement) pairs for abbreviations:
9
- _abbreviations = [
10
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
11
- for x in [
12
- ("mrs", "misess"),
13
- ("ms", "miss"),
14
- ("mr", "mister"),
15
- ("dr", "doctor"),
16
- ("st", "saint"),
17
- ("co", "company"),
18
- ("jr", "junior"),
19
- ("maj", "major"),
20
- ("gen", "general"),
21
- ("drs", "doctors"),
22
- ("rev", "reverend"),
23
- ("lt", "lieutenant"),
24
- ("hon", "honorable"),
25
- ("sgt", "sergeant"),
26
- ("capt", "captain"),
27
- ("esq", "esquire"),
28
- ("ltd", "limited"),
29
- ("col", "colonel"),
30
- ("ft", "fort"),
31
- ]
32
- ]
33
-
34
-
35
- def _expand_no_period(m):
36
- word = m.group(0)
37
- if word[0] == "N":
38
- return "Number"
39
- return "number"
40
-
41
-
42
- def _expand_percent(m):
43
- return " percent"
44
-
45
-
46
- def _expand_half(m):
47
- word = m.group(1)
48
- if word is None:
49
- return "half"
50
- return word[0] + " and a half"
51
-
52
-
53
- def normalize_abbreviations(text):
54
- text = re.sub(_no_period_re, _expand_no_period, text)
55
- text = re.sub(_percent_re, _expand_percent, text)
56
- text = re.sub(_half_re, _expand_half, text)
57
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tts_text_processing/acronyms.py DELETED
@@ -1,69 +0,0 @@
1
- import re
2
-
3
- _letter_to_arpabet = {
4
- "A": "EY1",
5
- "B": "B IY1",
6
- "C": "S IY1",
7
- "D": "D IY1",
8
- "E": "IY1",
9
- "F": "EH1 F",
10
- "G": "JH IY1",
11
- "H": "EY1 CH",
12
- "I": "AY1",
13
- "J": "JH EY1",
14
- "K": "K EY1",
15
- "L": "EH1 L",
16
- "M": "EH1 M",
17
- "N": "EH1 N",
18
- "O": "OW1",
19
- "P": "P IY1",
20
- "Q": "K Y UW1",
21
- "R": "AA1 R",
22
- "S": "EH1 S",
23
- "T": "T IY1",
24
- "U": "Y UW1",
25
- "V": "V IY1",
26
- "X": "EH1 K S",
27
- "Y": "W AY1",
28
- "W": "D AH1 B AH0 L Y UW0",
29
- "Z": "Z IY1",
30
- "s": "Z",
31
- }
32
-
33
- # must ignore roman numerals
34
- # _acronym_re = re.compile(r'([A-Z][A-Z]+)s?|([A-Z]\.([A-Z]\.)+s?)')
35
- _acronym_re = re.compile(r"([A-Z][A-Z]+)s?")
36
-
37
-
38
- class AcronymNormalizer(object):
39
- def __init__(self, phoneme_dict):
40
- self.phoneme_dict = phoneme_dict
41
-
42
- def normalize_acronyms(self, text):
43
- def _expand_acronyms(m, add_spaces=True):
44
- acronym = m.group(0)
45
- # remove dots if they exist
46
- acronym = re.sub("\.", "", acronym)
47
-
48
- acronym = "".join(acronym.split())
49
- arpabet = self.phoneme_dict.lookup(acronym)
50
-
51
- if arpabet is None:
52
- acronym = list(acronym)
53
- arpabet = ["{" + _letter_to_arpabet[letter] + "}" for letter in acronym]
54
- # temporary fix
55
- if arpabet[-1] == "{Z}" and len(arpabet) > 1:
56
- arpabet[-2] = arpabet[-2][:-1] + " " + arpabet[-1][1:]
57
- del arpabet[-1]
58
- arpabet = " ".join(arpabet)
59
- elif len(arpabet) == 1:
60
- arpabet = "{" + arpabet[0] + "}"
61
- else:
62
- arpabet = acronym
63
- return arpabet
64
-
65
- text = re.sub(_acronym_re, _expand_acronyms, text)
66
- return text
67
-
68
- def __call__(self, text):
69
- return self.normalize_acronyms(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tts_text_processing/cleaners.py CHANGED
@@ -1,26 +1,8 @@
1
  """adapted from https://github.com/keithito/tacotron"""
2
 
3
- """
4
- Cleaners are transformations that run over the input text at both training and eval time.
5
-
6
- Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7
- hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8
- 1. "english_cleaners" for English text
9
- 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10
- the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11
- 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12
- the symbols in symbols.py to match your data).
13
- """
14
-
15
  import re
16
  from string import punctuation
17
  from functools import reduce
18
- from unidecode import unidecode
19
- from .numerical import normalize_numbers, normalize_currency
20
- from .acronyms import AcronymNormalizer
21
- from .datestime import normalize_datestime
22
- from .letters_and_numbers import normalize_letters_and_numbers
23
- from .abbreviations import normalize_abbreviations
24
 
25
 
26
  # Regular expression matching whitespace:
@@ -30,26 +12,6 @@ _whitespace_re = re.compile(r"\s+")
30
  _arpa_re = re.compile(r"{[^}]+}|\S+")
31
 
32
 
33
- def expand_abbreviations(text):
34
- return normalize_abbreviations(text)
35
-
36
-
37
- def expand_numbers(text):
38
- return normalize_numbers(text)
39
-
40
-
41
- def expand_currency(text):
42
- return normalize_currency(text)
43
-
44
-
45
- def expand_datestime(text):
46
- return normalize_datestime(text)
47
-
48
-
49
- def expand_letters_and_numbers(text):
50
- return normalize_letters_and_numbers(text)
51
-
52
-
53
  def lowercase(text):
54
  return text.lower()
55
 
@@ -58,21 +20,6 @@ def collapse_whitespace(text):
58
  return re.sub(_whitespace_re, " ", text)
59
 
60
 
61
- def separate_acronyms(text):
62
- text = re.sub(r"([0-9]+)([a-zA-Z]+)", r"\1 \2", text)
63
- text = re.sub(r"([a-zA-Z]+)([0-9]+)", r"\1 \2", text)
64
- return text
65
-
66
-
67
- def convert_to_ascii(text):
68
- return unidecode(text)
69
-
70
-
71
- def dehyphenize_compound_words(text):
72
- text = re.sub(r"(?<=[a-zA-Z0-9])-(?=[a-zA-Z])", " ", text)
73
- return text
74
-
75
-
76
  def remove_space_before_punctuation(text):
77
  return re.sub(r"\s([{}](?:\s|$))".format(punctuation), r"\1", text)
78
 
@@ -81,7 +28,6 @@ class Cleaner(object):
81
  def __init__(self, cleaner_names, phonemedict):
82
  self.cleaner_names = cleaner_names
83
  self.phonemedict = phonemedict
84
- self.acronym_normalizer = AcronymNormalizer(self.phonemedict)
85
 
86
  def __call__(self, text):
87
  for cleaner_name in self.cleaner_names:
@@ -94,30 +40,13 @@ class Cleaner(object):
94
  for split in _arpa_re.findall(text)
95
  ]
96
  text = " ".join(text)
 
97
  text = remove_space_before_punctuation(text)
 
98
  return text
99
 
100
  def get_cleaner_fns(self, cleaner_name):
101
- if cleaner_name == "basic_cleaners":
102
- sequence_fns = [lowercase, collapse_whitespace]
103
- word_fns = []
104
- elif cleaner_name == "english_cleaners":
105
- sequence_fns = [collapse_whitespace, convert_to_ascii, lowercase]
106
- word_fns = [expand_numbers, expand_abbreviations]
107
- elif cleaner_name == "radtts_cleaners":
108
- sequence_fns = [
109
- collapse_whitespace,
110
- expand_currency,
111
- expand_datestime,
112
- expand_letters_and_numbers,
113
- ]
114
- word_fns = [expand_numbers, expand_abbreviations]
115
- elif cleaner_name == "ukrainian_cleaners":
116
- sequence_fns = [lowercase, collapse_whitespace]
117
- word_fns = []
118
- elif cleaner_name == "transliteration_cleaners":
119
- sequence_fns = [convert_to_ascii, lowercase, collapse_whitespace]
120
- else:
121
- raise Exception("{} cleaner not supported".format(cleaner_name))
122
 
123
  return sequence_fns, word_fns
 
1
  """adapted from https://github.com/keithito/tacotron"""
2
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import re
4
  from string import punctuation
5
  from functools import reduce
 
 
 
 
 
 
6
 
7
 
8
  # Regular expression matching whitespace:
 
12
  _arpa_re = re.compile(r"{[^}]+}|\S+")
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def lowercase(text):
16
  return text.lower()
17
 
 
20
  return re.sub(_whitespace_re, " ", text)
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def remove_space_before_punctuation(text):
24
  return re.sub(r"\s([{}](?:\s|$))".format(punctuation), r"\1", text)
25
 
 
28
  def __init__(self, cleaner_names, phonemedict):
29
  self.cleaner_names = cleaner_names
30
  self.phonemedict = phonemedict
 
31
 
32
  def __call__(self, text):
33
  for cleaner_name in self.cleaner_names:
 
40
  for split in _arpa_re.findall(text)
41
  ]
42
  text = " ".join(text)
43
+
44
  text = remove_space_before_punctuation(text)
45
+
46
  return text
47
 
48
  def get_cleaner_fns(self, cleaner_name):
49
+ sequence_fns = [lowercase, collapse_whitespace]
50
+ word_fns = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  return sequence_fns, word_fns
tts_text_processing/cmudict.py DELETED
@@ -1,140 +0,0 @@
1
- """adapted from https://github.com/keithito/tacotron"""
2
-
3
- import re
4
-
5
-
6
- valid_symbols = [
7
- "AA",
8
- "AA0",
9
- "AA1",
10
- "AA2",
11
- "AE",
12
- "AE0",
13
- "AE1",
14
- "AE2",
15
- "AH",
16
- "AH0",
17
- "AH1",
18
- "AH2",
19
- "AO",
20
- "AO0",
21
- "AO1",
22
- "AO2",
23
- "AW",
24
- "AW0",
25
- "AW1",
26
- "AW2",
27
- "AY",
28
- "AY0",
29
- "AY1",
30
- "AY2",
31
- "B",
32
- "CH",
33
- "D",
34
- "DH",
35
- "EH",
36
- "EH0",
37
- "EH1",
38
- "EH2",
39
- "ER",
40
- "ER0",
41
- "ER1",
42
- "ER2",
43
- "EY",
44
- "EY0",
45
- "EY1",
46
- "EY2",
47
- "F",
48
- "G",
49
- "HH",
50
- "IH",
51
- "IH0",
52
- "IH1",
53
- "IH2",
54
- "IY",
55
- "IY0",
56
- "IY1",
57
- "IY2",
58
- "JH",
59
- "K",
60
- "L",
61
- "M",
62
- "N",
63
- "NG",
64
- "OW",
65
- "OW0",
66
- "OW1",
67
- "OW2",
68
- "OY",
69
- "OY0",
70
- "OY1",
71
- "OY2",
72
- "P",
73
- "R",
74
- "S",
75
- "SH",
76
- "T",
77
- "TH",
78
- "UH",
79
- "UH0",
80
- "UH1",
81
- "UH2",
82
- "UW",
83
- "UW0",
84
- "UW1",
85
- "UW2",
86
- "V",
87
- "W",
88
- "Y",
89
- "Z",
90
- "ZH",
91
- ]
92
-
93
- _valid_symbol_set = set(valid_symbols)
94
-
95
-
96
- class CMUDict:
97
- """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
98
-
99
- def __init__(self, file_or_path, keep_ambiguous=True):
100
- if isinstance(file_or_path, str):
101
- with open(file_or_path, encoding="latin-1") as f:
102
- entries = _parse_cmudict(f)
103
- else:
104
- entries = _parse_cmudict(file_or_path)
105
- if not keep_ambiguous:
106
- entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
107
- self._entries = entries
108
-
109
- def __len__(self):
110
- return len(self._entries)
111
-
112
- def lookup(self, word):
113
- """Returns list of ARPAbet pronunciations of the given word."""
114
- return self._entries.get(word.upper())
115
-
116
-
117
- _alt_re = re.compile(r"\([0-9]+\)")
118
-
119
-
120
- def _parse_cmudict(file):
121
- cmudict = {}
122
- for line in file:
123
- if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
124
- parts = line.split(" ")
125
- word = re.sub(_alt_re, "", parts[0])
126
- pronunciation = _get_pronunciation(parts[1])
127
- if pronunciation:
128
- if word in cmudict:
129
- cmudict[word].append(pronunciation)
130
- else:
131
- cmudict[word] = [pronunciation]
132
- return cmudict
133
-
134
-
135
- def _get_pronunciation(s):
136
- parts = s.strip().split(" ")
137
- for part in parts:
138
- if part not in _valid_symbol_set:
139
- return None
140
- return " ".join(parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tts_text_processing/datestime.py DELETED
@@ -1,24 +0,0 @@
1
- """adapted from https://github.com/keithito/tacotron"""
2
-
3
- import re
4
-
5
- _ampm_re = re.compile(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):?([0-5][0-9])?\s*([AaPp][Mm]\b)")
6
-
7
-
8
- def _expand_ampm(m):
9
- matches = list(m.groups(0))
10
- txt = matches[0]
11
- txt = txt if int(matches[1]) == 0 else txt + " " + matches[1]
12
-
13
- if matches[2][0].lower() == "a":
14
- txt += " a.m."
15
- elif matches[2][0].lower() == "p":
16
- txt += " p.m."
17
-
18
- return txt
19
-
20
-
21
- def normalize_datestime(text):
22
- text = re.sub(_ampm_re, _expand_ampm, text)
23
- # text = re.sub(r"([0-9]|0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])?", r"\1 \2", text)
24
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tts_text_processing/grapheme_dictionary.py DELETED
@@ -1,37 +0,0 @@
1
- """adapted from https://github.com/keithito/tacotron"""
2
-
3
- import re
4
-
5
- _alt_re = re.compile(r"\([0-9]+\)")
6
-
7
-
8
- class Grapheme2PhonemeDictionary:
9
- """Thin wrapper around g2p data."""
10
-
11
- def __init__(self, file_or_path, keep_ambiguous=True, encoding="latin-1"):
12
- with open(file_or_path, encoding=encoding) as f:
13
- entries = _parse_g2p(f)
14
- if not keep_ambiguous:
15
- entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
16
- self._entries = entries
17
-
18
- def __len__(self):
19
- return len(self._entries)
20
-
21
- def lookup(self, word):
22
- """Returns list of pronunciations of the given word."""
23
- return self._entries.get(word.upper())
24
-
25
-
26
- def _parse_g2p(file):
27
- g2p = {}
28
- for line in file:
29
- if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
30
- parts = line.split(" ")
31
- word = re.sub(_alt_re, "", parts[0])
32
- pronunciation = parts[1].strip()
33
- if word in g2p:
34
- g2p[word].append(pronunciation)
35
- else:
36
- g2p[word] = [pronunciation]
37
- return g2p
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tts_text_processing/heteronyms DELETED
@@ -1,413 +0,0 @@
1
- abject
2
- abrogate
3
- absent
4
- abstract
5
- abuse
6
- ache
7
- acre
8
- acuminate
9
- addict
10
- address
11
- adduct
12
- adele
13
- advocate
14
- affect
15
- affiliate
16
- agape
17
- aged
18
- agglomerate
19
- aggregate
20
- agonic
21
- agora
22
- allied
23
- ally
24
- alternate
25
- alum
26
- am
27
- analyses
28
- andrea
29
- animate
30
- apply
31
- appropriate
32
- approximate
33
- ares
34
- arithmetic
35
- arsenic
36
- articulate
37
- associate
38
- attribute
39
- august
40
- axes
41
- ay
42
- aye
43
- bases
44
- bass
45
- bathed
46
- bested
47
- bifurcate
48
- blessed
49
- blotto
50
- bow
51
- bowed
52
- bowman
53
- brassy
54
- buffet
55
- bustier
56
- carbonate
57
- celtic
58
- choral
59
- chumash
60
- close
61
- closer
62
- coax
63
- coincidence
64
- color coordinate
65
- colour coordinate
66
- comber
67
- combine
68
- combs
69
- committee
70
- commune
71
- compact
72
- complex
73
- compound
74
- compress
75
- concert
76
- conduct
77
- confine
78
- confines
79
- conflict
80
- conglomerate
81
- conscript
82
- conserve
83
- consist
84
- console
85
- consort
86
- construct
87
- consult
88
- consummate
89
- content
90
- contest
91
- contract
92
- contracts
93
- contrast
94
- converse
95
- convert
96
- convict
97
- coop
98
- coordinate
99
- covey
100
- crooked
101
- curate
102
- cussed
103
- decollate
104
- decrease
105
- defect
106
- defense
107
- delegate
108
- deliberate
109
- denier
110
- desert
111
- detail
112
- deviate
113
- diagnoses
114
- diffuse
115
- digest
116
- discard
117
- discharge
118
- discount
119
- do
120
- document
121
- does
122
- dogged
123
- domesticate
124
- dominican
125
- dove
126
- dr
127
- drawer
128
- duplicate
129
- egress
130
- ejaculate
131
- eject
132
- elaborate
133
- ellipses
134
- email
135
- emu
136
- entrace
137
- entrance
138
- escort
139
- estimate
140
- eta
141
- etna
142
- evening
143
- excise
144
- excuse
145
- exploit
146
- export
147
- extract
148
- fine
149
- flower
150
- forbear
151
- four-legged
152
- frequent
153
- furrier
154
- gallant
155
- gel
156
- geminate
157
- gillie
158
- glower
159
- gotham
160
- graduate
161
- haggis
162
- heavy
163
- hinder
164
- house
165
- housewife
166
- impact
167
- imped
168
- implant
169
- implement
170
- import
171
- impress
172
- incense
173
- incline
174
- increase
175
- infix
176
- insert
177
- instar
178
- insult
179
- integral
180
- intercept
181
- interchange
182
- interflow
183
- interleaf
184
- intermediate
185
- intern
186
- interspace
187
- intimate
188
- intrigue
189
- invalid
190
- invert
191
- invite
192
- irony
193
- jagged
194
- jesses
195
- julies
196
- kite
197
- laminate
198
- laos
199
- lather
200
- lead
201
- learned
202
- leasing
203
- lech
204
- legitimate
205
- lied
206
- lima
207
- lipread
208
- live
209
- lower
210
- lunged
211
- maas
212
- magdalen
213
- manes
214
- mare
215
- marked
216
- merchandise
217
- merlion
218
- minute
219
- misconduct
220
- misled
221
- misprint
222
- mobile
223
- moderate
224
- mong
225
- moped
226
- moth
227
- mouth
228
- mow
229
- mpg
230
- multiply
231
- mush
232
- nana
233
- nice
234
- nice
235
- number
236
- numerate
237
- nun
238
- object
239
- opiate
240
- ornament
241
- outbox
242
- outcry
243
- outpour
244
- outreach
245
- outride
246
- outright
247
- outside
248
- outwork
249
- overall
250
- overbid
251
- overcall
252
- overcast
253
- overfall
254
- overflow
255
- overhaul
256
- overhead
257
- overlap
258
- overlay
259
- overuse
260
- overweight
261
- overwork
262
- pace
263
- palled
264
- palling
265
- para
266
- pasty
267
- pate
268
- pauline
269
- pedal
270
- peer
271
- perfect
272
- periodic
273
- permit
274
- pervert
275
- pinta
276
- placer
277
- platy
278
- polish
279
- polish
280
- poll
281
- pontificate
282
- postulate
283
- pram
284
- prayer
285
- precipitate
286
- predate
287
- predicate
288
- prefix
289
- preposition
290
- present
291
- pretest
292
- primer
293
- proceeds
294
- produce
295
- progress
296
- project
297
- proportionate
298
- prospect
299
- protest
300
- pussy
301
- putter
302
- putting
303
- quite
304
- ragged
305
- raven
306
- re
307
- read
308
- reading
309
- reading
310
- real
311
- rebel
312
- recall
313
- recap
314
- recitative
315
- recollect
316
- record
317
- recreate
318
- recreation
319
- redress
320
- refill
321
- refund
322
- refuse
323
- reject
324
- relay
325
- remake
326
- repaint
327
- reprint
328
- reread
329
- rerun
330
- resent
331
- reside
332
- resign
333
- respray
334
- resume
335
- retard
336
- retest
337
- retread
338
- rewrite
339
- root
340
- routed
341
- routing
342
- row
343
- rugged
344
- rummy
345
- sais
346
- sake
347
- sambuca
348
- saucier
349
- second
350
- secrete
351
- secreted
352
- secreting
353
- segment
354
- separate
355
- sewer
356
- shirk
357
- shower
358
- sin
359
- skied
360
- slaver
361
- slough
362
- sow
363
- spoof
364
- squid
365
- stingy
366
- subject
367
- subordinate
368
- subvert
369
- supply
370
- supposed
371
- survey
372
- suspect
373
- syringes
374
- tabulate
375
- tales
376
- tarrier
377
- tarry
378
- taxes
379
- taxis
380
- tear
381
- theron
382
- thou
383
- three-legged
384
- tier
385
- tinged
386
- torment
387
- transfer
388
- transform
389
- transplant
390
- transport
391
- transpose
392
- tush
393
- two-legged
394
- unionised
395
- unionized
396
- update
397
- uplift
398
- upset
399
- use
400
- used
401
- vale
402
- violist
403
- viva
404
- ware
405
- whinged
406
- whoop
407
- wicked
408
- wind
409
- windy
410
- wino
411
- won
412
- worsted
413
- wound
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tts_text_processing/letters_and_numbers.py DELETED
@@ -1,96 +0,0 @@
1
- """adapted from https://github.com/keithito/tacotron"""
2
-
3
- import re
4
-
5
- _letters_and_numbers_re = re.compile(
6
- r"((?:[a-zA-Z]+[0-9]|[0-9]+[a-zA-Z])[a-zA-Z0-9']*)", re.IGNORECASE
7
- )
8
-
9
- _hardware_re = re.compile(
10
- "([0-9]+(?:[.,][0-9]+)?)(?:\s?)(tb|gb|mb|kb|ghz|mhz|khz|hz|mm)", re.IGNORECASE
11
- )
12
- _hardware_key = {
13
- "tb": "terabyte",
14
- "gb": "gigabyte",
15
- "mb": "megabyte",
16
- "kb": "kilobyte",
17
- "ghz": "gigahertz",
18
- "mhz": "megahertz",
19
- "khz": "kilohertz",
20
- "hz": "hertz",
21
- "mm": "millimeter",
22
- "cm": "centimeter",
23
- "km": "kilometer",
24
- }
25
-
26
- _dimension_re = re.compile(
27
- r"\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b|\b(\d+(?:[,.]\d+)?\s*[xX]\s*\d+(?:[,.]\d+)?(?:in|inch|m)?)\b"
28
- )
29
- _dimension_key = {"m": "meter", "in": "inch", "inch": "inch"}
30
-
31
-
32
- def _expand_letters_and_numbers(m):
33
- text = re.split(r"(\d+)", m.group(0))
34
-
35
- # remove trailing space
36
- if text[-1] == "":
37
- text = text[:-1]
38
- elif text[0] == "":
39
- text = text[1:]
40
-
41
- # if not like 1920s, or AK47's , 20th, 1st, 2nd, 3rd, etc...
42
- if text[-1] in ("'s", "s", "th", "nd", "st", "rd") and text[-2].isdigit():
43
- text[-2] = text[-2] + text[-1]
44
- text = text[:-1]
45
-
46
- # for combining digits 2 by 2
47
- new_text = []
48
- for i in range(len(text)):
49
- string = text[i]
50
- if string.isdigit() and len(string) < 5:
51
- # heuristics
52
- if len(string) > 2 and string[-2] == "0":
53
- if string[-1] == "0":
54
- string = [string]
55
- else:
56
- string = [string[:-3], string[-2], string[-1]]
57
- elif len(string) % 2 == 0:
58
- string = [string[i : i + 2] for i in range(0, len(string), 2)]
59
- elif len(string) > 2:
60
- string = [string[0]] + [
61
- string[i : i + 2] for i in range(1, len(string), 2)
62
- ]
63
- new_text.extend(string)
64
- else:
65
- new_text.append(string)
66
-
67
- text = new_text
68
- text = " ".join(text)
69
- return text
70
-
71
-
72
- def _expand_hardware(m):
73
- quantity, measure = m.groups(0)
74
- measure = _hardware_key[measure.lower()]
75
- if measure[-1] != "z" and float(quantity.replace(",", "")) > 1:
76
- return "{} {}s".format(quantity, measure)
77
- return "{} {}".format(quantity, measure)
78
-
79
-
80
- def _expand_dimension(m):
81
- text = "".join([x for x in m.groups(0) if x != 0])
82
- text = text.replace(" x ", " by ")
83
- text = text.replace("x", " by ")
84
- if text.endswith(tuple(_dimension_key.keys())):
85
- if text[-2].isdigit():
86
- text = "{} {}".format(text[:-1], _dimension_key[text[-1:]])
87
- elif text[-3].isdigit():
88
- text = "{} {}".format(text[:-2], _dimension_key[text[-2:]])
89
- return text
90
-
91
-
92
- def normalize_letters_and_numbers(text):
93
- text = re.sub(_hardware_re, _expand_hardware, text)
94
- text = re.sub(_dimension_re, _expand_dimension, text)
95
- text = re.sub(_letters_and_numbers_re, _expand_letters_and_numbers, text)
96
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tts_text_processing/numerical.py DELETED
@@ -1,175 +0,0 @@
1
- """adapted from https://github.com/keithito/tacotron"""
2
-
3
- import inflect
4
- import re
5
-
6
- _magnitudes = ["trillion", "billion", "million", "thousand", "hundred", "m", "b", "t"]
7
- _magnitudes_key = {"m": "million", "b": "billion", "t": "trillion"}
8
- _measurements = "(f|c|k|d|m)"
9
- _measurements_key = {"f": "fahrenheit", "c": "celsius", "k": "thousand", "m": "meters"}
10
- _currency_key = {"$": "dollar", "£": "pound", "€": "euro", "₩": "won"}
11
- _inflect = inflect.engine()
12
- _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
13
- _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
14
- _currency_re = re.compile(
15
- r"([\$€£₩])([0-9\.\,]*[0-9]+)(?:[ ]?({})(?=[^a-zA-Z]))?".format(
16
- "|".join(_magnitudes)
17
- ),
18
- re.IGNORECASE,
19
- )
20
- _measurement_re = re.compile(
21
- r"([0-9\.\,]*[0-9]+(\s)?{}\b)".format(_measurements), re.IGNORECASE
22
- )
23
- _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
24
- # _range_re = re.compile(r'(?<=[0-9])+(-)(?=[0-9])+.*?')
25
- _roman_re = re.compile(
26
- r"\b(?=[MDCLXVI]+\b)M{0,4}(CM|CD|D?C{0,3})(XC|XL|L?X{0,3})(IX|IV|V?I{2,3})\b"
27
- ) # avoid I
28
- _multiply_re = re.compile(r"(\b[0-9]+)(x)([0-9]+)")
29
- _number_re = re.compile(r"[0-9]+'s|[0-9]+s|[0-9]+")
30
-
31
-
32
- def _remove_commas(m):
33
- return m.group(1).replace(",", "")
34
-
35
-
36
- def _expand_decimal_point(m):
37
- return m.group(1).replace(".", " point ")
38
-
39
-
40
- def _expand_currency(m):
41
- currency = _currency_key[m.group(1)]
42
- quantity = m.group(2)
43
- magnitude = m.group(3)
44
-
45
- # remove commas from quantity to be able to convert to numerical
46
- quantity = quantity.replace(",", "")
47
-
48
- # check for million, billion, etc...
49
- if magnitude is not None and magnitude.lower() in _magnitudes:
50
- if len(magnitude) == 1:
51
- magnitude = _magnitudes_key[magnitude.lower()]
52
- return "{} {} {}".format(_expand_hundreds(quantity), magnitude, currency + "s")
53
-
54
- parts = quantity.split(".")
55
- if len(parts) > 2:
56
- return quantity + " " + currency + "s" # Unexpected format
57
-
58
- dollars = int(parts[0]) if parts[0] else 0
59
-
60
- cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
61
- if dollars and cents:
62
- dollar_unit = currency if dollars == 1 else currency + "s"
63
- cent_unit = "cent" if cents == 1 else "cents"
64
- return "{} {}, {} {}".format(
65
- _expand_hundreds(dollars),
66
- dollar_unit,
67
- _inflect.number_to_words(cents),
68
- cent_unit,
69
- )
70
- elif dollars:
71
- dollar_unit = currency if dollars == 1 else currency + "s"
72
- return "{} {}".format(_expand_hundreds(dollars), dollar_unit)
73
- elif cents:
74
- cent_unit = "cent" if cents == 1 else "cents"
75
- return "{} {}".format(_inflect.number_to_words(cents), cent_unit)
76
- else:
77
- return "zero" + " " + currency + "s"
78
-
79
-
80
- def _expand_hundreds(text):
81
- number = float(text)
82
- if number > 1000 < 10000 and (number % 100 == 0) and (number % 1000 != 0):
83
- return _inflect.number_to_words(int(number / 100)) + " hundred"
84
- else:
85
- return _inflect.number_to_words(text)
86
-
87
-
88
- def _expand_ordinal(m):
89
- return _inflect.number_to_words(m.group(0))
90
-
91
-
92
- def _expand_measurement(m):
93
- _, number, measurement = re.split("(\d+(?:\.\d+)?)", m.group(0))
94
- number = _inflect.number_to_words(number)
95
- measurement = "".join(measurement.split())
96
- measurement = _measurements_key[measurement.lower()]
97
- return "{} {}".format(number, measurement)
98
-
99
-
100
- def _expand_range(m):
101
- return " to "
102
-
103
-
104
- def _expand_multiply(m):
105
- left = m.group(1)
106
- right = m.group(3)
107
- return "{} by {}".format(left, right)
108
-
109
-
110
- def _expand_roman(m):
111
- # from https://stackoverflow.com/questions/19308177/converting-roman-numerals-to-integers-in-python
112
- roman_numerals = {"I": 1, "V": 5, "X": 10, "L": 50, "C": 100, "D": 500, "M": 1000}
113
- result = 0
114
- num = m.group(0)
115
- for i, c in enumerate(num):
116
- if (i + 1) == len(num) or roman_numerals[c] >= roman_numerals[num[i + 1]]:
117
- result += roman_numerals[c]
118
- else:
119
- result -= roman_numerals[c]
120
- return str(result)
121
-
122
-
123
- def _expand_number(m):
124
- _, number, suffix = re.split(r"(\d+(?:'?\d+)?)", m.group(0))
125
- number = int(number)
126
- if (
127
- number > 1000
128
- and number < 10000
129
- and (number % 100 == 0)
130
- and (number % 1000 != 0)
131
- ):
132
- text = _inflect.number_to_words(number // 100) + " hundred"
133
- elif number > 1000 and number < 3000:
134
- if number == 2000:
135
- text = "two thousand"
136
- elif number > 2000 and number < 2010:
137
- text = "two thousand " + _inflect.number_to_words(number % 100)
138
- elif number % 100 == 0:
139
- text = _inflect.number_to_words(number // 100) + " hundred"
140
- else:
141
- number = _inflect.number_to_words(
142
- number, andword="", zero="oh", group=2
143
- ).replace(", ", " ")
144
- number = re.sub(r"-", " ", number)
145
- text = number
146
- else:
147
- number = _inflect.number_to_words(number, andword="and")
148
- number = re.sub(r"-", " ", number)
149
- number = re.sub(r",", "", number)
150
- text = number
151
-
152
- if suffix in ("'s", "s"):
153
- if text[-1] == "y":
154
- text = text[:-1] + "ies"
155
- else:
156
- text = text + suffix
157
-
158
- return text
159
-
160
-
161
- def normalize_currency(text):
162
- return re.sub(_currency_re, _expand_currency, text)
163
-
164
-
165
- def normalize_numbers(text):
166
- text = re.sub(_comma_number_re, _remove_commas, text)
167
- text = re.sub(_currency_re, _expand_currency, text)
168
- text = re.sub(_decimal_number_re, _expand_decimal_point, text)
169
- text = re.sub(_ordinal_re, _expand_ordinal, text)
170
- # text = re.sub(_range_re, _expand_range, text)
171
- # text = re.sub(_measurement_re, _expand_measurement, text)
172
- text = re.sub(_roman_re, _expand_roman, text)
173
- text = re.sub(_multiply_re, _expand_multiply, text)
174
- text = re.sub(_number_re, _expand_number, text)
175
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tts_text_processing/symbols.py DELETED
@@ -1,144 +0,0 @@
1
- """adapted from https://github.com/keithito/tacotron"""
2
-
3
- """
4
- Defines the set of symbols used in text input to the model.
5
-
6
- The default is a set of ASCII characters that works well for English or text
7
- that has been run through Unidecode. For other data, you can modify
8
- _characters."""
9
-
10
-
11
- arpabet = [
12
- "AA",
13
- "AA0",
14
- "AA1",
15
- "AA2",
16
- "AE",
17
- "AE0",
18
- "AE1",
19
- "AE2",
20
- "AH",
21
- "AH0",
22
- "AH1",
23
- "AH2",
24
- "AO",
25
- "AO0",
26
- "AO1",
27
- "AO2",
28
- "AW",
29
- "AW0",
30
- "AW1",
31
- "AW2",
32
- "AY",
33
- "AY0",
34
- "AY1",
35
- "AY2",
36
- "B",
37
- "CH",
38
- "D",
39
- "DH",
40
- "EH",
41
- "EH0",
42
- "EH1",
43
- "EH2",
44
- "ER",
45
- "ER0",
46
- "ER1",
47
- "ER2",
48
- "EY",
49
- "EY0",
50
- "EY1",
51
- "EY2",
52
- "F",
53
- "G",
54
- "HH",
55
- "IH",
56
- "IH0",
57
- "IH1",
58
- "IH2",
59
- "IY",
60
- "IY0",
61
- "IY1",
62
- "IY2",
63
- "JH",
64
- "K",
65
- "L",
66
- "M",
67
- "N",
68
- "NG",
69
- "OW",
70
- "OW0",
71
- "OW1",
72
- "OW2",
73
- "OY",
74
- "OY0",
75
- "OY1",
76
- "OY2",
77
- "P",
78
- "R",
79
- "S",
80
- "SH",
81
- "T",
82
- "TH",
83
- "UH",
84
- "UH0",
85
- "UH1",
86
- "UH2",
87
- "UW",
88
- "UW0",
89
- "UW1",
90
- "UW2",
91
- "V",
92
- "W",
93
- "Y",
94
- "Z",
95
- "ZH",
96
- ]
97
-
98
-
99
- def get_symbols(symbol_set):
100
- if symbol_set == "english_basic":
101
- _pad = "_"
102
- _punctuation = "!'\"(),.:;? "
103
- _special = "-"
104
- _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
105
- _arpabet = ["@" + s for s in arpabet]
106
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
107
- elif symbol_set == "english_basic_lowercase":
108
- _pad = "_"
109
- _punctuation = "!'\"(),.:;? "
110
- _special = "-"
111
- _letters = "abcdefghijklmnopqrstuvwxyz"
112
- _arpabet = ["@" + s for s in arpabet]
113
- symbols = list(_pad + _special + _punctuation + _letters) + _arpabet
114
- elif symbol_set == "english_expanded":
115
- _punctuation = "!'\",.:;? "
116
- _math = "#%&*+-/[]()"
117
- _special = "_@©°½—₩€$"
118
- _accented = "áçéêëñöøćž"
119
- _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
120
- _arpabet = ["@" + s for s in arpabet]
121
- symbols = (
122
- list(_punctuation + _math + _special + _accented + _letters) + _arpabet
123
- )
124
- elif symbol_set == "ukrainian":
125
- _punctuation = "'.,?! "
126
- _special = "-+"
127
- _letters = "абвгґдежзийклмнопрстуфхцчшщьюяєії"
128
- symbols = list(_punctuation + _special + _letters)
129
- elif symbol_set == "radtts":
130
- _punctuation = "!'\",.:;? "
131
- _math = "#%&*+-/[]()"
132
- _special = "_@©°½—₩€$"
133
- _accented = "áçéêëñöøćž"
134
- _numbers = "0123456789"
135
- _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
136
- _arpabet = ["@" + s for s in arpabet]
137
- symbols = (
138
- list(_punctuation + _math + _special + _accented + _numbers + _letters)
139
- + _arpabet
140
- )
141
- else:
142
- raise Exception("{} symbol set does not exist".format(symbol_set))
143
-
144
- return symbols
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tts_text_processing/text_processing.py CHANGED
@@ -2,9 +2,8 @@
2
 
3
  import re
4
  import numpy as np
 
5
  from .cleaners import Cleaner
6
- from .symbols import get_symbols
7
- from .grapheme_dictionary import Grapheme2PhonemeDictionary
8
 
9
 
10
  #########
@@ -20,11 +19,14 @@ _words_re = re.compile(
20
  )
21
 
22
 
23
- def lines_to_list(filename):
24
- with open(filename, encoding="utf-8") as f:
25
- lines = f.readlines()
26
- lines = [l.rstrip() for l in lines]
27
- return lines
 
 
 
28
 
29
 
30
  class TextProcessing(object):
@@ -42,18 +44,14 @@ class TextProcessing(object):
42
  add_bos_eos_to_text=False,
43
  encoding="latin-1",
44
  ):
45
- if heteronyms_path is not None and heteronyms_path != "":
46
- self.heteronyms = set(lines_to_list(heteronyms_path))
47
- else:
48
- self.heteronyms = []
49
- # phoneme dict
50
  self.phonemedict = {}
51
 
52
  self.p_phoneme = p_phoneme
53
  self.handle_phoneme = handle_phoneme
54
  self.handle_phoneme_ambiguous = handle_phoneme_ambiguous
55
 
56
- self.symbols = get_symbols(symbol_set)
57
  self.cleaner_names = cleaner_name
58
  self.cleaner = Cleaner(cleaner_name, self.phonemedict)
59
 
 
2
 
3
  import re
4
  import numpy as np
5
+
6
  from .cleaners import Cleaner
 
 
7
 
8
 
9
  #########
 
19
  )
20
 
21
 
22
+ def get_symbols():
23
+ _punctuation = "'.,?! "
24
+ _special = "-+"
25
+ _letters = "абвгґдежзийклмнопрстуфхцчшщьюяєії"
26
+
27
+ symbols = list(_punctuation + _special + _letters)
28
+
29
+ return symbols
30
 
31
 
32
  class TextProcessing(object):
 
44
  add_bos_eos_to_text=False,
45
  encoding="latin-1",
46
  ):
47
+ self.heteronyms = []
 
 
 
 
48
  self.phonemedict = {}
49
 
50
  self.p_phoneme = p_phoneme
51
  self.handle_phoneme = handle_phoneme
52
  self.handle_phoneme_ambiguous = handle_phoneme_ambiguous
53
 
54
+ self.symbols = get_symbols()
55
  self.cleaner_names = cleaner_name
56
  self.cleaner = Cleaner(cleaner_name, self.phonemedict)
57