from fastapi import FastAPI, HTTPException, Response from pydantic import BaseModel from transformers import AutoTokenizer, AutoModel import torch import numpy as np import io from scipy.io.wavfile import write from PIL import Image import riffusion app = FastAPI() # Load Riffusion model model_name = "riffusion/riffusion-model-v1" model = AutoModel.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) class MusicRequest(BaseModel): prompt: str @app.post("/generate-music/") def generate_music(request: MusicRequest): try: inputs = tokenizer(request.prompt, return_tensors="pt").to(device) with torch.no_grad(): spectrogram = model.generate(**inputs) # Convert spectrogram to an image (since Riffusion outputs spectrograms) spectrogram_image = Image.fromarray((spectrogram.cpu().numpy().squeeze() * 255).astype(np.uint8)) # Convert spectrogram to audio audio_values, sampling_rate = riffusion.audio_processing.spectrogram_to_audio(spectrogram_image) # Normalize and convert to int16 audio_values = np.clip(audio_values * 32767, -32768, 32767).astype(np.int16) # Convert to WAV format audio_bytes = io.BytesIO() write(audio_bytes, sampling_rate, audio_values) audio_bytes.seek(0) return Response(content=audio_bytes.read(), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=generated_music.wav"}) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") def root(): return {"message": "Welcome to the Riffusion Music Generation API"}