Spaces:
Running
Running
Add app.py
Browse files- app.py +222 -0
- lerobot_datasets.csv +0 -0
app.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import json
|
4 |
+
from huggingface_hub import HfApi
|
5 |
+
import pandas as pd
|
6 |
+
import json
|
7 |
+
import spacy
|
8 |
+
import ast
|
9 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
10 |
+
|
11 |
+
def analyze_dataset_metadata(repo_id: str):
|
12 |
+
try:
|
13 |
+
metadata = LeRobotDatasetMetadata(repo_id=repo_id, revision="v2.0")
|
14 |
+
except Exception as e:
|
15 |
+
try:
|
16 |
+
metadata = LeRobotDatasetMetadata(repo_id=repo_id, revision="v2.1")
|
17 |
+
except Exception as e:
|
18 |
+
print(f"Error loading metadata for {repo_id}: {str(e)}")
|
19 |
+
return None
|
20 |
+
|
21 |
+
# Check version
|
22 |
+
version_str = str(metadata._version).strip()
|
23 |
+
if version_str not in ["2.0", "2.1"]:
|
24 |
+
print(f"Skipping {repo_id}: version <{version_str}>")
|
25 |
+
return None
|
26 |
+
|
27 |
+
try:
|
28 |
+
info = {
|
29 |
+
"repo_id": repo_id,
|
30 |
+
"username": repo_id.split('/')[0],
|
31 |
+
"robot_type": metadata.robot_type,
|
32 |
+
"total_episodes": metadata.total_episodes,
|
33 |
+
"total_frames": metadata.total_frames,
|
34 |
+
"fps": metadata.fps,
|
35 |
+
"camera_keys": ','.join(metadata.camera_keys), # Convert list to string
|
36 |
+
"num_cameras": len(metadata.camera_keys),
|
37 |
+
"video_keys": ','.join(metadata.video_keys),
|
38 |
+
"has_video": len(metadata.video_keys) > 0,
|
39 |
+
"total_tasks": metadata.total_tasks,
|
40 |
+
"tasks": json.dumps(metadata.tasks), # Convert dict to JSON string
|
41 |
+
"is_sim": "sim_" in repo_id.lower(),
|
42 |
+
"is_eval": "eval_" in repo_id.lower(),
|
43 |
+
"features": ','.join(metadata.features.keys()),
|
44 |
+
"chunks_size": metadata.chunks_size,
|
45 |
+
"total_chunks": metadata.total_chunks,
|
46 |
+
"version": metadata._version
|
47 |
+
}
|
48 |
+
return info
|
49 |
+
except Exception as e:
|
50 |
+
print(f"Error extracting metadata for {repo_id}: {str(e)}")
|
51 |
+
return None
|
52 |
+
|
53 |
+
def extract_metadata_fn(tags, progress=gr.Progress()):
|
54 |
+
progress(0)
|
55 |
+
api = HfApi()
|
56 |
+
tags = tags.split(",") if tags else None
|
57 |
+
datasets = api.list_datasets(tags=tags)
|
58 |
+
repo_ids = [dataset.id for dataset in datasets]
|
59 |
+
gr.Info(f"Found {len(repo_ids)} datasets with provided tags. Extracting metadata...")
|
60 |
+
dataset_infos = []
|
61 |
+
for i, repo_id in progress.tqdm(enumerate(repo_ids)):
|
62 |
+
progress(i)
|
63 |
+
info = analyze_dataset_metadata(repo_id)
|
64 |
+
if info is not None:
|
65 |
+
dataset_infos.append(info)
|
66 |
+
|
67 |
+
# Convert to DataFrame and save to CSV and pickle
|
68 |
+
df = pd.DataFrame(dataset_infos)
|
69 |
+
csv_filename = "lerobot_datasets.csv"
|
70 |
+
gr.Info(f"Dataset metadata extracted. Saving to {csv_filename}")
|
71 |
+
df.to_csv(csv_filename, index=False)
|
72 |
+
return df
|
73 |
+
|
74 |
+
def load_metadata_fn(file_explorer):
|
75 |
+
gr.Info(f"Metadata loaded from {file_explorer}.")
|
76 |
+
df = pd.read_csv(file_explorer)
|
77 |
+
return df
|
78 |
+
|
79 |
+
def filter_tasks(tasks_json):
|
80 |
+
"""Filter out tasks that are too short and contain weird names"""
|
81 |
+
try:
|
82 |
+
tasks = json.loads(tasks_json)
|
83 |
+
valid_tasks = [task for task in tasks.values()
|
84 |
+
if task and isinstance(task, str) and len(task.strip()) > 10
|
85 |
+
and len(task.split("_")) < 3 and "test" not in task.lower()]
|
86 |
+
return len(valid_tasks) > 0
|
87 |
+
except:
|
88 |
+
return False
|
89 |
+
|
90 |
+
def filtering_metadata(
|
91 |
+
df,
|
92 |
+
num_episodes,
|
93 |
+
num_frames,
|
94 |
+
include_sim,
|
95 |
+
robot_set,
|
96 |
+
include_eval,
|
97 |
+
filter_unlabeled_tasks
|
98 |
+
):
|
99 |
+
all_data_number = len(df)
|
100 |
+
filtered_datasets = df[
|
101 |
+
(df['total_episodes'] >= num_episodes) &
|
102 |
+
(df['total_frames'] >= num_frames) &
|
103 |
+
(df['has_video'] == True) &
|
104 |
+
(df['is_sim'] == include_sim) &
|
105 |
+
(df['robot_type'].isin(robot_set)) &
|
106 |
+
('test' not in df['repo_id'])
|
107 |
+
]
|
108 |
+
if not include_eval:
|
109 |
+
filtered_datasets = filtered_datasets[filtered_datasets['is_eval'] == False]
|
110 |
+
if filter_unlabeled_tasks:
|
111 |
+
filtered_datasets['has_valid_tasks'] = filtered_datasets['tasks'].apply(filter_tasks)
|
112 |
+
filtered_datasets = filtered_datasets[filtered_datasets['has_valid_tasks']]
|
113 |
+
gr.Info(f"Filtering datasets from {all_data_number} to {len(filtered_datasets)}")
|
114 |
+
return len(filtered_datasets), filtered_datasets["repo_id"].to_list(), filtered_datasets
|
115 |
+
|
116 |
+
class LeRobotAnalysisApp(object):
|
117 |
+
def __init__(self, ui_obj):
|
118 |
+
self.name = "LeRobot Analysis App"
|
119 |
+
self.description = "Analyze LeRobot datasets"
|
120 |
+
self.ui_obj = ui_obj
|
121 |
+
|
122 |
+
# TODO
|
123 |
+
def create_app(self):
|
124 |
+
with self.ui_obj:
|
125 |
+
gr.Markdown("Application to filter & analyze LeRobot datasets")
|
126 |
+
filtered_data = gr.DataFrame(visible=False)
|
127 |
+
with gr.Tabs():
|
128 |
+
with gr.TabItem("1) Extract/Load Data"):
|
129 |
+
with gr.Row():
|
130 |
+
with gr.Column():
|
131 |
+
gr.Markdown("# Extract metadata from HF API")
|
132 |
+
gr.Markdown("Choose a set of **tags** (separated by a coma) to select the datasets to extract **metadata** from.")
|
133 |
+
gr.Markdown("The final metadata will be saved to a **CSV file**.")
|
134 |
+
tags = gr.Textbox(label="Tags", value="LeRobot",
|
135 |
+
placeholder="Enter tags separated by comma",
|
136 |
+
info="Enter tags separated by comma",
|
137 |
+
lines=3)
|
138 |
+
btn_extract = gr.Button("Extract Data")
|
139 |
+
gr.Markdown("# OR Load from CSV")
|
140 |
+
gr.Markdown("If you already downloaded the metadata in CSV, you can directly load it here.")
|
141 |
+
file_explorer = gr.FileExplorer(label="Load CSV file", file_count="single")
|
142 |
+
btn_load = gr.Button("Load CSV Data")
|
143 |
+
with gr.Column():
|
144 |
+
out_data = gr.DataFrame()
|
145 |
+
btn_extract.click(extract_metadata_fn, [tags], [out_data])
|
146 |
+
btn_load.click(load_metadata_fn, [file_explorer], [out_data])
|
147 |
+
with gr.TabItem("2) Filter Data"):
|
148 |
+
@gr.render(inputs=[out_data])
|
149 |
+
def filter_data(out_data):
|
150 |
+
if out_data.empty:
|
151 |
+
gr.Markdown("# Filtering data")
|
152 |
+
gr.Markdown("No data to display : please extract or load metadata first")
|
153 |
+
else:
|
154 |
+
df = out_data
|
155 |
+
min_eps = int(df['total_episodes'].min())
|
156 |
+
min_frames = int(df['total_frames'].min())
|
157 |
+
robot_types = list(set(df['robot_type'].to_list()))
|
158 |
+
robot_types.sort()
|
159 |
+
with gr.Row():
|
160 |
+
with gr.Column():
|
161 |
+
gr.Markdown("# Filtering data")
|
162 |
+
gr.Markdown("Filter the extracted datasets to your needs")
|
163 |
+
data = gr.DataFrame(label="Dataset Metadata", value=out_data)
|
164 |
+
is_sim = gr.Checkbox(label="Include simulation datasets", value=False)
|
165 |
+
eps = gr.Number(label="Min episodes ", value=min_eps)
|
166 |
+
frames = gr.Number(label="Min frames", value=min_frames)
|
167 |
+
robot_type = gr.CheckboxGroup(label="Robot types", choices=robot_types)
|
168 |
+
incl_eval = gr.Checkbox(label="Include evaluation datasets", value=False)
|
169 |
+
filter_task = gr.Checkbox(label="Filter unlabeled tasks", value=True)
|
170 |
+
btn_filter = gr.Button("Filter Data")
|
171 |
+
with gr.Column():
|
172 |
+
out_num_d = gr.Number(label="Number of datasets", value=0)
|
173 |
+
out_text = gr.Text(label="Dataset repo IDs", value="")
|
174 |
+
btn_filter.click(filtering_metadata,
|
175 |
+
inputs=[data, eps, frames, is_sim, robot_type, incl_eval, filter_task],
|
176 |
+
outputs=[out_num_d, out_text, filtered_data])
|
177 |
+
with gr.TabItem("3) Analyze Data"):
|
178 |
+
@gr.render(inputs=[out_data, filtered_data])
|
179 |
+
def analyze_data(out_data, filtered_data):
|
180 |
+
if out_data.empty:
|
181 |
+
gr.Markdown("# Analyzing data")
|
182 |
+
gr.Markdown("No data to display : please extract or load metadata first")
|
183 |
+
else:
|
184 |
+
with gr.Row():
|
185 |
+
with gr.Column():
|
186 |
+
if filtered_data.empty:
|
187 |
+
gr.BarPlot(out_data, x="robot_type", y="total_episodes", title="Episodes per robot type")
|
188 |
+
else:
|
189 |
+
actions_df = self.extract_actions_from_tasks(filtered_data['tasks'])
|
190 |
+
gr.BarPlot(filtered_data, x="robot_type", y="total_episodes", title="Episodes per robot type")
|
191 |
+
gr.BarPlot(actions_df, title="Counting of each actions",
|
192 |
+
x="actions",
|
193 |
+
y="count",
|
194 |
+
x_label="Actions",
|
195 |
+
y_label="Count of actions")
|
196 |
+
|
197 |
+
def extract_actions_from_tasks(self, tasks):
|
198 |
+
gr.Info("Extracting actions from tasks, it might take a while...")
|
199 |
+
nlp = spacy.load("en_core_web_sm")
|
200 |
+
actions = []
|
201 |
+
for el in tasks:
|
202 |
+
dict_tasks = ast.literal_eval(el)
|
203 |
+
for id, task in dict_tasks.items():
|
204 |
+
doc = nlp(task)
|
205 |
+
for token in doc:
|
206 |
+
if token.pos_ == "VERB":
|
207 |
+
actions.append(token.lemma_)
|
208 |
+
# Remove duplicates
|
209 |
+
actions_unique = list(set(actions))
|
210 |
+
count_actions = [actions.count(action) for action in actions_unique]
|
211 |
+
|
212 |
+
return pd.DataFrame({"actions": actions_unique, "count": count_actions})
|
213 |
+
|
214 |
+
def launch_ui(self):
|
215 |
+
self.ui_obj.launch()
|
216 |
+
|
217 |
+
if __name__ == "__main__":
|
218 |
+
app = gr.Blocks()
|
219 |
+
ui = LeRobotAnalysisApp(app)
|
220 |
+
ui.create_app()
|
221 |
+
ui.launch_ui()
|
222 |
+
|
lerobot_datasets.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|