import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer import gradio as gr # Load the Hugging Face token from environment variables hf_token = os.getenv("HF_TOKEN") if not hf_token: raise ValueError("Hugging Face token not found in environment variables.") # Load the tokenizer and model model_name = "sander-wood/music-transformer" tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token) model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=hf_token) # Move the model to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Function to generate music def generate_music(input_text, max_length=512, temperature=0.9, top_p=0.95): """ Generate music based on the input text. """ # Tokenize the input inputs = tokenizer(input_text, return_tensors="pt").to(device) # Generate music with torch.no_grad(): outputs = model.generate( inputs["input_ids"], max_length=max_length, num_return_sequences=1, temperature=temperature, top_p=top_p, do_sample=True, ) # Decode the generated output generated_music = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_music # Gradio interface def gradio_interface(input_text, max_length, temperature, top_p): """ Gradio interface for generating music. """ generated_music = generate_music(input_text, max_length, temperature, top_p) return generated_music # Define Gradio inputs and outputs inputs = [ gr.Textbox(label="Input Text", placeholder="Enter a music description..."), gr.Slider(minimum=64, maximum=1024, value=512, label="Max Length"), gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p"), ] outputs = gr.Textbox(label="Generated Music (ABC Notation)") # Create the Gradio app app = gr.Interface( fn=gradio_interface, inputs=inputs, outputs=outputs, title="Music Transformer", description="Generate music in ABC notation using the sander-wood/music-transformer model.", ) # Launch the app app.launch(server_name="0.0.0.0", server_port=7860)