|
import os |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import gradio as gr |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
if not hf_token: |
|
raise ValueError("Hugging Face token not found in environment variables.") |
|
|
|
|
|
model_name = "sander-wood/music-transformer" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=hf_token) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
def generate_music(input_text, max_length=512, temperature=0.9, top_p=0.95): |
|
""" |
|
Generate music based on the input text. |
|
""" |
|
|
|
inputs = tokenizer(input_text, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs["input_ids"], |
|
max_length=max_length, |
|
num_return_sequences=1, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True, |
|
) |
|
|
|
|
|
generated_music = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return generated_music |
|
|
|
|
|
def gradio_interface(input_text, max_length, temperature, top_p): |
|
""" |
|
Gradio interface for generating music. |
|
""" |
|
generated_music = generate_music(input_text, max_length, temperature, top_p) |
|
return generated_music |
|
|
|
|
|
inputs = [ |
|
gr.Textbox(label="Input Text", placeholder="Enter a music description..."), |
|
gr.Slider(minimum=64, maximum=1024, value=512, label="Max Length"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Temperature"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p"), |
|
] |
|
|
|
outputs = gr.Textbox(label="Generated Music (ABC Notation)") |
|
|
|
|
|
app = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=inputs, |
|
outputs=outputs, |
|
title="Music Transformer", |
|
description="Generate music in ABC notation using the sander-wood/music-transformer model.", |
|
) |
|
|
|
|
|
app.launch(server_name="0.0.0.0", server_port=7860) |