Botpy-808 / app.py
Fred808's picture
Update app.py
ca8f9b4 verified
raw
history blame
1.59 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForTextToWaveform
import torch
from scipy.io.wavfile import write
import numpy as np
import uuid
import os
app = FastAPI()
# Load model and tokenizer
model_name = "facebook/musicgen-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTextToWaveform.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Directory to save generated audio files
OUTPUT_DIR = "generated_audio"
os.makedirs(OUTPUT_DIR, exist_ok=True)
class MusicRequest(BaseModel):
prompt: str
@app.post("/generate-music/")
def generate_music(request: MusicRequest):
try:
inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
with torch.no_grad():
audio_values = model.generate(**inputs)
sampling_rate = model.config.sampling_rate
audio_values = audio_values.cpu().numpy().squeeze()
# Generate a unique filename
file_id = str(uuid.uuid4())[:8]
output_wav_path = os.path.join(OUTPUT_DIR, f"music_{file_id}.wav")
# Save audio file
write(output_wav_path, sampling_rate, np.int16(audio_values * 32767))
return {"message": "Music generated successfully", "file_path": output_wav_path}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def root():
return {"message": "Welcome to the Music Generation API"}