llama-kokoro / app.py
khurrameycon's picture
Update app.py
c9b6f29 verified
raw
history blame
7 kB
# import gradio as gr
import os
import torch
from huggingface_hub import InferenceClient
# Khurram
from fastapi import FastAPI, Query
from pydantic import BaseModel
import uvicorn
from fastapi.responses import JSONResponse
#################
# Import eSpeak TTS pipeline
from tts_cli import (
build_model as build_model_espeak,
generate_long_form_tts as generate_long_form_tts_espeak,
)
# Import OpenPhonemizer TTS pipeline
from tts_cli_op import (
build_model as build_model_open,
generate_long_form_tts as generate_long_form_tts_open,
)
from pretrained_models import Kokoro
#
# ---------------------------------------------------------------------
# Path to models and voicepacks
# ---------------------------------------------------------------------
MODELS_DIR = "pretrained_models/Kokoro"
VOICES_DIR = "pretrained_models/Kokoro/voices"
HF_TOKEN = os.getenv("HF_TOKEN")
client = InferenceClient(api_key=HF_TOKEN)
# ---------------------------------------------------------------------
# List the models (.pth) and voices (.pt)
# ---------------------------------------------------------------------
def get_models():
return sorted([f for f in os.listdir(MODELS_DIR) if f.endswith(".pth")])
def get_voices():
return sorted([f for f in os.listdir(VOICES_DIR) if f.endswith(".pt")])
# ---------------------------------------------------------------------
# We'll map engine selection -> (build_model_func, generate_func)
# ---------------------------------------------------------------------
ENGINES = {
"espeak": (build_model_espeak, generate_long_form_tts_espeak),
"openphonemizer": (build_model_open, generate_long_form_tts_open),
}
# ---------------------------------------------------------------------
# The main inference function called by Gradio
# ---------------------------------------------------------------------
def tts_inference(text, engine, model_file, voice_file, speed=1.0):
"""
text: Input string
engine: "espeak" or "openphonemizer"
model_file: Selected .pth from the models folder
voice_file: Selected .pt from the voices folder
speed: Speech speed
"""
# 0) Get the response of user query from LLAMA
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": text + str('describe in one line only')
} #,
# {
# "type": "image_url",
# "image_url": {
# "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
# }
# }
]
}
]
response_from_llama = client.chat.completions.create(
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
messages=messages,
max_tokens=500)
# 1) Map engine to the correct build_model + generate_long_form_tts
build_fn, gen_fn = ENGINES[engine]
# 2) Prepare paths
model_path = os.path.join(MODELS_DIR, model_file)
voice_path = os.path.join(VOICES_DIR, voice_file)
# 3) Decide device
device = "cuda" if torch.cuda.is_available() else "cpu"
# 4) Load model
model = build_fn(model_path, device=device)
# Set submodules eval
for k, subm in model.items():
if hasattr(subm, "eval"):
subm.eval()
# 5) Load voicepack
voicepack = torch.load(voice_path, map_location=device)
if hasattr(voicepack, "eval"):
voicepack.eval()
# 6) Generate TTS
audio, phonemes = gen_fn(model, response_from_llama.choices[0].message['content'], voicepack, speed=speed)
sr = 22050 # or your actual sample rate
return (sr, audio) # Gradio expects (sample_rate, np_array)
#------------------------------------------
# FAST API
#---------------
app = FastAPI()
class TTSRequest(BaseModel):
text: str
engine: str
model_file: str
voice_file: str
speed: float = 1.0
@app.post("/tts")
def generate_tts(request: TTSRequest):
try:
sr, audio = tts_inference(
text="What is Deep SeEK? define in 2 lines",
engine="openphonemizer",
model_file="kokoro-v0_19.pth",
voice_file="af_bella.pt",
speed=1.0
)
return JSONResponse(content={
"sample_rate": sr,
"audio_tensor": audio.tolist()
})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
###############################
# # ---------------------------------------------------------------------
# # Build Gradio App
# # ---------------------------------------------------------------------
# def create_gradio_app():
# model_list = get_models()
# voice_list = get_voices()
# css = """
# h4 {
# text-align: center;
# display:block;
# }
# h2 {
# text-align: center;
# display:block;
# }
# """
# with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
# gr.Markdown("## LLAMA TTS DEMO - API - GRADIO VISUAL")
# # Row 1: Text input
# text_input = gr.Textbox(
# label="Enter your question",
# value="What is AI?",
# lines=2,
# )
# # Row 2: Engine selection
# # engine_dropdown = gr.Dropdown(
# # choices=["espeak", "openphonemizer"],
# # value="openphonemizer",
# # label="Phonemizer",
# # )
# # Row 3: Model dropdown
# # model_dropdown = gr.Dropdown(
# # choices=model_list,
# # value=model_list[0] if model_list else None,
# # label="Model (.pth)",
# # )
# # Row 4: Voice dropdown
# # voice_dropdown = gr.Dropdown(
# # choices=voice_list,
# # value=voice_list[0] if voice_list else None,
# # label="Voice (.pt)",
# # )
# # Row 5: Speed slider
# speed_slider = gr.Slider(
# minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speech Speed"
# )
# # Generate button + audio output
# generate_btn = gr.Button("Generate")
# tts_output = gr.Audio(label="TTS Output")
# # Connect the button to our inference function
# generate_btn.click(
# fn=tts_inference,
# inputs=[
# text_input,
# gr.State("openphonemizer"), #engine_dropdown,
# gr.State("kokoro-v0_19.pth"), #model_dropdown,
# gr.State("af_bella.pt"), #voice_dropdown,
# speed_slider,
# ],
# outputs=tts_output,
# )
# gr.Markdown(
# "#### LLAMA - TTS"
# )
# return demo
# # ---------------------------------------------------------------------
# # Main
# # ---------------------------------------------------------------------
# if __name__ == "__main__":
# app = create_gradio_app()
# app.launch()