dynamo_ssl / datasets /pusht.py
jeffacce
initial commit
393d3de
import torch
import einops
import pickle
from pathlib import Path
from typing import Optional
from datasets.core import TrajectoryDataset
class PushTDataset(TrajectoryDataset):
def __init__(
self,
data_directory,
subset_fraction: Optional[float] = None,
relative=False,
):
self.data_directory = Path(data_directory)
self.relative = relative
self.states = torch.load(self.data_directory / "states.pth")
if relative:
self.actions = torch.load(self.data_directory / "rel_actions.pth")
else:
self.actions = torch.load(self.data_directory / "abs_actions.pth")
with open(self.data_directory / "seq_lengths.pkl", "rb") as f:
self.seq_lengths = pickle.load(f)
self.subset_fraction = subset_fraction
if self.subset_fraction:
assert self.subset_fraction > 0 and self.subset_fraction <= 1
n = int(len(self.states) * self.subset_fraction)
else:
n = len(self.states)
self.states = self.states[:n]
self.actions = self.actions[:n]
self.seq_lengths = self.seq_lengths[:n]
for i in range(n):
T = self.seq_lengths[i]
self.actions[i, T:] = 0 # redo zero padding
def get_seq_length(self, idx):
return self.seq_lengths[idx]
def get_all_actions(self):
result = []
for i in range(len(self.seq_lengths)):
T = self.seq_lengths[i]
result.append(self.actions[i, :T, :])
return torch.cat(result, dim=0)
def get_frames(self, idx, frames):
vid_dir = self.data_directory / "obses"
obs = torch.load(str(vid_dir / f"episode_{idx:03d}.pth"))
obs = obs[frames] # THWC
obs = einops.rearrange(obs, "T H W C -> T 1 C H W") / 255.0 # T V C H W, 1 view
act = self.actions[idx, frames]
mask = torch.ones(len(act)).bool()
return obs, act, mask
def __getitem__(self, idx):
return self.get_frames(idx, range(self.get_seq_length(idx)))
def __len__(self):
return len(self.seq_lengths)