osanseviero's picture
Fix token = (#2)
5b73951
from sentence_transformers import SentenceTransformer, util
from huggingface_hub import hf_hub_download
import os
import pickle
import pandas as pd
import gradio as gr
pd.options.mode.chained_assignment = None # Turn off SettingWithCopyWarning
auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
corpus_embeddings = pickle.load(open(hf_hub_download("NimaBoscarino/playlist-generator", repo_type="dataset", filename="verse-embeddings.pkl"), "rb"))
songs = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator", repo_type="dataset", filename="songs_new.csv"))
verses = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator", repo_type="dataset", filename="verses.csv"))
lyrics = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator-private", repo_type="dataset", filename="lyrics_new.csv", use_auth_token=auth_token))
embedder = SentenceTransformer('msmarco-MiniLM-L-6-v3')
def generate_playlist(prompt):
prompt_embedding = embedder.encode(prompt, convert_to_tensor=True)
hits = util.semantic_search(prompt_embedding, corpus_embeddings, top_k=20)
hits = pd.DataFrame(hits[0], columns=['corpus_id', 'score'])
verse_match = verses.iloc[hits['corpus_id']]
verse_match = verse_match.drop_duplicates(subset=["song_id"])
song_match = songs[songs["song_id"].isin(verse_match["song_id"].values)]
song_match.song_id = pd.Categorical(song_match.song_id, categories=verse_match["song_id"].values)
song_match = song_match.sort_values("song_id")
song_match = song_match[0:9] # Only grab the top 9
song_names = list(song_match["full_title"])
song_art = list(song_match["art"].fillna("https://i.imgur.com/bgCDfT1.jpg"))
images = [gr.Image.update(value=art, visible=True) for art in song_art]
return (
gr.Radio.update(label="Songs", interactive=True, choices=song_names),
*images
)
def set_lyrics(full_title):
lyrics_text = lyrics[lyrics["song_id"].isin(songs[songs["full_title"] == full_title]["song_id"])]["text"].iloc[0]
return gr.Textbox.update(value=lyrics_text)
def set_example_prompt(example):
return gr.TextArea.update(value=example[0])
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# Playlist Generator 📻 🎵
""")
with gr.Row():
with gr.Column():
gr.Markdown(
"""
Enter a prompt and generate a playlist based on ✨semantic similarity✨
This was built using Sentence Transformers and Gradio – [see the blog](https://huggingface.co/blog/your-first-ml-project)!
""")
song_prompt = gr.TextArea(
value="Running wild and free",
placeholder="Enter a song prompt, or choose an example"
)
example_prompts = gr.Dataset(
components=[song_prompt],
samples=[
["I feel nostalgic for the past"],
["Running wild and free"],
["I'm deeply in love with someone I just met!"],
["My friends mean the world to me"],
["Sometimes I feel like no one understands"],
]
)
with gr.Column():
fetch_songs = gr.Button(value="Generate Your Playlist 🧑🏽‍🎤").style(full_width=True)
with gr.Row():
tile1 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
tile2 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
tile3 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
with gr.Row():
tile4 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
tile5 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
tile6 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
with gr.Row():
tile7 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
tile8 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
tile9 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
# Workaround because of the Gallery issues
tiles = [tile1, tile2, tile3, tile4, tile5, tile6, tile7, tile8, tile9]
song_option = gr.Radio(label="Songs", interactive=True, choices=None, type="value")
with gr.Column():
verse = gr.Textbox(label="Verse", placeholder="Select a song to see its lyrics")
fetch_songs.click(
fn=generate_playlist,
inputs=[song_prompt],
outputs=[song_option, *tiles],
)
example_prompts.click(
fn=set_example_prompt,
inputs=example_prompts,
outputs=example_prompts.components,
)
song_option.change(
fn=set_lyrics,
inputs=[song_option],
outputs=[verse]
)
demo.launch()