dynamo_ssl / datasets /block_pushing.py
jeffacce
initial commit
393d3de
import os
import torch
import einops
import numpy as np
from pathlib import Path
from typing import Optional
from datasets.core import TrajectoryDataset
class PushMultiviewTrajectoryDataset(TrajectoryDataset):
def __init__(
self,
data_directory: os.PathLike,
onehot_goals=False,
subset_fraction: Optional[float] = None,
prefetch: bool = False,
):
self.data_directory = Path(data_directory)
self.states = np.load(self.data_directory / "multimodal_push_observations.npy")
self.actions = np.load(self.data_directory / "multimodal_push_actions.npy")
self.masks = np.load(self.data_directory / "multimodal_push_masks.npy")
self.subset_fraction = subset_fraction
if self.subset_fraction:
assert self.subset_fraction > 0 and self.subset_fraction <= 1
n = int(len(self.states) * self.subset_fraction)
else:
n = len(self.states)
self.states = self.states[:n]
self.actions = self.actions[:n]
self.masks = self.masks[:n]
self.states = torch.from_numpy(self.states).float()
self.actions = torch.from_numpy(self.actions).float() / 0.03
self.masks = torch.from_numpy(self.masks).bool()
self.prefetch = prefetch
if self.prefetch:
self.obses = []
for i in range(n):
vid_path = self.data_directory / "obs_multiview" / f"{i:03d}.pth"
self.obses.append(torch.load(vid_path))
self.onehot_goals = onehot_goals
if self.onehot_goals:
self.goals = torch.load(self.data_directory / "onehot_goals.pth").float()
self.goals = self.goals[:n]
def get_seq_length(self, idx):
return int(self.masks[idx].sum().item())
def get_all_actions(self):
result = []
# mask out invalid actions
for i in range(len(self.masks)):
T = int(self.masks[i].sum().item())
result.append(self.actions[i, :T, :])
return torch.cat(result, dim=0)
def get_frames(self, idx, frames):
if self.prefetch:
obs = self.obses[idx][frames]
else:
obs = torch.load(self.data_directory / "obs_multiview" / f"{idx:03d}.pth")[
frames
]
obs = einops.rearrange(obs, "T V H W C -> T V C H W") / 255.0
act = self.actions[idx, frames]
mask = self.masks[idx, frames]
if self.onehot_goals:
goal = self.goals[idx, frames]
return obs, act, mask, goal
else:
return obs, act, mask
def __getitem__(self, idx):
T = self.masks[idx].sum().int().item()
return self.get_frames(idx, range(T))
def __len__(self):
return len(self.states)