Beegbrain commited on
Commit
827715e
·
1 Parent(s): 27b97b7

Add app.py

Browse files
Files changed (2) hide show
  1. app.py +222 -0
  2. lerobot_datasets.csv +0 -0
app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import json
4
+ from huggingface_hub import HfApi
5
+ import pandas as pd
6
+ import json
7
+ import spacy
8
+ import ast
9
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
10
+
11
+ def analyze_dataset_metadata(repo_id: str):
12
+ try:
13
+ metadata = LeRobotDatasetMetadata(repo_id=repo_id, revision="v2.0")
14
+ except Exception as e:
15
+ try:
16
+ metadata = LeRobotDatasetMetadata(repo_id=repo_id, revision="v2.1")
17
+ except Exception as e:
18
+ print(f"Error loading metadata for {repo_id}: {str(e)}")
19
+ return None
20
+
21
+ # Check version
22
+ version_str = str(metadata._version).strip()
23
+ if version_str not in ["2.0", "2.1"]:
24
+ print(f"Skipping {repo_id}: version <{version_str}>")
25
+ return None
26
+
27
+ try:
28
+ info = {
29
+ "repo_id": repo_id,
30
+ "username": repo_id.split('/')[0],
31
+ "robot_type": metadata.robot_type,
32
+ "total_episodes": metadata.total_episodes,
33
+ "total_frames": metadata.total_frames,
34
+ "fps": metadata.fps,
35
+ "camera_keys": ','.join(metadata.camera_keys), # Convert list to string
36
+ "num_cameras": len(metadata.camera_keys),
37
+ "video_keys": ','.join(metadata.video_keys),
38
+ "has_video": len(metadata.video_keys) > 0,
39
+ "total_tasks": metadata.total_tasks,
40
+ "tasks": json.dumps(metadata.tasks), # Convert dict to JSON string
41
+ "is_sim": "sim_" in repo_id.lower(),
42
+ "is_eval": "eval_" in repo_id.lower(),
43
+ "features": ','.join(metadata.features.keys()),
44
+ "chunks_size": metadata.chunks_size,
45
+ "total_chunks": metadata.total_chunks,
46
+ "version": metadata._version
47
+ }
48
+ return info
49
+ except Exception as e:
50
+ print(f"Error extracting metadata for {repo_id}: {str(e)}")
51
+ return None
52
+
53
+ def extract_metadata_fn(tags, progress=gr.Progress()):
54
+ progress(0)
55
+ api = HfApi()
56
+ tags = tags.split(",") if tags else None
57
+ datasets = api.list_datasets(tags=tags)
58
+ repo_ids = [dataset.id for dataset in datasets]
59
+ gr.Info(f"Found {len(repo_ids)} datasets with provided tags. Extracting metadata...")
60
+ dataset_infos = []
61
+ for i, repo_id in progress.tqdm(enumerate(repo_ids)):
62
+ progress(i)
63
+ info = analyze_dataset_metadata(repo_id)
64
+ if info is not None:
65
+ dataset_infos.append(info)
66
+
67
+ # Convert to DataFrame and save to CSV and pickle
68
+ df = pd.DataFrame(dataset_infos)
69
+ csv_filename = "lerobot_datasets.csv"
70
+ gr.Info(f"Dataset metadata extracted. Saving to {csv_filename}")
71
+ df.to_csv(csv_filename, index=False)
72
+ return df
73
+
74
+ def load_metadata_fn(file_explorer):
75
+ gr.Info(f"Metadata loaded from {file_explorer}.")
76
+ df = pd.read_csv(file_explorer)
77
+ return df
78
+
79
+ def filter_tasks(tasks_json):
80
+ """Filter out tasks that are too short and contain weird names"""
81
+ try:
82
+ tasks = json.loads(tasks_json)
83
+ valid_tasks = [task for task in tasks.values()
84
+ if task and isinstance(task, str) and len(task.strip()) > 10
85
+ and len(task.split("_")) < 3 and "test" not in task.lower()]
86
+ return len(valid_tasks) > 0
87
+ except:
88
+ return False
89
+
90
+ def filtering_metadata(
91
+ df,
92
+ num_episodes,
93
+ num_frames,
94
+ include_sim,
95
+ robot_set,
96
+ include_eval,
97
+ filter_unlabeled_tasks
98
+ ):
99
+ all_data_number = len(df)
100
+ filtered_datasets = df[
101
+ (df['total_episodes'] >= num_episodes) &
102
+ (df['total_frames'] >= num_frames) &
103
+ (df['has_video'] == True) &
104
+ (df['is_sim'] == include_sim) &
105
+ (df['robot_type'].isin(robot_set)) &
106
+ ('test' not in df['repo_id'])
107
+ ]
108
+ if not include_eval:
109
+ filtered_datasets = filtered_datasets[filtered_datasets['is_eval'] == False]
110
+ if filter_unlabeled_tasks:
111
+ filtered_datasets['has_valid_tasks'] = filtered_datasets['tasks'].apply(filter_tasks)
112
+ filtered_datasets = filtered_datasets[filtered_datasets['has_valid_tasks']]
113
+ gr.Info(f"Filtering datasets from {all_data_number} to {len(filtered_datasets)}")
114
+ return len(filtered_datasets), filtered_datasets["repo_id"].to_list(), filtered_datasets
115
+
116
+ class LeRobotAnalysisApp(object):
117
+ def __init__(self, ui_obj):
118
+ self.name = "LeRobot Analysis App"
119
+ self.description = "Analyze LeRobot datasets"
120
+ self.ui_obj = ui_obj
121
+
122
+ # TODO
123
+ def create_app(self):
124
+ with self.ui_obj:
125
+ gr.Markdown("Application to filter & analyze LeRobot datasets")
126
+ filtered_data = gr.DataFrame(visible=False)
127
+ with gr.Tabs():
128
+ with gr.TabItem("1) Extract/Load Data"):
129
+ with gr.Row():
130
+ with gr.Column():
131
+ gr.Markdown("# Extract metadata from HF API")
132
+ gr.Markdown("Choose a set of **tags** (separated by a coma) to select the datasets to extract **metadata** from.")
133
+ gr.Markdown("The final metadata will be saved to a **CSV file**.")
134
+ tags = gr.Textbox(label="Tags", value="LeRobot",
135
+ placeholder="Enter tags separated by comma",
136
+ info="Enter tags separated by comma",
137
+ lines=3)
138
+ btn_extract = gr.Button("Extract Data")
139
+ gr.Markdown("# OR Load from CSV")
140
+ gr.Markdown("If you already downloaded the metadata in CSV, you can directly load it here.")
141
+ file_explorer = gr.FileExplorer(label="Load CSV file", file_count="single")
142
+ btn_load = gr.Button("Load CSV Data")
143
+ with gr.Column():
144
+ out_data = gr.DataFrame()
145
+ btn_extract.click(extract_metadata_fn, [tags], [out_data])
146
+ btn_load.click(load_metadata_fn, [file_explorer], [out_data])
147
+ with gr.TabItem("2) Filter Data"):
148
+ @gr.render(inputs=[out_data])
149
+ def filter_data(out_data):
150
+ if out_data.empty:
151
+ gr.Markdown("# Filtering data")
152
+ gr.Markdown("No data to display : please extract or load metadata first")
153
+ else:
154
+ df = out_data
155
+ min_eps = int(df['total_episodes'].min())
156
+ min_frames = int(df['total_frames'].min())
157
+ robot_types = list(set(df['robot_type'].to_list()))
158
+ robot_types.sort()
159
+ with gr.Row():
160
+ with gr.Column():
161
+ gr.Markdown("# Filtering data")
162
+ gr.Markdown("Filter the extracted datasets to your needs")
163
+ data = gr.DataFrame(label="Dataset Metadata", value=out_data)
164
+ is_sim = gr.Checkbox(label="Include simulation datasets", value=False)
165
+ eps = gr.Number(label="Min episodes ", value=min_eps)
166
+ frames = gr.Number(label="Min frames", value=min_frames)
167
+ robot_type = gr.CheckboxGroup(label="Robot types", choices=robot_types)
168
+ incl_eval = gr.Checkbox(label="Include evaluation datasets", value=False)
169
+ filter_task = gr.Checkbox(label="Filter unlabeled tasks", value=True)
170
+ btn_filter = gr.Button("Filter Data")
171
+ with gr.Column():
172
+ out_num_d = gr.Number(label="Number of datasets", value=0)
173
+ out_text = gr.Text(label="Dataset repo IDs", value="")
174
+ btn_filter.click(filtering_metadata,
175
+ inputs=[data, eps, frames, is_sim, robot_type, incl_eval, filter_task],
176
+ outputs=[out_num_d, out_text, filtered_data])
177
+ with gr.TabItem("3) Analyze Data"):
178
+ @gr.render(inputs=[out_data, filtered_data])
179
+ def analyze_data(out_data, filtered_data):
180
+ if out_data.empty:
181
+ gr.Markdown("# Analyzing data")
182
+ gr.Markdown("No data to display : please extract or load metadata first")
183
+ else:
184
+ with gr.Row():
185
+ with gr.Column():
186
+ if filtered_data.empty:
187
+ gr.BarPlot(out_data, x="robot_type", y="total_episodes", title="Episodes per robot type")
188
+ else:
189
+ actions_df = self.extract_actions_from_tasks(filtered_data['tasks'])
190
+ gr.BarPlot(filtered_data, x="robot_type", y="total_episodes", title="Episodes per robot type")
191
+ gr.BarPlot(actions_df, title="Counting of each actions",
192
+ x="actions",
193
+ y="count",
194
+ x_label="Actions",
195
+ y_label="Count of actions")
196
+
197
+ def extract_actions_from_tasks(self, tasks):
198
+ gr.Info("Extracting actions from tasks, it might take a while...")
199
+ nlp = spacy.load("en_core_web_sm")
200
+ actions = []
201
+ for el in tasks:
202
+ dict_tasks = ast.literal_eval(el)
203
+ for id, task in dict_tasks.items():
204
+ doc = nlp(task)
205
+ for token in doc:
206
+ if token.pos_ == "VERB":
207
+ actions.append(token.lemma_)
208
+ # Remove duplicates
209
+ actions_unique = list(set(actions))
210
+ count_actions = [actions.count(action) for action in actions_unique]
211
+
212
+ return pd.DataFrame({"actions": actions_unique, "count": count_actions})
213
+
214
+ def launch_ui(self):
215
+ self.ui_obj.launch()
216
+
217
+ if __name__ == "__main__":
218
+ app = gr.Blocks()
219
+ ui = LeRobotAnalysisApp(app)
220
+ ui.create_app()
221
+ ui.launch_ui()
222
+
lerobot_datasets.csv ADDED
The diff for this file is too large to render. See raw diff