|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from transformers import AutoTokenizer, AutoModelForTextToWaveform |
|
import torch |
|
from scipy.io.wavfile import write |
|
import numpy as np |
|
import uuid |
|
import os |
|
|
|
app = FastAPI() |
|
|
|
|
|
model_name = "facebook/musicgen-small" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForTextToWaveform.from_pretrained(model_name) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
OUTPUT_DIR = "generated_audio" |
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
class MusicRequest(BaseModel): |
|
prompt: str |
|
|
|
@app.post("/generate-music/") |
|
def generate_music(request: MusicRequest): |
|
try: |
|
inputs = tokenizer(request.prompt, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
audio_values = model.generate(**inputs) |
|
|
|
sampling_rate = model.config.sampling_rate |
|
audio_values = audio_values.cpu().numpy().squeeze() |
|
|
|
|
|
file_id = str(uuid.uuid4())[:8] |
|
output_wav_path = os.path.join(OUTPUT_DIR, f"music_{file_id}.wav") |
|
|
|
|
|
write(output_wav_path, sampling_rate, np.int16(audio_values * 32767)) |
|
|
|
return {"message": "Music generated successfully", "file_path": output_wav_path} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/") |
|
def root(): |
|
return {"message": "Welcome to the Music Generation API"} |
|
|