import os
import json
import math

import torch
import torch.nn.functional as F
import librosa
import numpy as np
import soundfile as sf
import gradio as gr
from transformers import WavLMModel

from env import AttrDict
from meldataset import mel_spectrogram, MAX_WAV_VALUE
from models import Generator
from Utils.JDC.model import JDCNet


# files
hpfile = "config_v1_16k.json"
ptfile = "exp/default/g_00700000"
spk2id_path = "filelists/spk2id.json"
f0_stats_path = "filelists/f0_stats.json"
spk_stats_path = "filelists/spk_stats.json"
spk_emb_dir = "dataset/spk"
spk_wav_dir = "dataset/audio"

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load config
with open(hpfile) as f:
    data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)

# load models
F0_model = JDCNet(num_class=1, seq_len=192)
generator = Generator(h, F0_model).to(device)

state_dict_g = torch.load(ptfile, map_location=device)
generator.load_state_dict(state_dict_g['generator'], strict=True)
generator.remove_weight_norm()
_ = generator.eval()

wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base-plus")
wavlm.eval()
wavlm.to(device)

# load stats
with open(spk2id_path) as f:
    spk2id = json.load(f)
with open(f0_stats_path) as f:
    f0_stats = json.load(f)
with open(spk_stats_path) as f:
    spk_stats = json.load(f)

# tune f0
threshold = 10
step = (math.log(1100) - math.log(50)) / 256
def tune_f0(initial_f0, i):
    if i == 0:
        return initial_f0
    voiced = initial_f0 > threshold
    initial_lf0 = torch.log(initial_f0)
    lf0 = initial_lf0 + step * i
    f0 = torch.exp(lf0)
    f0 = torch.where(voiced, f0, initial_f0)
    return f0

# convert function
def convert(tgt_spk, src_wav, f0_shift=0):
    tgt_ref = spk_stats[tgt_spk]["best_spk_emb"]
    tgt_emb = f"{spk_emb_dir}/{tgt_spk}/{tgt_ref}.npy"

    with torch.no_grad():
        # tgt
        spk_id = spk2id[tgt_spk]
        spk_id = torch.LongTensor([spk_id]).unsqueeze(0).to(device)
        
        spk_emb = np.load(tgt_emb)
        spk_emb = torch.from_numpy(spk_emb).unsqueeze(0).to(device)

        f0_mean_tgt = f0_stats[tgt_spk]["mean"]
        f0_mean_tgt = torch.FloatTensor([f0_mean_tgt]).unsqueeze(0).to(device)

        # src
        wav, sr = librosa.load(src_wav, sr=16000)
        wav = torch.FloatTensor(wav).to(device)
        mel = mel_spectrogram(wav.unsqueeze(0), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
        
        x = wavlm(wav.unsqueeze(0)).last_hidden_state
        x = x.transpose(1, 2) # (B, C, T)
        x = F.pad(x, (0, mel.size(2) - x.size(2)), 'constant')

        # cvt
        f0 = generator.get_f0(mel, f0_mean_tgt)
        f0 = tune_f0(f0, f0_shift)
        x = generator.get_x(x, spk_emb, spk_id)
        y = generator.infer(x, f0)
        
        audio = y.squeeze()
        audio = audio / torch.max(torch.abs(audio)) * 0.95
        audio = audio * MAX_WAV_VALUE
        audio = audio.cpu().numpy().astype('int16')

        sf.write("out.wav", audio, h.sampling_rate, "PCM_16")

    out_wav = "out.wav"
    return out_wav

# change spk
def change_spk(tgt_spk):
    tgt_ref = spk_stats[tgt_spk]["best_spk_emb"]
    tgt_wav = f"{spk_wav_dir}/{tgt_spk}/{tgt_ref}.wav"
    return tgt_wav

# interface
with gr.Blocks() as demo:
    gr.Markdown("# PitchVC")
    gr.Markdown("Gradio Demo for PitchVC. ([Github Repo](https://github.com/OlaWod/PitchVC))")

    with gr.Row():
        with gr.Column():
            tgt_spk = gr.Dropdown(choices=spk2id.keys(), type="value", label="Target Speaker")
            ref_audio =  gr.Audio(label="Reference Audio", type='filepath')
            src_audio = gr.Audio(label="Source Audio", type='filepath')
            f0_shift = gr.Slider(minimum=-30, maximum=30, value=0, step=1, label="F0 Shift")
        with gr.Column():
            out_audio =  gr.Audio(label="Output Audio", type='filepath')
            submit = gr.Button(value="Submit")

    tgt_spk.change(fn=change_spk, inputs=[tgt_spk], outputs=[ref_audio])
    submit.click(convert, [tgt_spk, src_audio, f0_shift], [out_audio])

    examples = gr.Examples(
        examples=[["p225", 'dataset/audio/p226/p226_341.wav', 0], 
                    ["p226", 'dataset/audio/p225/p225_220.wav', -5]],
        inputs=[tgt_spk, src_audio, f0_shift])

demo.launch()