Fred808 commited on
Commit
b4e27c2
·
verified ·
1 Parent(s): a3a5240

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -1,17 +1,19 @@
1
  from fastapi import FastAPI, HTTPException, Response
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForTextToWaveform
4
  import torch
5
- from scipy.io.wavfile import write
6
  import numpy as np
7
  import io
 
 
 
8
 
9
  app = FastAPI()
10
 
11
- # Load model and tokenizer
12
- model_name = "facebook/musicgen-medium"
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- model = AutoModelForTextToWaveform.from_pretrained(model_name, attn_implementation="eager")
15
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  model.to(device)
@@ -24,15 +26,18 @@ def generate_music(request: MusicRequest):
24
  try:
25
  inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
26
  with torch.no_grad():
27
- audio_values = model.generate(**inputs)
 
 
 
28
 
29
- sampling_rate = model.config.sampling_rate
30
- audio_values = audio_values.cpu().numpy().squeeze()
31
 
32
- # Normalize audio values to fit int16 range
33
  audio_values = np.clip(audio_values * 32767, -32768, 32767).astype(np.int16)
34
 
35
- # Convert audio to bytes
36
  audio_bytes = io.BytesIO()
37
  write(audio_bytes, sampling_rate, audio_values)
38
  audio_bytes.seek(0)
@@ -43,4 +48,4 @@ def generate_music(request: MusicRequest):
43
 
44
  @app.get("/")
45
  def root():
46
- return {"message": "Welcome to the Music Generation API"}
 
1
  from fastapi import FastAPI, HTTPException, Response
2
  from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModel
4
  import torch
 
5
  import numpy as np
6
  import io
7
+ from scipy.io.wavfile import write
8
+ from PIL import Image
9
+ import riffusion
10
 
11
  app = FastAPI()
12
 
13
+ # Load Riffusion model
14
+ model_name = "riffusion/riffusion-model-v1"
15
+ model = AutoModel.from_pretrained(model_name)
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
17
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  model.to(device)
 
26
  try:
27
  inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
28
  with torch.no_grad():
29
+ spectrogram = model.generate(**inputs)
30
+
31
+ # Convert spectrogram to an image (since Riffusion outputs spectrograms)
32
+ spectrogram_image = Image.fromarray((spectrogram.cpu().numpy().squeeze() * 255).astype(np.uint8))
33
 
34
+ # Convert spectrogram to audio
35
+ audio_values, sampling_rate = riffusion.audio_processing.spectrogram_to_audio(spectrogram_image)
36
 
37
+ # Normalize and convert to int16
38
  audio_values = np.clip(audio_values * 32767, -32768, 32767).astype(np.int16)
39
 
40
+ # Convert to WAV format
41
  audio_bytes = io.BytesIO()
42
  write(audio_bytes, sampling_rate, audio_values)
43
  audio_bytes.seek(0)
 
48
 
49
  @app.get("/")
50
  def root():
51
+ return {"message": "Welcome to the Riffusion Music Generation API"}