|
import torch |
|
import einops |
|
import numpy as np |
|
from pathlib import Path |
|
from typing import Optional |
|
from torch.nn.utils.rnn import pad_sequence |
|
from datasets.core import TrajectoryDataset |
|
|
|
|
|
class LiberoGoalDataset(TrajectoryDataset): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, data_directory, subset_fraction: Optional[float] = None): |
|
self.dir = Path(data_directory) / "libero_goal" |
|
self.task_names = list(self.dir.iterdir()) |
|
self.task_names.sort() |
|
self.demos = [] |
|
for task_name in self.task_names: |
|
self.demos += list(task_name.iterdir()) |
|
|
|
self.subset_fraction = subset_fraction |
|
if self.subset_fraction: |
|
assert 0 < self.subset_fraction <= 1 |
|
n = int(len(self.demos) * self.subset_fraction) |
|
self.demos = self.demos[:n] |
|
|
|
|
|
self.joint_pos = [] |
|
self.eef = [] |
|
self.gripper_qpos = [] |
|
self.object_states = [] |
|
self.states = [] |
|
self.actions = [] |
|
for demo in self.demos: |
|
self.joint_pos.append(np.load(demo / "robot0_joint_pos.npy")) |
|
self.eef.append(np.load(demo / "robot0_eef.npy")) |
|
self.gripper_qpos.append(np.load(demo / "robot0_gripper_pos.npy")) |
|
self.object_states.append(np.load(demo / "object_states.npy")) |
|
state = np.concatenate( |
|
[ |
|
self.joint_pos[-1], |
|
self.eef[-1], |
|
self.gripper_qpos[-1], |
|
self.object_states[-1], |
|
], |
|
axis=1, |
|
) |
|
act = np.load(demo / "actions.npy") |
|
self.states.append(torch.from_numpy(state)) |
|
self.actions.append(torch.from_numpy(act)) |
|
|
|
|
|
MAX_DIM = 128 |
|
for i in range(len(self.states)): |
|
self.states[i] = torch.cat( |
|
[ |
|
self.states[i], |
|
torch.zeros( |
|
self.states[i].shape[0], MAX_DIM - self.states[i].shape[1] |
|
), |
|
], |
|
dim=1, |
|
) |
|
|
|
self.states = pad_sequence(self.states, batch_first=True).float() |
|
self.actions = pad_sequence(self.actions, batch_first=True).float() |
|
|
|
|
|
self.goals = None |
|
goals = [] |
|
for i in range(0, 500, 50): |
|
last_obs, _, _ = self.get_frames(i, [-1]) |
|
goals.append(last_obs) |
|
self.goals = goals |
|
|
|
def __len__(self): |
|
return len(self.demos) |
|
|
|
def get_frames(self, idx, frames): |
|
demo = self.demos[idx] |
|
agentview_obs = torch.load( |
|
str(demo / "agentview_image.pth"), |
|
) |
|
robotview_obs = torch.load( |
|
str(demo / "robot0_eye_in_hand_image.pth"), |
|
) |
|
agentview = agentview_obs[frames] |
|
robotview = robotview_obs[frames] |
|
obs = torch.stack([agentview, robotview], dim=1) |
|
obs = einops.rearrange(obs, "T V H W C -> T V C H W") / 255.0 |
|
act = self.actions[idx][frames] |
|
|
|
if self.goals is not None: |
|
task_idx = idx // 50 |
|
goal = self.goals[task_idx].repeat(len(frames), 1, 1, 1, 1) |
|
return obs, act, goal |
|
else: |
|
return obs, act, None |
|
|
|
def __getitem__(self, idx): |
|
return self.get_frames(idx, range(len(self.joint_pos[idx]))) |
|
|
|
def get_seq_length(self, idx): |
|
return len(self.joint_pos[idx]) |
|
|
|
def get_all_actions(self): |
|
actions = [] |
|
for i in range(len(self.demos)): |
|
T = self.get_seq_length(i) |
|
actions.append(self.actions[i][:T]) |
|
return torch.cat(actions, dim=0) |
|
|