Spaces:
Running
on
Zero
Running
on
Zero
import importlib | |
from types import SimpleNamespace | |
import gradio as gr | |
import pandas as pd | |
import spaces | |
import torch | |
from utmosv2.utils import get_dataset, get_model | |
description = ( | |
"# π UTMOSv2 demo\n\n" | |
"[![GitHub](https://img.shields.io/badge/-GitHub-181717.svg?logo=github&style=flat)](https://github.com/sarulab-speech/UTMOSv2)\n\n" | |
"This is a demonstration of MOS prediction using UTMOSv2. " | |
"This demonstration only accepts `.wav` format. Best at 16 kHz sampling rate." | |
) | |
device = torch.device("cuda") | |
config = importlib.import_module("utmosv2.config.fusion_stage3") | |
cfg = SimpleNamespace(**{attr: getattr(config, attr) for attr in config.__dict__ if not attr.startswith("__")}) | |
cfg.reproduce = False | |
cfg.config = "fusion_stage3" | |
cfg.print_config = False | |
cfg.data_config = None | |
cfg.phase = "inference" | |
cfg.num_workers = 1 | |
def predict_mos(audio_path: str, domain: str, quick: bool) -> float: | |
data = pd.DataFrame({"file_path": [audio_path]}) | |
data["dataset"] = domain | |
data["mos"] = 0 | |
preds = 0.0 | |
for fold in range(5): | |
cfg.now_fold = fold | |
cfg.weight = f"models/fusion_stage3/fold{fold}_s42_best_model.pth" | |
model = get_model(cfg, device).eval() | |
for _ in range(5): | |
test_dataset = get_dataset(cfg, data, "test") | |
p = model(*[torch.tensor(t,dtype=torch.float32).unsqueeze(0).to(device) for t in test_dataset[0][:-1]]) | |
preds += p.cpu().numpy()[0][0] | |
if quick: | |
return preds | |
preds /= 25.0 | |
return preds | |
with gr.Blocks() as demo: | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
audio = gr.Audio(type="filepath", label="Audio") | |
domain = gr.Dropdown( | |
[ | |
"sarulab", | |
"bvcc", | |
"somos", | |
"blizzard2008", | |
"blizzard2009", | |
"blizzard2010-EH1", | |
"blizzard2010-EH2", | |
"blizzard2010-ES1", | |
"blizzard2010-ES3", | |
"blizzard2011", | |
], | |
label="Data-domain ID for the MOS prediction", | |
value="sarulab", | |
) | |
quick = gr.Checkbox( | |
label="Quick prediction", | |
value=True, | |
info=( | |
"UTMOSv2 makes predictions repeatedly for five randomly selected frames " | |
"of the input speech waveform for all five folds. " | |
"To make quick predictions by reducing this to a single repetition, " | |
"check this checkbox:", | |
), | |
) | |
submit = gr.Button(value="Submit") | |
with gr.Column(): | |
output = gr.Textbox(label="Predicted MOS", type="text") | |
submit.click(fn=predict_mos, inputs=[audio, domain, quick], outputs=[output]) | |
demo.queue().launch() |