dynamo_ssl / datasets /your_dataset.py
jeffacce
initial commit
393d3de
raw
history blame contribute delete
613 Bytes
import utils
import torch
import numpy as np
from pathlib import Path
from torch.utils.data import TensorDataset
from datasets.core import TrajectoryDataset
class YourTrajectoryDataset(TensorDataset, TrajectoryDataset):
def __init__(self, data_directory):
data_directory = Path(data_directory)
def get_seq_length(self, idx):
raise NotImplementedError
def get_frames(self, idx, frames):
raise NotImplementedError
# return obs / 255.0, actions, masks
def __getitem__(self, idx):
T = self.get_seq_length(idx)
return self.get_frames(idx, range(T))