Fred808 commited on
Commit
22f7e0f
·
verified ·
1 Parent(s): 36a52c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -39
app.py CHANGED
@@ -1,51 +1,79 @@
1
- from fastapi import FastAPI, HTTPException, Response
 
 
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModel
4
- import torch
5
- import numpy as np
6
- import io
7
- from scipy.io.wavfile import write
8
- from PIL import Image
9
- import riffusion
10
 
11
  app = FastAPI()
12
 
13
- # Load Riffusion model
14
- model_name = "riffusion/riffusion-model-v1"
15
- model = AutoModel.from_pretrained(model_name)
16
- tokenizer = AutoTokenizer.from_pretrained(model_name)
17
 
18
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- model.to(device)
 
 
 
 
20
 
21
- class MusicRequest(BaseModel):
22
- prompt: str
 
 
 
 
 
 
 
 
23
 
24
- @app.post("/generate-music/")
25
- def generate_music(request: MusicRequest):
 
 
 
 
 
 
 
 
 
26
  try:
27
- inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
28
- with torch.no_grad():
29
- spectrogram = model.generate(**inputs)
30
-
31
- # Convert spectrogram to an image (since Riffusion outputs spectrograms)
32
- spectrogram_image = Image.fromarray((spectrogram.cpu().numpy().squeeze() * 255).astype(np.uint8))
33
-
34
- # Convert spectrogram to audio
35
- audio_values, sampling_rate = riffusion.audio_processing.spectrogram_to_audio(spectrogram_image)
36
-
37
- # Normalize and convert to int16
38
- audio_values = np.clip(audio_values * 32767, -32768, 32767).astype(np.int16)
39
-
40
- # Convert to WAV format
41
- audio_bytes = io.BytesIO()
42
- write(audio_bytes, sampling_rate, audio_values)
43
- audio_bytes.seek(0)
44
-
45
- return Response(content=audio_bytes.read(), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=generated_music.wav"})
46
  except Exception as e:
47
  raise HTTPException(status_code=500, detail=str(e))
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  @app.get("/")
50
- def root():
51
- return {"message": "Welcome to the Riffusion Music Generation API"}
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ import requests
3
+ import base64
4
  from pydantic import BaseModel
5
+ from typing import Optional
 
 
 
 
 
 
6
 
7
  app = FastAPI()
8
 
9
+ # NVIDIA API endpoint and API key
10
+ NVIDIA_API_URL = "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions"
11
+ API_KEY = "your_nvidia_api_key_here" # Replace with your actual API key
 
12
 
13
+ # Request model for text-based input
14
+ class TextRequest(BaseModel):
15
+ message: str
16
+ max_tokens: Optional[int] = 512
17
+ temperature: Optional[float] = 1.0
18
+ top_p: Optional[float] = 1.0
19
 
20
+ # Function to call the NVIDIA API
21
+ def call_nvidia_api(payload: dict):
22
+ headers = {
23
+ "Authorization": f"Bearer {API_KEY}",
24
+ "Accept": "application/json",
25
+ }
26
+ response = requests.post(NVIDIA_API_URL, headers=headers, json=payload)
27
+ if response.status_code != 200:
28
+ raise HTTPException(status_code=response.status_code, detail="NVIDIA API request failed")
29
+ return response.json()
30
 
31
+ # Endpoint for text-based input
32
+ @app.post("/chat/text")
33
+ async def chat_with_text(request: TextRequest):
34
+ payload = {
35
+ "model": "meta/llama-3.2-90b-vision-instruct",
36
+ "messages": [{"role": "user", "content": request.message}],
37
+ "max_tokens": request.max_tokens,
38
+ "temperature": request.temperature,
39
+ "top_p": request.top_p,
40
+ "stream": False,
41
+ }
42
  try:
43
+ response = call_nvidia_api(payload)
44
+ return {"response": response["choices"][0]["message"]["content"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  except Exception as e:
46
  raise HTTPException(status_code=500, detail=str(e))
47
 
48
+ # Endpoint for image-based input
49
+ @app.post("/chat/image")
50
+ async def chat_with_image(file: UploadFile = File(...)):
51
+ # Read and encode the image file to base64
52
+ image_data = await file.read()
53
+ base64_image = base64.b64encode(image_data).decode("utf-8")
54
+
55
+ # Prepare the payload for the NVIDIA API
56
+ payload = {
57
+ "model": "meta/llama-3.2-90b-vision-instruct",
58
+ "messages": [
59
+ {
60
+ "role": "user",
61
+ "content": f'What is in this image? <img src="data:image/png;base64,{base64_image}" />',
62
+ }
63
+ ],
64
+ "max_tokens": 512,
65
+ "temperature": 1.0,
66
+ "top_p": 1.0,
67
+ "stream": False,
68
+ }
69
+
70
+ try:
71
+ response = call_nvidia_api(payload)
72
+ return {"response": response["choices"][0]["message"]["content"]}
73
+ except Exception as e:
74
+ raise HTTPException(status_code=500, detail=str(e))
75
+
76
+ # Root endpoint
77
  @app.get("/")
78
+ async def root():
79
+ return {"message": "Welcome to the NVIDIA API FastAPI wrapper!"}