Yehor commited on
Commit
eac7684
·
1 Parent(s): 1cf3da8
Files changed (8) hide show
  1. app.py +21 -7
  2. audio_processing.py +0 -54
  3. common.py +3 -1
  4. configs/radtts-pp-dap-model.json +0 -39
  5. data.py +3 -365
  6. export_weights.py +3 -3
  7. radtts.py +3 -1
  8. requirements.txt +1 -2
app.py CHANGED
@@ -19,7 +19,7 @@ from huggingface_hub import hf_hub_download
19
 
20
  # RAD-TTS code
21
  from radtts import RADTTS
22
- from data import Data
23
  from common import update_params
24
  from torch_env import device
25
 
@@ -100,10 +100,10 @@ radtts.eval()
100
  print(f"Loaded checkpoint '{radtts_path}')")
101
 
102
  ignore_keys = ["training_files", "validation_files"]
103
- trainset = Data(
104
  data_config["training_files"],
105
  **dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
106
- )
107
 
108
  # Config
109
  concurrency_limit = 5
@@ -186,6 +186,20 @@ examples = [
186
  ]
187
 
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  def inference(text, voice):
190
  if not text:
191
  raise gr.Error("Please paste your text.")
@@ -209,16 +223,16 @@ def inference(text, voice):
209
  energy_mean = 0
210
  energy_std = 0
211
 
212
- tensor_text = trainset.get_text(text).to(device)
213
 
214
- speaker_id = trainset.get_speaker_id(speaker).to(device)
215
  speaker_id_text, speaker_id_attributes = speaker_id, speaker_id
216
 
217
  if speaker_text is not None:
218
- speaker_id_text = trainset.get_speaker_id(speaker_text).to(device)
219
 
220
  if speaker_attributes is not None:
221
- speaker_id_attributes = trainset.get_speaker_id(speaker_attributes).to(device)
222
 
223
  inference_start = time.time()
224
 
 
19
 
20
  # RAD-TTS code
21
  from radtts import RADTTS
22
+ from data import TextProcessor
23
  from common import update_params
24
  from torch_env import device
25
 
 
100
  print(f"Loaded checkpoint '{radtts_path}')")
101
 
102
  ignore_keys = ["training_files", "validation_files"]
103
+ tp = TextProcessor(
104
  data_config["training_files"],
105
  **dict((k, v) for k, v in data_config.items() if k not in ignore_keys),
106
+ ).get_processor()
107
 
108
  # Config
109
  concurrency_limit = 5
 
186
  ]
187
 
188
 
189
+ def get_speaker_id(speaker):
190
+ speaker_ids = {
191
+ "lada": 0,
192
+ "mykyta": 1,
193
+ "tetiana": 2,
194
+ }
195
+
196
+ return torch.LongTensor([speaker_ids[speaker]])
197
+
198
+
199
+ def get_text(text):
200
+ return torch.LongTensor(tp.encode_text(text))
201
+
202
+
203
  def inference(text, voice):
204
  if not text:
205
  raise gr.Error("Please paste your text.")
 
223
  energy_mean = 0
224
  energy_std = 0
225
 
226
+ tensor_text = get_text(text).to(device)
227
 
228
+ speaker_id = get_speaker_id(speaker).to(device)
229
  speaker_id_text, speaker_id_attributes = speaker_id, speaker_id
230
 
231
  if speaker_text is not None:
232
+ speaker_id_text = get_speaker_id(speaker_text).to(device)
233
 
234
  if speaker_attributes is not None:
235
+ speaker_id_attributes = get_speaker_id(speaker_attributes).to(device)
236
 
237
  inference_start = time.time()
238
 
audio_processing.py CHANGED
@@ -55,7 +55,6 @@ import torch
55
  import numpy as np
56
 
57
  from scipy.signal import get_window
58
- from librosa.filters import mel as librosa_mel_fn
59
  import librosa.util as librosa_util
60
 
61
  import torch.nn.functional as F
@@ -159,59 +158,6 @@ def dynamic_range_decompression(x, C=1):
159
  return torch.exp(x) / C
