|
import os |
|
import torch |
|
import einops |
|
import numpy as np |
|
from pathlib import Path |
|
from typing import Optional |
|
from datasets.core import TrajectoryDataset |
|
|
|
|
|
class PushMultiviewTrajectoryDataset(TrajectoryDataset): |
|
def __init__( |
|
self, |
|
data_directory: os.PathLike, |
|
onehot_goals=False, |
|
subset_fraction: Optional[float] = None, |
|
prefetch: bool = False, |
|
): |
|
self.data_directory = Path(data_directory) |
|
self.states = np.load(self.data_directory / "multimodal_push_observations.npy") |
|
self.actions = np.load(self.data_directory / "multimodal_push_actions.npy") |
|
self.masks = np.load(self.data_directory / "multimodal_push_masks.npy") |
|
|
|
self.subset_fraction = subset_fraction |
|
if self.subset_fraction: |
|
assert self.subset_fraction > 0 and self.subset_fraction <= 1 |
|
n = int(len(self.states) * self.subset_fraction) |
|
else: |
|
n = len(self.states) |
|
self.states = self.states[:n] |
|
self.actions = self.actions[:n] |
|
self.masks = self.masks[:n] |
|
|
|
self.states = torch.from_numpy(self.states).float() |
|
self.actions = torch.from_numpy(self.actions).float() / 0.03 |
|
self.masks = torch.from_numpy(self.masks).bool() |
|
self.prefetch = prefetch |
|
if self.prefetch: |
|
self.obses = [] |
|
for i in range(n): |
|
vid_path = self.data_directory / "obs_multiview" / f"{i:03d}.pth" |
|
self.obses.append(torch.load(vid_path)) |
|
self.onehot_goals = onehot_goals |
|
if self.onehot_goals: |
|
self.goals = torch.load(self.data_directory / "onehot_goals.pth").float() |
|
self.goals = self.goals[:n] |
|
|
|
def get_seq_length(self, idx): |
|
return int(self.masks[idx].sum().item()) |
|
|
|
def get_all_actions(self): |
|
result = [] |
|
|
|
for i in range(len(self.masks)): |
|
T = int(self.masks[i].sum().item()) |
|
result.append(self.actions[i, :T, :]) |
|
return torch.cat(result, dim=0) |
|
|
|
def get_frames(self, idx, frames): |
|
if self.prefetch: |
|
obs = self.obses[idx][frames] |
|
else: |
|
obs = torch.load(self.data_directory / "obs_multiview" / f"{idx:03d}.pth")[ |
|
frames |
|
] |
|
obs = einops.rearrange(obs, "T V H W C -> T V C H W") / 255.0 |
|
act = self.actions[idx, frames] |
|
mask = self.masks[idx, frames] |
|
if self.onehot_goals: |
|
goal = self.goals[idx, frames] |
|
return obs, act, mask, goal |
|
else: |
|
return obs, act, mask |
|
|
|
def __getitem__(self, idx): |
|
T = self.masks[idx].sum().int().item() |
|
return self.get_frames(idx, range(T)) |
|
|
|
def __len__(self): |
|
return len(self.states) |
|
|