UTMOSv2 / app.py
Wataru's picture
updated dependency
95a2b02
import importlib
from types import SimpleNamespace
import gradio as gr
import pandas as pd
import spaces
import torch
from utmosv2.utils import get_dataset, get_model
description = (
"# πŸš€ UTMOSv2 demo\n\n"
"[![GitHub](https://img.shields.io/badge/-GitHub-181717.svg?logo=github&style=flat)](https://github.com/sarulab-speech/UTMOSv2)\n\n"
"This is a demonstration of MOS prediction using UTMOSv2. "
"This demonstration only accepts `.wav` format. Best at 16 kHz sampling rate."
)
device = torch.device("cuda")
config = importlib.import_module("utmosv2.config.fusion_stage3")
cfg = SimpleNamespace(**{attr: getattr(config, attr) for attr in config.__dict__ if not attr.startswith("__")})
cfg.reproduce = False
cfg.config = "fusion_stage3"
cfg.print_config = False
cfg.data_config = None
cfg.phase = "inference"
cfg.num_workers = 1
@spaces.GPU
@torch.inference_mode()
def predict_mos(audio_path: str, domain: str, quick: bool) -> float:
data = pd.DataFrame({"file_path": [audio_path]})
data["dataset"] = domain
data["mos"] = 0
preds = 0.0
for fold in range(5):
cfg.now_fold = fold
cfg.weight = f"models/fusion_stage3/fold{fold}_s42_best_model.pth"
model = get_model(cfg, device).eval()
for _ in range(5):
test_dataset = get_dataset(cfg, data, "test")
p = model(*[torch.tensor(t,dtype=torch.float32).unsqueeze(0).to(device) for t in test_dataset[0][:-1]])
preds += p.cpu().numpy()[0][0]
if quick:
return preds
preds /= 25.0
return preds
with gr.Blocks() as demo:
gr.Markdown(description)
with gr.Row():
with gr.Column():
audio = gr.Audio(type="filepath", label="Audio")
domain = gr.Dropdown(
[
"sarulab",
"bvcc",
"somos",
"blizzard2008",
"blizzard2009",
"blizzard2010-EH1",
"blizzard2010-EH2",
"blizzard2010-ES1",
"blizzard2010-ES3",
"blizzard2011",
],
label="Data-domain ID for the MOS prediction",
value="sarulab",
)
quick = gr.Checkbox(
label="Quick prediction",
value=True,
info=(
"UTMOSv2 makes predictions repeatedly for five randomly selected frames "
"of the input speech waveform for all five folds. "
"To make quick predictions by reducing this to a single repetition, "
"check this checkbox:",
),
)
submit = gr.Button(value="Submit")
with gr.Column():
output = gr.Textbox(label="Predicted MOS", type="text")
submit.click(fn=predict_mos, inputs=[audio, domain, quick], outputs=[output])
demo.queue().launch()