160
 
161
 
162
- class TacotronSTFT(torch.nn.Module):
163
- def __init__(
164
- self,
165
- filter_length=1024,
166
- hop_length=256,
167
- win_length=1024,
168
- n_mel_channels=80,
169
- sampling_rate=22050,
170
- mel_fmin=0.0,
171
- mel_fmax=None,
172
- ):
173
- super(TacotronSTFT, self).__init__()
174
- self.n_mel_channels = n_mel_channels
175
- self.sampling_rate = sampling_rate
176
- self.stft_fn = STFT(filter_length, hop_length, win_length)
177
- mel_basis = librosa_mel_fn(
178
- sr=sampling_rate,
179
- n_fft=filter_length,
180
- n_mels=n_mel_channels,
181
- fmin=mel_fmin,
182
- fmax=mel_fmax,
183
- )
184
- mel_basis = torch.from_numpy(mel_basis).float()
185
- self.register_buffer("mel_basis", mel_basis)
186
-
187
- def spectral_normalize(self, magnitudes):
188
- output = dynamic_range_compression(magnitudes)
189
- return output
190
-
191
- def spectral_de_normalize(self, magnitudes):
192
- output = dynamic_range_decompression(magnitudes)
193
- return output
194
-
195
- def mel_spectrogram(self, y):
196
- """Computes mel-spectrograms from a batch of waves
197
- PARAMS
198
- ------
199
- y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
200
-
201
- RETURNS
202
- -------
203
- mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
204
- """
205
- assert torch.min(y.data) >= -1
206
- assert torch.max(y.data) <= 1
207
-
208
- magnitudes, phases = self.stft_fn.transform(y)
209
- magnitudes = magnitudes.data
210
- mel_output = torch.matmul(self.mel_basis, magnitudes)
211
- mel_output = self.spectral_normalize(mel_output)
212
- return mel_output
213
-
214
-
215
  class STFT(torch.nn.Module):
216
  """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
217
 
 
55
  import numpy as np
56
 
57
  from scipy.signal import get_window
 
58
  import librosa.util as librosa_util
59
 
60
  import torch.nn.functional as F
 
158
  return torch.exp(x) / C
159
 
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  class STFT(torch.nn.Module):
162
  """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
163
 
common.py CHANGED
@@ -233,7 +233,9 @@ class ConvLSTMLinear(nn.Module):
233
  dilation=1,
234
  w_init_gain="relu",
235
  )
236
- conv_layer = torch.nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight")
 
 
237
  convolutions.append(conv_layer)
238
 
239
  self.convolutions = nn.ModuleList(convolutions)
 
233
  dilation=1,
234
  w_init_gain="relu",
235
  )
236
+ conv_layer = torch.nn.utils.parametrizations.weight_norm(
237
+ conv_layer.conv, name="weight"
238
+ )
239
  convolutions.append(conv_layer)
240
 
241
  self.convolutions = nn.ModuleList(convolutions)
