Fred808 commited on
Commit
ca8f9b4
·
verified ·
1 Parent(s): 1f555d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -1
app.py CHANGED
@@ -1 +1,50 @@
1
- pip install tensorrt_llm -U --pre --extra-index-url https://pypi.nvidia.com
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForTextToWaveform
4
+ import torch
5
+ from scipy.io.wavfile import write
6
+ import numpy as np
7
+ import uuid
8
+ import os
9
+
10
+ app = FastAPI()
11
+
12
+ # Load model and tokenizer
13
+ model_name = "facebook/musicgen-small"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModelForTextToWaveform.from_pretrained(model_name)
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ model.to(device)
19
+
20
+ # Directory to save generated audio files
21
+ OUTPUT_DIR = "generated_audio"
22
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
23
+
24
+ class MusicRequest(BaseModel):
25
+ prompt: str
26
+
27
+ @app.post("/generate-music/")
28
+ def generate_music(request: MusicRequest):
29
+ try:
30
+ inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
31
+ with torch.no_grad():
32
+ audio_values = model.generate(**inputs)
33
+
34
+ sampling_rate = model.config.sampling_rate
35
+ audio_values = audio_values.cpu().numpy().squeeze()
36
+
37
+ # Generate a unique filename
38
+ file_id = str(uuid.uuid4())[:8]
39
+ output_wav_path = os.path.join(OUTPUT_DIR, f"music_{file_id}.wav")
40
+
41
+ # Save audio file
42
+ write(output_wav_path, sampling_rate, np.int16(audio_values * 32767))
43
+
44
+ return {"message": "Music generated successfully", "file_path": output_wav_path}
45
+ except Exception as e:
46
+ raise HTTPException(status_code=500, detail=str(e))
47
+
48
+ @app.get("/")
49
+ def root():
50
+ return {"message": "Welcome to the Music Generation API"}