dynamo_ssl / datasets /vqbet_repro.py
jeffacce
initial commit
393d3de
import abc
import utils
import torch
import numpy as np
from torch.utils.data import Dataset
from typing import Optional, Callable
class TrajectoryDataset(Dataset, abc.ABC):
"""
A dataset containing trajectories.
TrajectoryDataset[i] returns: (observations, actions, mask)
observations: Tensor[T, ...], T frames of observations
actions: Tensor[T, ...], T frames of actions
mask: Tensor[T]: 0: invalid; 1: valid
"""
@abc.abstractmethod
def get_seq_length(self, idx):
"""
Returns the length of the idx-th trajectory.
"""
raise NotImplementedError
class TrajectorySlicerDataset(TrajectoryDataset):
def __init__(
self,
dataset: TrajectoryDataset,
window: int,
action_window: int,
vqbet_get_future_action_chunk: bool = True,
future_conditional: bool = False,
min_future_sep: int = 0,
future_seq_len: Optional[int] = None,
only_sample_tail: bool = False,
transform: Optional[Callable] = None,
use_libero_goal: bool = False,
):
if future_conditional:
assert future_seq_len is not None, "must specify a future_seq_len"
self.dataset = dataset
self.window = window
self.action_window = action_window
self.vqbet_get_future_action_chunk = vqbet_get_future_action_chunk
self.future_conditional = future_conditional
self.min_future_sep = min_future_sep
self.future_seq_len = future_seq_len
self.only_sample_tail = only_sample_tail
self.transform = transform
self.slices = []
self.use_libero_goal = use_libero_goal
min_seq_length = np.inf
if vqbet_get_future_action_chunk:
min_window_required = window + action_window
else:
min_window_required = max(window, action_window)
for i in range(len(self.dataset)): # type: ignore
T = self.dataset.get_seq_length(i) # avoid reading actual seq (slow)
min_seq_length = min(T, min_seq_length)
if T - min_window_required < 0:
print(
f"Ignored short sequence #{i}: len={T}, window={min_window_required}"
)
else:
self.slices += [
(i, 0, end + 1) for end in range(window - 1)
] # slice indices follow convention [start, end)
self.slices += [
(i, start, start + window)
for start in range(T - min_window_required)
] # slice indices follow convention [start, end)
if min_seq_length < min_window_required:
print(
f"Ignored short sequences. To include all, set window <= {min_seq_length}."
)
def get_seq_length(self, idx: int) -> int:
if self.future_conditional:
return self.future_seq_len + self.window
else:
return self.window
def __len__(self):
return len(self.slices)
def __getitem__(self, idx):
i, start, end = self.slices[idx]
if end - start < self.window:
obs, act, *others = self.dataset[i]
obs = utils.inference.repeat_start_to_length(
obs[start:end], self.window, dim=0
)
act = utils.inference.repeat_start_to_length(
act[start : end - 1 + self.action_window],
self.window + self.action_window - 1,
dim=0,
)
values = [obs, act]
else:
values = [
self.dataset[i][0][start:end],
self.dataset[i][1][start : end - 1 + self.action_window],
]
if self.use_libero_goal:
goals = self.dataset[i][2][start:end]
if end - start < self.window:
goals = utils.inference.repeat_start_to_length(
goals, self.window, dim=0
)
values.append(goals)
# optionally apply transform
if self.transform is not None:
values = self.transform(values)
if len(values) == 2: # placeholder goal
values.append(torch.ones([1, 1, 1]))
return tuple(values)