configs/radtts-pp-dap-model.json CHANGED
@@ -1,39 +1,4 @@
1
  {
2
- "train_config": {
3
- "output_directory": "outdir_pp_model",
4
- "epochs": 10000000,
5
- "optim_algo": "RAdam",
6
- "learning_rate": 0.001,
7
- "weight_decay": 1e-06,
8
- "sigma": 1.0,
9
- "iters_per_checkpoint": 1000,
10
- "batch_size": 16,
11
- "seed": null,
12
- "checkpoint_path": "",
13
- "ignore_layers": [],
14
- "ignore_layers_warmstart": [],
15
- "finetune_layers": [],
16
- "include_layers": [],
17
- "vocoder_config_path": "models/hifigan_22khz_config.json",
18
- "vocoder_checkpoint_path": "models/hifigan_ljs_generator_v1.pt",
19
- "log_attribute_samples": true,
20
- "log_decoder_samples": true,
21
- "warmstart_checkpoint_path": "outdir_pp/model_100000",
22
- "use_amp": true,
23
- "grad_clip_val": 1.0,
24
- "loss_weights": {
25
- "blank_logprob": -1,
26
- "ctc_loss_weight": 0.1,
27
- "binarization_loss_weight": 1.0,
28
- "dur_loss_weight": 1.0,
29
- "f0_loss_weight": 1.0,
30
- "energy_loss_weight": 1.0,
31
- "vpred_loss_weight": 1.0
32
- },
33
- "binarization_start_iter": 0,
34
- "kl_loss_start_iter": 0,
35
- "unfreeze_modules": "all"
36
- },
37
  "data_config": {
38
  "training_files": {
39
  "LJS": {
@@ -88,10 +53,6 @@
88
  "distance_tx_unvoiced": false,
89
  "mel_noise_scale": 0.0
90
  },
91
- "dist_config": {
92
- "dist_backend": "nccl",
93
- "dist_url": "tcp://localhost:54321"
94
- },
95
  "model_config": {
96
  "n_speakers": 3,
97
  "n_speaker_dim": 16,
 
1
  {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  "data_config": {
3
  "training_files": {
4
  "LJS": {
 
53
  "distance_tx_unvoiced": false,
54
  "mel_noise_scale": 0.0
55
  },
 
 
 
 
56
  "model_config": {
57
  "n_speakers": 3,
58
  "n_speaker_dim": 16,
data.py CHANGED
@@ -38,43 +38,13 @@
38
  #
39
  ###############################################################################
40
 
41
- import os
42
- import pickle as pkl
43
-
44
- import lmdb
45
  import torch
46
  import torch.utils.data
47
- import numpy as np
48
-
49
- from librosa import pyin
50
- from scipy.io.wavfile import read
51
- from scipy.stats import betabinom
52
- from scipy.ndimage import distance_transform_edt as distance_transform
53
 
54
- from audio_processing import TacotronSTFT
55
  from tts_text_processing.text_processing import TextProcessing
56
 
57
 
58
- def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=0.05):
59
- P = phoneme_count
60
- M = mel_count
61
- x = np.arange(0, P)
62
- mel_text_probs = []
63
- for i in range(1, M + 1):
64
- a, b = scaling_factor * i, scaling_factor * (M + 1 - i)
65
- rv = betabinom(P - 1, a, b)
66
- mel_i_prob = rv.pmf(x)
67
- mel_text_probs.append(mel_i_prob)
68
- return torch.tensor(np.array(mel_text_probs))
69
-
70
-
71
- def load_wav_to_torch(full_path):
72
- """Loads wavdata into torch array"""
73
- sampling_rate, data = read(full_path)
74
- return torch.from_numpy(np.array(data)).float(), sampling_rate
75
-
76
-
77
- class Data(torch.utils.data.Dataset):
78
  def __init__(
79
  self,
80
  datasets,
@@ -114,37 +84,6 @@ class Data(torch.utils.data.Dataset):
114
  combine_speaker_and_emotion=False,
115
  **kwargs,
116
  ):
117
- self.combine_speaker_and_emotion = combine_speaker_and_emotion
118
- self.max_wav_value = max_wav_value
119
- self.audio_lmdb_dict = {} # dictionary of lmdbs for audio data
120
- self.data = self.load_data(datasets)
121
- self.distance_tx_unvoiced = False
122
- if "distance_tx_unvoiced" in kwargs.keys():
123
- self.distance_tx_unvoiced = kwargs["distance_tx_unvoiced"]
124
- self.stft = TacotronSTFT(
125
- filter_length=filter_length,
126
- hop_length=hop_length,
127
- win_length=win_length,
128
- sampling_rate=sampling_rate,
129
- n_mel_channels=n_mel_channels,
130
- mel_fmin=mel_fmin,
131
- mel_fmax=mel_fmax,
132
- )
133
-
134
- self.do_mel_scaling = kwargs.get("do_mel_scaling", True)
135
- self.mel_noise_scale = kwargs.get("mel_noise_scale", 0.0)
136
- self.filter_length = filter_length
137
- self.hop_length = hop_length
138
- self.win_length = win_length
139
- self.mel_fmin = mel_fmin
140
- self.mel_fmax = mel_fmax
141
- self.f0_min = f0_min
142
- self.f0_max = f0_max
143
- self.use_f0 = use_f0
144
- self.use_log_f0 = use_log_f0
145
- self.use_energy_avg = use_energy_avg
146
- self.use_scaled_energy = use_scaled_energy
147
- self.sampling_rate = sampling_rate
148
  self.tp = TextProcessing(
149
  symbol_set,
150
  cleaner_names,
@@ -158,306 +97,5 @@ class Data(torch.utils.data.Dataset):
158
  add_bos_eos_to_text=add_bos_eos_to_text,
159
  )
160
 
161
- self.dur_min = dur_min
162
- self.dur_max = dur_max
163
- if speaker_ids is None or speaker_ids == "":
164
- self.speaker_ids = self.create_speaker_lookup_table(self.data)
165
- else:
166
- self.speaker_ids = speaker_ids
167
-
168
- print("Number of files", len(self.data))
169
- if include_speakers is not None:
170
- for speaker_set, include in include_speakers:
171
- self.filter_by_speakers_(speaker_set, include)
172
- print("Number of files after speaker filtering", len(self.data))
173
-
174
- if dur_min is not None and dur_max is not None:
175
- self.filter_by_duration_(dur_min, dur_max)
176
- print("Number of files after duration filtering", len(self.data))
177
-
178
- self.use_attn_prior_masking = bool(use_attn_prior_masking)
179
- self.prepend_space_to_text = bool(prepend_space_to_text)
180
- self.append_space_to_text = bool(append_space_to_text)
181
- self.betabinom_cache_path = betabinom_cache_path
182
- self.betabinom_scaling_factor = betabinom_scaling_factor
183
- self.lmdb_cache_path = lmdb_cache_path
184
- if self.lmdb_cache_path != "":
185
- self.cache_data_lmdb = lmdb.open(
186
- self.lmdb_cache_path, readonly=True, max_readers=1024, lock=False
187
- ).begin()
188
-
189
- # # make sure caching path exists
190
- # if not os.path.exists(self.betabinom_cache_path):
191
- # os.makedirs(self.betabinom_cache_path)
192
-
193
- print("Dataloader initialized with no augmentations")
194
- self.speaker_map = None
195
- if "speaker_map" in kwargs:
196
- self.speaker_map = kwargs["speaker_map"]
197
-
198
- def load_data(self, datasets, split="|"):
199
- dataset = []
200
- for dset_name, dset_dict in datasets.items():
201
- folder_path = dset_dict["basedir"]
202
- audiodir = dset_dict["audiodir"]
203
- filename = dset_dict["filelist"]
204
- audio_lmdb_key = None
205
- if "lmdbpath" in dset_dict.keys() and len(dset_dict["lmdbpath"]) > 0:
206
- self.audio_lmdb_dict[dset_name] = lmdb.open(
207
- dset_dict["lmdbpath"], readonly=True, max_readers=256, lock=False
208
- ).begin()
209
- audio_lmdb_key = dset_name
210
-
211
- wav_folder_prefix = os.path.join(folder_path, audiodir)
212
- filelist_path = os.path.join(folder_path, filename)
213
- with open(filelist_path, encoding="utf-8") as f:
214
- data = [line.strip().split(split) for line in f]
215
-
216
- for d in data:
217
- emotion = "other" if len(d) == 3 else d[3]
218
- duration = -1 if len(d) == 3 else d[4]
219
- dataset.append(
220
- {
221
- "audiopath": os.path.join(wav_folder_prefix, d[0]),
222
- "text": d[1],
223
- "speaker": d[2] + "-" + emotion
224
- if self.combine_speaker_and_emotion
225
- else d[2],
226
- "emotion": emotion,
227
- "duration": float(duration),
228
- "lmdb_key": audio_lmdb_key,
229
- }
230
- )
231
- return dataset
232
-
233
- def filter_by_speakers_(self, speakers, include=True):
234
- print("Include spaker {}: {}".format(speakers, include))
235
- if include:
236
- self.data = [x for x in self.data if x["speaker"] in speakers]
237
- else:
238
- self.data = [x for x in self.data if x["speaker"] not in speakers]
239
-
240
- def filter_by_duration_(self, dur_min, dur_max):
241
- self.data = [
242
- x
243
- for x in self.data
244
- if x["duration"] == -1
245
- or (x["duration"] >= dur_min and x["duration"] <= dur_max)
246
- ]
247
-
248
- def create_speaker_lookup_table(self, data):
249
- speaker_ids = np.sort(np.unique([x["speaker"] for x in data]))
250
- d = {speaker_ids[i]: i for i in range(len(speaker_ids))}
251
- print("Number of speakers:", len(d))
252
- print("Speaker IDS", d)
253
- return d
254
-
255
- def f0_normalize(self, x):
256
- if self.use_log_f0:
257
- mask = x >= self.f0_min
258
- x[mask] = torch.log(x[mask])
259
- x[~mask] = 0.0
260
-
261
- return x
262
-
263
- def f0_denormalize(self, x):
264
- if self.use_log_f0:
265
- log_f0_min = np.log(self.f0_min)
266
- mask = x >= log_f0_min
267
- x[mask] = torch.exp(x[mask])
268
- x[~mask] = 0.0
269
- x[x <= 0.0] = 0.0
270
-
271
- return x
272
-
273
- def energy_avg_normalize(self, x):
274
- if self.use_scaled_energy:
275
- x = (x + 20.0) / 20.0
276
- return x
277
-
278
- def energy_avg_denormalize(self, x):
279
- if self.use_scaled_energy:
280
- x = x * 20.0 - 20.0
281
- return x
282
-
283
- def get_f0_pvoiced(
284
- self,
285
- audio,
286
- sampling_rate=22050,
287
- frame_length=1024,
288
- hop_length=256,
289
- f0_min=100,
290
- f0_max=300,
291
- ):
292
- audio_norm = audio / self.max_wav_value
293
- f0, voiced_mask, p_voiced = pyin(
294
- audio_norm,
295
- f0_min,
296
- f0_max,
297
- sampling_rate,
298
- frame_length=frame_length,
299
- win_length=frame_length // 2,
300
- hop_length=hop_length,
301
- )
302
- f0[~voiced_mask] = 0.0
303
- f0 = torch.FloatTensor(f0)
304
- p_voiced = torch.FloatTensor(p_voiced)
305
- voiced_mask = torch.FloatTensor(voiced_mask)
306
- return f0, voiced_mask, p_voiced
307
-
308
- def get_energy_average(self, mel):
309
- energy_avg = mel.mean(0)
310
- energy_avg = self.energy_avg_normalize(energy_avg)
311
- return energy_avg
312
-
313
- def get_mel(self, audio):
314
- audio_norm = audio / self.max_wav_value
315
- audio_norm = audio_norm.unsqueeze(0)
316
- audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
317
- melspec = self.stft.mel_spectrogram(audio_norm)
318
- melspec = torch.squeeze(melspec, 0)
319
- if self.do_mel_scaling:
320
- melspec = (melspec + 5.5) / 2
321
- if self.mel_noise_scale > 0:
322
- melspec += torch.randn_like(melspec) * self.mel_noise_scale
323
- return melspec
324
-
325
- def get_speaker_id(self, speaker):
326
- if self.speaker_map is not None and speaker in self.speaker_map:
327
- speaker = self.speaker_map[speaker]
328
-
329
- return torch.LongTensor([self.speaker_ids[speaker]])
330
-
331
- def get_text(self, text):
332
- text = self.tp.encode_text(text)
333
- text = torch.LongTensor(text)
334
- return text
335
-
336
- def get_attention_prior(self, n_tokens, n_frames):
337
- # cache the entire attn_prior by filename
338
- if self.use_attn_prior_masking:
339
- filename = "{}_{}".format(n_tokens, n_frames)
340
- prior_path = os.path.join(self.betabinom_cache_path, filename)
341
- prior_path += "_prior.pth"
342
- if self.lmdb_cache_path != "":
343
- attn_prior = pkl.loads(
344
- self.cache_data_lmdb.get(prior_path.encode("ascii"))
345
- )
346
- elif os.path.exists(prior_path):
347
- attn_prior = torch.load(prior_path)
348
- else:
349
- attn_prior = beta_binomial_prior_distribution(
350
- n_tokens, n_frames, self.betabinom_scaling_factor
351
- )
352
- torch.save(attn_prior, prior_path)
353
- else:
354
- attn_prior = torch.ones(n_frames, n_tokens) # all ones baseline
355
-
356
- return attn_prior
357
-
358
- def __getitem__(self, index):
359
- data = self.data[index]
360
- audiopath, text = data["audiopath"], data["text"]
361
- speaker_id = data["speaker"]
362
-
363
- if data["lmdb_key"] is not None:
364
- data_dict = pkl.loads(
365
- self.audio_lmdb_dict[data["lmdb_key"]].get(audiopath.encode("ascii"))
366
- )
367
- audio = data_dict["audio"]
368
- sampling_rate = data_dict["sampling_rate"]
369
- else:
370
- audio, sampling_rate = load_wav_to_torch(audiopath)
371
-
372
- if sampling_rate != self.sampling_rate:
373
- raise ValueError(
374
- "{} SR doesn't match target {} SR".format(
375
- sampling_rate, self.sampling_rate
376
- )
377
- )
378
-
379
- mel = self.get_mel(audio)
380
- f0 = None
381
- p_voiced = None
382
- voiced_mask = None
383
- if self.use_f0:
384
- filename = "_".join(audiopath.split("/")[-3:])
385
- f0_path = os.path.join(self.betabinom_cache_path, filename)
386
- f0_path += "_f0_sr{}_fl{}_hl{}_f0min{}_f0max{}_log{}.pt".format(
387
- self.sampling_rate,
388
- self.filter_length,
389
- self.hop_length,
390
- self.f0_min,
391
- self.f0_max,
392
- self.use_log_f0,
393
- )
394
-
395
- dikt = None
396
- if len(self.lmdb_cache_path) > 0:
397
- dikt = pkl.loads(self.cache_data_lmdb.get(f0_path.encode("ascii")))
398
- f0 = dikt["f0"]
399
- p_voiced = dikt["p_voiced"]
400
- voiced_mask = dikt["voiced_mask"]
401
- elif os.path.exists(f0_path):
402
- try:
403
- dikt = torch.load(f0_path)
404
- except Exception as e:
405
- print(e)
406
- print(f"f0 loading from {f0_path} is broken, recomputing.")
407
-
408
- if dikt is not None:
409
- f0 = dikt["f0"]
410
- p_voiced = dikt["p_voiced"]
411
- voiced_mask = dikt["voiced_mask"]
412
- else:
413
- f0, voiced_mask, p_voiced = self.get_f0_pvoiced(
414
- audio.cpu().numpy(),
415
- self.sampling_rate,
416
- self.filter_length,
417
- self.hop_length,
418
- self.f0_min,
419
- self.f0_max,
420
- )
421
- print("saving f0 to {}".format(f0_path))
422
- torch.save(
423
- {"f0": f0, "voiced_mask": voiced_mask, "p_voiced": p_voiced},
424
- f0_path,
425
- )
426
- if f0 is None:
427
- raise Exception("STOP, BROKEN F0 {}".format(audiopath))
428
-
429
- f0 = self.f0_normalize(f0)
430
- if self.distance_tx_unvoiced:
431
- mask = f0 <= 0.0
432
- distance_map = np.log(distance_transform(mask))
433
- distance_map[distance_map <= 0] = 0.0
434
- f0 = f0 - distance_map
435
-
436
- energy_avg = None
437
- if self.use_energy_avg:
438
- energy_avg = self.get_energy_average(mel)
439
- if self.use_scaled_energy and energy_avg.min() < 0.0:
440
- print(audiopath, "has scaled energy avg smaller than 0")
441
-
442
- speaker_id = self.get_speaker_id(speaker_id)
443
- text_encoded = self.get_text(text)
444
-
445
- attn_prior = self.get_attention_prior(text_encoded.shape[0], mel.shape[1])
446
-
447
- if not self.use_attn_prior_masking:
448
- attn_prior = None
449
-
450
- return {
451
- "mel": mel,
452
- "speaker_id": speaker_id,
453
- "text_encoded": text_encoded,
454
- "audiopath": audiopath,
455
- "attn_prior": attn_prior,
456
- "f0": f0,
457
- "p_voiced": p_voiced,
458
- "voiced_mask": voiced_mask,
459
- "energy_avg": energy_avg,
460
- }
461
-
462
- def __len__(self):
463
- return len(self.data)
 
38
  #
39
  ###############################################################################
40
 
 
 
 
 
41
  import torch
42
  import torch.utils.data
 
 
 
 
 
 
43
 
 
44
  from tts_text_processing.text_processing import TextProcessing
45
 
46
 
47
+ class TextProcessor(torch.utils.data.Dataset):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def __init__(
49
  self,
50
  datasets,
 
84
  combine_speaker_and_emotion=False,
85
  **kwargs,
86
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  self.tp = TextProcessing(
88
  symbol_set,
89
  cleaner_names,
 
97
  add_bos_eos_to_text=add_bos_eos_to_text,
98
  )
99
 
100
+ def get_processor(self):
101
+ return self.tp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
export_weights.py CHANGED
@@ -5,9 +5,9 @@ radtts_path_state = "models/radtts-pp-dap-model/model_dap_84000_state.pt"
5
 
6
  checkpoint_dict = torch.load(radtts_path, map_location="cpu")
7
 
8
- del checkpoint_dict['iteration']
9
- del checkpoint_dict['optimizer']
10
- del checkpoint_dict['learning_rate']
11
 
12
  print(checkpoint_dict.keys())
13
 
 
5
 
6
  checkpoint_dict = torch.load(radtts_path, map_location="cpu")
7
 
8
+ del checkpoint_dict["iteration"]
9
+ del checkpoint_dict["optimizer"]
10
+ del checkpoint_dict["learning_rate"]
11
 
12
  print(checkpoint_dict.keys())
13
 
radtts.py CHANGED
@@ -201,7 +201,9 @@ class RADTTS(torch.nn.Module):
201
  if context_lstm_norm is not None:
202
  if "spectral" in context_lstm_norm:
203
  print("Applying spectral norm to context encoder LSTM")
204
- lstm_norm_fn_pntr = torch.nn.utils.parametrizations.spectral_norm
 
 
205
  elif "weight" in context_lstm_norm:
206
  print("Applying weight norm to context encoder LSTM")
207
  lstm_norm_fn_pntr = torch.nn.utils.parametrizations.weight_norm
 
201
  if context_lstm_norm is not None:
202
  if "spectral" in context_lstm_norm:
203
  print("Applying spectral norm to context encoder LSTM")
204
+ lstm_norm_fn_pntr = (
205
+ torch.nn.utils.parametrizations.spectral_norm
206
+ )
207
  elif "weight" in context_lstm_norm:
208
  print("Applying weight norm to context encoder LSTM")
209
  lstm_norm_fn_pntr = torch.nn.utils.parametrizations.weight_norm
requirements.txt CHANGED
@@ -1,12 +1,11 @@
1
  huggingface_hub
2
 
3
- gradio==5.18.0
4
 
5
  torch
6
  torchaudio
7
  scipy
8
  numba
9
- lmdb
10
  librosa
11
 
12
  git+https://github.com/langtech-bsc/vocos.git@matcha
 
1
  huggingface_hub
2
 
3
+ gradio
4
 
5
  torch
6
  torchaudio
7
  scipy
8
  numba
 
9
  librosa
10
 
11
  git+https://github.com/langtech-bsc/vocos.git@matcha