|
import utils |
|
import torch |
|
import numpy as np |
|
from pathlib import Path |
|
from datasets.core import TrajectoryDataset |
|
|
|
|
|
class SimKitchenTrajectoryDataset(TrajectoryDataset): |
|
def __init__(self, data_directory, prefetch=True, onehot_goals=False): |
|
self.data_directory = Path(data_directory) |
|
states = torch.from_numpy(np.load(self.data_directory / "observations_seq.npy")) |
|
actions = torch.from_numpy(np.load(self.data_directory / "actions_seq.npy")) |
|
goals = torch.load(self.data_directory / "onehot_goals.pth") |
|
|
|
self.states, self.actions, self.goals = utils.transpose_batch_timestep( |
|
states, actions, goals |
|
) |
|
self.Ts = np.load(self.data_directory / "existence_mask.npy").sum(axis=0).astype(int).tolist() |
|
|
|
self.prefetch = prefetch |
|
if self.prefetch: |
|
self.obses = [] |
|
for i in range(len(self.Ts)): |
|
self.obses.append(torch.load(self.data_directory / "obses" / f"{i:03d}.pth")) |
|
self.onehot_goals = onehot_goals |
|
|
|
def get_seq_length(self, idx): |
|
return self.Ts[idx] |
|
|
|
def get_all_actions(self): |
|
result = [] |
|
|
|
for i in range(len(self.Ts)): |
|
T = self.Ts[i] |
|
result.append(self.actions[i, :T, :]) |
|
return torch.cat(result, dim=0) |
|
|
|
def get_frames(self, idx, frames): |
|
|
|
if self.prefetch: |
|
obs = self.obses[idx][frames] |
|
else: |
|
obs = torch.load(self.data_directory / "obses" / f"{idx:03d}.pth")[frames] |
|
obs = obs / 255.0 |
|
act = self.actions[idx, frames] |
|
mask = torch.ones((len(frames))) |
|
if self.onehot_goals: |
|
goal = self.goals[idx, frames] |
|
return obs, act, mask, goal |
|
else: |
|
return obs, act, mask |
|
|
|
def __getitem__(self, idx): |
|
T = self.Ts[idx] |
|
return self.get_frames(idx, range(T)) |
|
|
|
def __len__(self): |
|
return len(self.Ts) |