File size: 2,139 Bytes
393d3de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import einops
import pickle
from pathlib import Path
from typing import Optional
from datasets.core import TrajectoryDataset


class PushTDataset(TrajectoryDataset):
    def __init__(
        self,
        data_directory,
        subset_fraction: Optional[float] = None,
        relative=False,
    ):
        self.data_directory = Path(data_directory)
        self.relative = relative
        self.states = torch.load(self.data_directory / "states.pth")
        if relative:
            self.actions = torch.load(self.data_directory / "rel_actions.pth")
        else:
            self.actions = torch.load(self.data_directory / "abs_actions.pth")
        with open(self.data_directory / "seq_lengths.pkl", "rb") as f:
            self.seq_lengths = pickle.load(f)

        self.subset_fraction = subset_fraction
        if self.subset_fraction:
            assert self.subset_fraction > 0 and self.subset_fraction <= 1
            n = int(len(self.states) * self.subset_fraction)
        else:
            n = len(self.states)
        self.states = self.states[:n]
        self.actions = self.actions[:n]
        self.seq_lengths = self.seq_lengths[:n]

        for i in range(n):
            T = self.seq_lengths[i]
            self.actions[i, T:] = 0  # redo zero padding

    def get_seq_length(self, idx):
        return self.seq_lengths[idx]

    def get_all_actions(self):
        result = []
        for i in range(len(self.seq_lengths)):
            T = self.seq_lengths[i]
            result.append(self.actions[i, :T, :])
        return torch.cat(result, dim=0)

    def get_frames(self, idx, frames):
        vid_dir = self.data_directory / "obses"
        obs = torch.load(str(vid_dir / f"episode_{idx:03d}.pth"))
        obs = obs[frames]  # THWC
        obs = einops.rearrange(obs, "T H W C -> T 1 C H W") / 255.0  # T V C H W, 1 view
        act = self.actions[idx, frames]
        mask = torch.ones(len(act)).bool()
        return obs, act, mask

    def __getitem__(self, idx):
        return self.get_frames(idx, range(self.get_seq_length(idx)))

    def __len__(self):
        return len(self.seq_lengths)