|
import abc |
|
import utils |
|
import torch |
|
import numpy as np |
|
from torch.utils.data import Dataset |
|
from typing import Optional, Callable |
|
|
|
|
|
class TrajectoryDataset(Dataset, abc.ABC): |
|
""" |
|
A dataset containing trajectories. |
|
TrajectoryDataset[i] returns: (observations, actions, mask) |
|
observations: Tensor[T, ...], T frames of observations |
|
actions: Tensor[T, ...], T frames of actions |
|
mask: Tensor[T]: 0: invalid; 1: valid |
|
""" |
|
|
|
@abc.abstractmethod |
|
def get_seq_length(self, idx): |
|
""" |
|
Returns the length of the idx-th trajectory. |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
class TrajectorySlicerDataset(TrajectoryDataset): |
|
def __init__( |
|
self, |
|
dataset: TrajectoryDataset, |
|
window: int, |
|
action_window: int, |
|
vqbet_get_future_action_chunk: bool = True, |
|
future_conditional: bool = False, |
|
min_future_sep: int = 0, |
|
future_seq_len: Optional[int] = None, |
|
only_sample_tail: bool = False, |
|
transform: Optional[Callable] = None, |
|
use_libero_goal: bool = False, |
|
): |
|
if future_conditional: |
|
assert future_seq_len is not None, "must specify a future_seq_len" |
|
self.dataset = dataset |
|
self.window = window |
|
self.action_window = action_window |
|
self.vqbet_get_future_action_chunk = vqbet_get_future_action_chunk |
|
self.future_conditional = future_conditional |
|
self.min_future_sep = min_future_sep |
|
self.future_seq_len = future_seq_len |
|
self.only_sample_tail = only_sample_tail |
|
self.transform = transform |
|
self.slices = [] |
|
self.use_libero_goal = use_libero_goal |
|
min_seq_length = np.inf |
|
if vqbet_get_future_action_chunk: |
|
min_window_required = window + action_window |
|
else: |
|
min_window_required = max(window, action_window) |
|
for i in range(len(self.dataset)): |
|
T = self.dataset.get_seq_length(i) |
|
min_seq_length = min(T, min_seq_length) |
|
if T - min_window_required < 0: |
|
print( |
|
f"Ignored short sequence #{i}: len={T}, window={min_window_required}" |
|
) |
|
else: |
|
self.slices += [ |
|
(i, 0, end + 1) for end in range(window - 1) |
|
] |
|
self.slices += [ |
|
(i, start, start + window) |
|
for start in range(T - min_window_required) |
|
] |
|
|
|
if min_seq_length < min_window_required: |
|
print( |
|
f"Ignored short sequences. To include all, set window <= {min_seq_length}." |
|
) |
|
|
|
def get_seq_length(self, idx: int) -> int: |
|
if self.future_conditional: |
|
return self.future_seq_len + self.window |
|
else: |
|
return self.window |
|
|
|
def __len__(self): |
|
return len(self.slices) |
|
|
|
def __getitem__(self, idx): |
|
i, start, end = self.slices[idx] |
|
if end - start < self.window: |
|
obs, act, *others = self.dataset[i] |
|
obs = utils.inference.repeat_start_to_length( |
|
obs[start:end], self.window, dim=0 |
|
) |
|
act = utils.inference.repeat_start_to_length( |
|
act[start : end - 1 + self.action_window], |
|
self.window + self.action_window - 1, |
|
dim=0, |
|
) |
|
values = [obs, act] |
|
else: |
|
values = [ |
|
self.dataset[i][0][start:end], |
|
self.dataset[i][1][start : end - 1 + self.action_window], |
|
] |
|
|
|
if self.use_libero_goal: |
|
goals = self.dataset[i][2][start:end] |
|
if end - start < self.window: |
|
goals = utils.inference.repeat_start_to_length( |
|
goals, self.window, dim=0 |
|
) |
|
values.append(goals) |
|
|
|
|
|
if self.transform is not None: |
|
values = self.transform(values) |
|
if len(values) == 2: |
|
values.append(torch.ones([1, 1, 1])) |
|
return tuple(values) |
|
|