jeffacce
initial commit
393d3de
import abc
import utils
import torch
import numpy as np
from torch import default_generator, randperm
from torch.utils.data import Dataset, Subset
from typing import Callable, Optional, Sequence, List, Any
from torch.nn.utils.rnn import pad_sequence
# Taken from python 3.5 docs
def _accumulate(iterable, fn=lambda x, y: x + y):
"Return running totals"
# _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
# _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
it = iter(iterable)
try:
total = next(it)
except StopIteration:
return
yield total
for element in it:
total = fn(total, element)
yield total
class TrajectoryDataset(Dataset, abc.ABC):
"""
A dataset containing trajectories.
TrajectoryDataset[i] returns: (observations, actions, mask)
observations: Tensor[T, ...], T frames of observations
actions: Tensor[T, ...], T frames of actions
mask: Tensor[T]: False: invalid; True: valid
"""
@abc.abstractmethod
def get_seq_length(self, idx):
"""
Returns the length of the idx-th trajectory.
"""
raise NotImplementedError
@abc.abstractmethod
def get_frames(self, idx, frames):
"""
Returns the frames from the idx-th trajectory at the specified frames.
Used to speed up slicing.
"""
raise NotImplementedError
class TrajectorySubset(TrajectoryDataset, Subset):
"""
Subset of a trajectory dataset at specified indices.
Args:
dataset (TrajectoryDataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
def __init__(self, dataset: TrajectoryDataset, indices: Sequence[int]):
Subset.__init__(self, dataset, indices)
def get_seq_length(self, idx):
return self.dataset.get_seq_length(self.indices[idx])
def get_all_actions(self):
return self.dataset.get_all_actions()
def get_frames(self, idx, frames):
return self.dataset.get_frames(self.indices[idx], frames)
class TrajectorySlicerDataset:
def __init__(
self,
dataset: TrajectoryDataset,
window: int,
future_conditional: bool = False,
min_future_sep: int = 0,
future_seq_len: Optional[int] = None,
only_sample_tail: bool = False,
transform: Optional[Callable] = None,
num_extra_predicted_actions: Optional[int] = None,
frame_step: int = 1,
repeat_first_frame: bool = False,
):
"""
Slice a trajectory dataset into unique (but overlapping) sequences of length `window`.
dataset: a trajectory dataset that satisfies:
dataset.get_seq_length(i) is implemented to return the length of sequence i
dataset[i] = (observations, actions, mask)
observations: Tensor[T, ...]
actions: Tensor[T, ...]
mask: Tensor[T]
False: invalid
True: valid
window: int
number of timesteps to include in each slice
future_conditional: bool = False
if True, observations will be augmented with future observations sampled from the same trajectory
min_future_sep: int = 0
minimum number of timesteps between the end of the current sequence and the start of the future sequence
for the future conditional
future_seq_len: Optional[int] = None
the length of the future conditional sequence;
required if future_conditional is True
only_sample_tail: bool = False
if True, only sample future sequences from the tail of the trajectory
transform: function (observations, actions, mask[, goal]) -> (observations, actions, mask[, goal])
"""
if future_conditional:
assert future_seq_len is not None, "must specify a future_seq_len"
self.dataset = dataset
self.window = window
self.future_conditional = future_conditional
self.min_future_sep = min_future_sep
self.future_seq_len = future_seq_len
self.only_sample_tail = only_sample_tail
self.transform = transform
self.num_extra_predicted_actions = num_extra_predicted_actions or 0
self.slices = []
self.frame_step = frame_step
min_seq_length = np.inf
if num_extra_predicted_actions:
window = window + num_extra_predicted_actions
for i in range(len(self.dataset)): # type: ignore
T = self.dataset.get_seq_length(i) # avoid reading actual seq (slow)
min_seq_length = min(T, min_seq_length)
if T - window < 0:
print(f"Ignored short sequence #{i}: len={T}, window={window}")
else:
if repeat_first_frame:
self.slices += [(i, 0, end + 1) for end in range(window - 1)]
window_len_with_step = (window - 1) * frame_step + 1
last_start = T - window_len_with_step
self.slices += [
(i, start, start + window_len_with_step)
for start in range(last_start)
] # slice indices follow convention [start, end)
if min_seq_length < window:
print(
f"Ignored short sequences. To include all, set window <= {min_seq_length}."
)
def get_seq_length(self, idx: int) -> int:
if self.future_conditional:
return self.future_seq_len + self.window
else:
return self.window
def get_all_actions(self) -> torch.Tensor:
return self.dataset.get_all_actions()
def __len__(self):
return len(self.slices)
def __getitem__(self, idx):
i, start, end = self.slices[idx]
T = self.dataset.get_seq_length(i)
if (
self.num_extra_predicted_actions is not None
and self.num_extra_predicted_actions != 0
):
assert self.frame_step == 1, "NOT TESTED"
if self.future_conditional:
raise NotImplementedError(
"num_extra_predicted_actions with future_conditional not implemented"
)
assert end <= T, f"end={end} > T={T}"
observations, actions, mask = self.dataset.get_frames(i, range(start, end))
observations = observations[: self.window]
values = [observations, actions, mask.bool()]
else:
if self.future_conditional:
assert self.frame_step == 1, "NOT TESTED"
valid_start_range = (
end + self.min_future_sep,
self.dataset.get_seq_length(i) - self.future_seq_len,
)
if valid_start_range[0] < valid_start_range[1]:
if self.only_sample_tail:
future_obs_range = range(T - self.future_seq_len, T)
else:
future_start = np.random.randint(*valid_start_range)
future_end = future_start + self.future_seq_len
future_obs_range = range(future_start, future_end)
obs, actions, mask = self.dataset.get_frames(
i, list(range(start, end)) + list(future_obs_range)
)
future_obs = obs[end - start :]
obs = obs[: end - start]
actions = actions[: end - start]
mask = mask[: end - start]
else:
# zeros placeholder T x obs_dim
obs, actions, mask = self.dataset.get_frames(i, range(start, end))
obs_dims = obs.shape[1:]
future_obs = torch.zeros((self.future_seq_len, *obs_dims))
# [observations, actions, mask, future_obs (goal conditional)]
values = [obs, actions, mask.bool(), future_obs]
else:
observations, actions, mask = self.dataset.get_frames(
i, range(start, end, self.frame_step)
)
values = [observations, actions, mask.bool()]
if end - start < self.window + self.num_extra_predicted_actions:
# this only happens for repeating the very first frames
values = [
utils.inference.repeat_start_to_length(
x, self.window + self.num_extra_predicted_actions, dim=0
)
for x in values
]
values[0] = values[0][: self.window]
# optionally apply transform
if self.transform is not None:
values = self.transform(values)
return tuple(values)
class TrajectoryEmbeddingDataset(TrajectoryDataset):
def __init__(
self,
model,
dataset: TrajectoryDataset,
device="cpu",
embed_goal=False,
):
self.data = utils.inference.embed_trajectory_dataset(
model,
dataset,
obs_only=False,
device=device,
embed_goal=embed_goal,
)
assert len(self.data) == len(dataset)
self.seq_lengths = [len(x[0]) for x in self.data]
self.on_device_data = []
n_tensors = len(self.data[0])
for i in range(n_tensors):
self.on_device_data.append(
pad_sequence([x[i] for x in self.data], batch_first=True).to(device)
)
self.data = self.on_device_data
def get_seq_length(self, idx):
return self.seq_lengths[idx]
def get_all_actions(self):
return torch.cat([x[1] for x in self.data], dim=0)
def get_frames(self, idx, frames):
return [x[idx, frames] for x in self.data]
def __getitem__(self, idx):
return self.get_frames(idx, range(self.get_seq_length(idx)))
def __len__(self):
return len(self.seq_lengths)
def get_train_val_sliced(
traj_dataset: TrajectoryDataset,
train_fraction: float = 0.9,
random_seed: int = 42,
window_size: int = 10,
future_conditional: bool = False,
min_future_sep: int = 0,
future_seq_len: Optional[int] = None,
only_sample_tail: bool = False,
transform: Optional[Callable[[Any], Any]] = None,
num_extra_predicted_actions: Optional[int] = None,
frame_step: int = 1,
):
train, val = split_traj_datasets(
traj_dataset,
train_fraction=train_fraction,
random_seed=random_seed,
)
traj_slicer_kwargs = {
"window": window_size,
"future_conditional": future_conditional,
"min_future_sep": min_future_sep,
"future_seq_len": future_seq_len,
"only_sample_tail": only_sample_tail,
"transform": transform,
"num_extra_predicted_actions": num_extra_predicted_actions,
"frame_step": frame_step,
}
train_slices = TrajectorySlicerDataset(train, **traj_slicer_kwargs)
val_slices = TrajectorySlicerDataset(val, **traj_slicer_kwargs)
return train_slices, val_slices
def random_split_traj(
dataset: TrajectoryDataset,
lengths: Sequence[int],
generator: Optional[torch.Generator] = default_generator,
) -> List[TrajectorySubset]:
"""
(Modified from torch.utils.data.dataset.random_split)
Randomly split a trajectory dataset into non-overlapping new datasets of given lengths.
Optionally fix the generator for reproducible results, e.g.:
>>> random_split_traj(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
Args:
dataset (TrajectoryDataset): TrajectoryDataset to be split
lengths (sequence): lengths of splits to be produced
generator (Generator): Generator used for the random permutation.
"""
# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError(
"Sum of input lengths does not equal the length of the input dataset!"
)
indices = randperm(sum(lengths), generator=generator).tolist()
return [
TrajectorySubset(dataset, indices[offset - length : offset])
for offset, length in zip(_accumulate(lengths), lengths)
]
def split_traj_datasets(dataset, train_fraction=0.95, random_seed=42):
dataset_length = len(dataset)
lengths = [
int(train_fraction * dataset_length),
dataset_length - int(train_fraction * dataset_length),
]
train_set, val_set = random_split_traj(
dataset, lengths, generator=torch.Generator().manual_seed(random_seed)
)
return train_set, val_set