Botpy-808 / app.py
Fred808's picture
Update app.py
b4e27c2 verified
raw
history blame
1.8 kB
from fastapi import FastAPI, HTTPException, Response
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import io
from scipy.io.wavfile import write
from PIL import Image
import riffusion
app = FastAPI()
# Load Riffusion model
model_name = "riffusion/riffusion-model-v1"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
class MusicRequest(BaseModel):
prompt: str
@app.post("/generate-music/")
def generate_music(request: MusicRequest):
try:
inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
with torch.no_grad():
spectrogram = model.generate(**inputs)
# Convert spectrogram to an image (since Riffusion outputs spectrograms)
spectrogram_image = Image.fromarray((spectrogram.cpu().numpy().squeeze() * 255).astype(np.uint8))
# Convert spectrogram to audio
audio_values, sampling_rate = riffusion.audio_processing.spectrogram_to_audio(spectrogram_image)
# Normalize and convert to int16
audio_values = np.clip(audio_values * 32767, -32768, 32767).astype(np.int16)
# Convert to WAV format
audio_bytes = io.BytesIO()
write(audio_bytes, sampling_rate, audio_values)
audio_bytes.seek(0)
return Response(content=audio_bytes.read(), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=generated_music.wav"})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def root():
return {"message": "Welcome to the Riffusion Music Generation API"}