preslaff's picture
Fixed row_count to (100, dynamic)
9f4735a unverified
import os
import json
import requests
import gradio as gr
import pandas as pd
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from huggingface_hub.repocard import metadata_load
from apscheduler.schedulers.background import BackgroundScheduler
from tqdm.contrib.concurrent import thread_map
from utils import *
DATASET_REPO_URL = "https://huggingface.co/datasets/huggingface-projects/drlc-leaderboard-data"
DATASET_REPO_ID = "huggingface-projects/drlc-leaderboard-data"
HF_TOKEN = os.environ.get("HF_TOKEN")
block = gr.Blocks()
api = HfApi(token=HF_TOKEN)
# Define RL environments
rl_envs = [
{"rl_env_beautiful": "LunarLander-v2 πŸš€", "rl_env": "LunarLander-v2", "video_link": "", "global": None},
{"rl_env_beautiful": "CartPole-v1", "rl_env": "CartPole-v1", "video_link": "https://huggingface.co/sb3/ppo-CartPole-v1/resolve/main/replay.mp4", "global": None},
{"rl_env_beautiful": "FrozenLake-v1-4x4-no_slippery ❄️", "rl_env": "FrozenLake-v1-4x4-no_slippery", "video_link": "", "global": None},
{"rl_env_beautiful": "FrozenLake-v1-8x8-no_slippery ❄️", "rl_env": "FrozenLake-v1-8x8-no_slippery", "video_link": "", "global": None},
{"rl_env_beautiful": "FrozenLake-v1-4x4 ❄️", "rl_env": "FrozenLake-v1-4x4", "video_link": "", "global": None},
{"rl_env_beautiful": "FrozenLake-v1-8x8 ❄️", "rl_env": "FrozenLake-v1-8x8", "video_link": "", "global": None},
{"rl_env_beautiful": "Taxi-v3 πŸš–", "rl_env": "Taxi-v3", "video_link": "", "global": None},
{"rl_env_beautiful": "CarRacing-v0 🏎️", "rl_env": "CarRacing-v0", "video_link": "", "global": None},
{"rl_env_beautiful": "CarRacing-v2 🏎️", "rl_env": "CarRacing-v2", "video_link": "", "global": None},
{"rl_env_beautiful": "MountainCar-v0 ⛰️", "rl_env": "MountainCar-v0", "video_link": "", "global": None},
{"rl_env_beautiful": "SpaceInvadersNoFrameskip-v4 πŸ‘Ύ", "rl_env": "SpaceInvadersNoFrameskip-v4", "video_link": "", "global": None},
{"rl_env_beautiful": "PongNoFrameskip-v4 🎾", "rl_env": "PongNoFrameskip-v4", "video_link": "", "global": None},
{"rl_env_beautiful": "BreakoutNoFrameskip-v4 🧱", "rl_env": "BreakoutNoFrameskip-v4", "video_link": "", "global": None},
{"rl_env_beautiful": "QbertNoFrameskip-v4 🐦", "rl_env": "QbertNoFrameskip-v4", "video_link": "", "global": None},
{"rl_env_beautiful": "BipedalWalker-v3", "rl_env": "BipedalWalker-v3", "video_link": "", "global": None},
{"rl_env_beautiful": "Walker2DBulletEnv-v0", "rl_env": "Walker2DBulletEnv-v0", "video_link": "", "global": None},
{"rl_env_beautiful": "AntBulletEnv-v0", "rl_env": "AntBulletEnv-v0", "video_link": "", "global": None},
{"rl_env_beautiful": "HalfCheetahBulletEnv-v0", "rl_env": "HalfCheetahBulletEnv-v0", "video_link": "", "global": None},
{"rl_env_beautiful": "PandaReachDense-v2", "rl_env": "PandaReachDense-v2", "video_link": "", "global": None},
{"rl_env_beautiful": "PandaReachDense-v3", "rl_env": "PandaReachDense-v3", "video_link": "", "global": None},
{"rl_env_beautiful": "Pixelcopter-PLE-v0", "rl_env": "Pixelcopter-PLE-v0", "video_link": "", "global": None}
]
# -------------------- Utility Functions --------------------
def restart():
"""Restart the Hugging Face Space."""
print("RESTARTING SPACE...")
api.restart_space(repo_id="huggingface-projects/Deep-Reinforcement-Learning-Leaderboard")
def download_leaderboard_dataset():
"""Download leaderboard dataset once at startup."""
print("Downloading leaderboard dataset...")
return snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset")
def get_metadata(model_id):
"""Fetch metadata for a given model from Hugging Face."""
try:
readme_path = hf_hub_download(model_id, filename="README.md", etag_timeout=180)
return metadata_load(readme_path)
except requests.exceptions.HTTPError:
return None # 404 README.md not found
def parse_metrics_accuracy(meta):
"""Extract accuracy metrics from metadata."""
if "model-index" not in meta:
return None
result = meta["model-index"][0]["results"]
metrics = result[0]["metrics"]
return metrics[0]["value"]
def parse_rewards(accuracy):
"""Extract mean and std rewards from accuracy metrics."""
default_std = -1000
default_reward = -1000
if accuracy is not None:
parsed = str(accuracy).split('+/-')
mean_reward = float(parsed[0].strip()) if parsed[0] else default_reward
std_reward = float(parsed[1].strip()) if len(parsed) > 1 else 0
else:
mean_reward, std_reward = default_reward, default_std
return mean_reward, std_reward
def get_model_ids(rl_env):
"""Retrieve models matching the given RL environment."""
return [x.modelId for x in api.list_models(filter=rl_env)]
def update_leaderboard_dataset_parallel(rl_env, path):
"""Parallelized update of leaderboard dataset for a given RL environment."""
model_ids = get_model_ids(rl_env)
def process_model(model_id):
meta = get_metadata(model_id)
if not meta:
return None
user_id = model_id.split('/')[0]
row = {
"User": user_id,
"Model": model_id,
"Results": None,
"Mean Reward": None,
"Std Reward": None
}
accuracy = parse_metrics_accuracy(meta)
mean_reward, std_reward = parse_rewards(accuracy)
row["Results"] = mean_reward - std_reward
row["Mean Reward"] = mean_reward
row["Std Reward"] = std_reward
return row
data = list(thread_map(process_model, model_ids, desc="Processing models"))
data = [row for row in data if row is not None]
ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data))
ranked_dataframe.to_csv(os.path.join(path, f"{rl_env}.csv"), index=False)
return ranked_dataframe
def rank_dataframe(dataframe):
"""Sort models by results and assign ranking."""
dataframe = dataframe.sort_values(by=['Results', 'User', 'Model'], ascending=False)
dataframe.insert(0, 'Ranking', range(1, len(dataframe) + 1))
return dataframe
def run_update_dataset():
"""Update dataset periodically using the scheduler."""
path_ = download_leaderboard_dataset()
for env in rl_envs:
update_leaderboard_dataset_parallel(env["rl_env"], path_)
print("Uploading updated dataset...")
api.upload_folder(
folder_path=path_,
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message="Update dataset"
)
def filter_data(rl_env, path, user_id):
"""Filter dataset for a specific user ID."""
data_df = pd.read_csv(os.path.join(path, f"{rl_env}.csv"))
return data_df[data_df["User"] == user_id]
# -------------------- Gradio UI --------------------
print("Initializing dataset...")
path_ = download_leaderboard_dataset()
with block:
gr.Markdown("""
# πŸ† Deep Reinforcement Learning Course Leaderboard πŸ†
This leaderboard displays trained agents from the [Deep Reinforcement Learning Course](https://huggingface.co/learn/deep-rl-course/unit0/introduction?fw=pt).
**Models are ranked using `mean_reward - std_reward`.**
If you can't find your model, please wait for the next update (every 2 hours).
""")
grpath = gr.State(path_) # Store dataset path as a state variable
for env in rl_envs:
with gr.TabItem(env["rl_env_beautiful"]):
gr.Markdown(f"## {env['rl_env_beautiful']}")
user_id = gr.Textbox(label="Your user ID")
search_btn = gr.Button("Search πŸ”Ž")
reset_btn = gr.Button("Clear Search")
env_state = gr.State(env["rl_env"]) # Store environment name as a state variable
gr_dataframe = gr.Dataframe(
value=pd.read_csv(os.path.join(path_, f"{env['rl_env']}.csv")),
headers=["Ranking πŸ†", "User πŸ€—", "Model πŸ€–", "Results", "Mean Reward", "Std Reward"],
datatype=["number", "markdown", "markdown", "number", "number", "number"],
# row_count=(100, 'fixed')
row_count=(100,"dynamic") # Allows displaying all rows dynamically
)
# βœ… Corrected: Use `gr.State()` for env["rl_env"] and `grpath`
search_btn.click(fn=filter_data, inputs=[env_state, grpath, user_id], outputs=gr_dataframe)
reset_btn.click(fn=lambda: pd.read_csv(os.path.join(path_, f"{env['rl_env']}.csv")), inputs=[], outputs=gr_dataframe)
# -------------------- Scheduler --------------------
scheduler = BackgroundScheduler()
scheduler.add_job(run_update_dataset, 'interval', hours=2) # Update dataset every 2 hours
scheduler.add_job(restart, 'interval', hours=3) # Restart space every 3 hours
scheduler.start()
block.launch()