ITS-C4SF733\Administrator
all resource
324bf29
import os
import torch
import logging
from macros import *
from data.tokenizer import (
AudioTokenizer,
tokenize_audio,
)
from models.vallex import VALLE
from vocos import Vocos
def get_model(device):
url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'
checkpoints_dir = "./checkpoints"
model_checkpoint_name = "vallex-checkpoint_modified.pt"
if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)
if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)):
import wget
print("3")
try:
logging.info(
"Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...")
# download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt
wget.download("https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt",
out="./checkpoints/vallex-checkpoint.pt", bar=wget.bar_adaptive)
except Exception as e:
logging.info(e)
raise Exception(
"\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'"
"\n manually download model weights and put it to {} .".format(os.getcwd() + "\checkpoints"))
# VALL-E
model = VALLE(
N_DIM,
NUM_HEAD,
NUM_LAYERS,
norm_first=True,
add_prenet=False,
prefix_mode=PREFIX_MODE,
share_embedding=True,
nar_scale_factor=1.0,
prepend_bos=True,
num_quantizers=NUM_QUANTIZERS,
).to(device)
checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(
checkpoint["model"], strict=True
)
assert not missing_keys
# Encodec
codec = AudioTokenizer(device)
vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)
return model, codec, vocos