|
from fastapi import FastAPI, HTTPException, Response |
|
from pydantic import BaseModel |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
import numpy as np |
|
import io |
|
from scipy.io.wavfile import write |
|
from PIL import Image |
|
import riffusion |
|
|
|
app = FastAPI() |
|
|
|
|
|
model_name = "riffusion/riffusion-model-v1" |
|
model = AutoModel.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
class MusicRequest(BaseModel): |
|
prompt: str |
|
|
|
@app.post("/generate-music/") |
|
def generate_music(request: MusicRequest): |
|
try: |
|
inputs = tokenizer(request.prompt, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
spectrogram = model.generate(**inputs) |
|
|
|
|
|
spectrogram_image = Image.fromarray((spectrogram.cpu().numpy().squeeze() * 255).astype(np.uint8)) |
|
|
|
|
|
audio_values, sampling_rate = riffusion.audio_processing.spectrogram_to_audio(spectrogram_image) |
|
|
|
|
|
audio_values = np.clip(audio_values * 32767, -32768, 32767).astype(np.int16) |
|
|
|
|
|
audio_bytes = io.BytesIO() |
|
write(audio_bytes, sampling_rate, audio_values) |
|
audio_bytes.seek(0) |
|
|
|
return Response(content=audio_bytes.read(), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=generated_music.wav"}) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/") |
|
def root(): |
|
return {"message": "Welcome to the Riffusion Music Generation API"} |
|
|