File size: 6,833 Bytes
393d3de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import torch
import einops
import numpy as np
import torch.nn as nn
from .base import AbstractSSL
from accelerate import Accelerator
from typing import Tuple, Dict, Optional
from ..transformer_encoder import TransformerEncoder, TransformerEncoderConfig
from ..ema import EMA


# https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py#L239
def off_diag(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


def off_diag_cov_loss(x: torch.Tensor) -> torch.Tensor:
    cov = torch.cov(einops.rearrange(x, "... E -> E (...)"))
    return off_diag(cov).square().mean()


accelerator = Accelerator()


class DynaMoSSL(AbstractSSL):
    def __init__(
        self,
        encoder: nn.Module,
        projector: nn.Module,
        window_size: int,
        feature_dim: int,
        projection_dim: int,
        n_layer: int,
        n_head: int,
        n_embd: int,
        dropout: float = 0.0,
        covariance_reg_coef: float = 0.04,
        dynamics_loss_coef: float = 1.0,
        ema_beta: Optional[float] = None,  # None for SimSiam; float for EMA encoder
        beta_scheduling: bool = False,
        projector_use_ema: bool = False,
        lr: float = 1e-4,
        weight_decay: float = 0.0,
        betas: Tuple[float, float] = (0.9, 0.999),
        separate_single_views: bool = True,
    ):
        nn.Module.__init__(self)
        # avoid registering encoder/projector as submodules
        self.__dict__["encoder"] = encoder
        self.__dict__["projector"] = projector
        forward_dynamics_cfg = TransformerEncoderConfig(
            block_size=window_size,
            input_dim=feature_dim + projection_dim,
            n_layer=n_layer,
            n_head=n_head,
            n_embd=n_embd,
            dropout=dropout,
            output_dim=feature_dim,
        )
        self.forward_dynamics = TransformerEncoder(forward_dynamics_cfg)
        self.forward_dynamics_optimizer = self.forward_dynamics.configure_optimizers(
            weight_decay=weight_decay,
            lr=lr,
            betas=betas,
        )
        self.forward_dynamics, self.forward_dynamics_optimizer = accelerator.prepare(
            self.forward_dynamics,
            self.forward_dynamics_optimizer,
        )
        self.covariance_reg_coef = covariance_reg_coef
        self.dynamics_loss_coef = dynamics_loss_coef
        self.ema_beta = ema_beta
        self.beta_scheduling = beta_scheduling
        self.projector_use_ema = projector_use_ema
        if self.ema_beta is not None:
            self.ema_encoder = EMA(self.encoder, self.ema_beta)
            if self.projector_use_ema:
                self.ema_projector = EMA(self.projector, self.ema_beta)
        self.separate_single_views = separate_single_views

    def forward(
        self,
        obs: torch.Tensor,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        Dict[str, torch.Tensor],
    ]:
        obs_enc = self.encoder(obs)
        if self.ema_beta is not None:
            obs_target = self.ema_encoder(obs)  # use EMA encoder as target
            if self.projector_use_ema:
                obs_proj = self.ema_projector(obs_enc)
            else:
                obs_proj = self.projector(obs_enc)
        else:
            obs_target = obs_enc  # use SimSiam target
            obs_proj = self.projector(obs_enc)

        covariance_loss = self._covariance_reg_loss(obs_enc)
        dynamics_loss, dynamics_loss_components = self._forward_dyn_loss(
            obs_enc, obs_proj, obs_target, self.separate_single_views
        )
        total_loss = dynamics_loss + covariance_loss
        loss_components = {
            "total_loss": total_loss,
            **dynamics_loss_components,
            "covariance_loss": covariance_loss,
        }
        return obs_enc, obs_proj, total_loss, loss_components

    def _forward_dyn_loss(
        self,
        obs_enc: torch.Tensor,
        obs_proj: torch.Tensor,
        obs_target: torch.Tensor,
        separate_single_views: bool = True,
    ):
        V = obs_proj.shape[2]  # number of views
        total = torch.zeros(1, device=obs_enc.device)
        loss_components = {}
        if separate_single_views:
            for i in range(V):
                loss = self._forward_dyn_loss_one_pair(
                    obs_enc, obs_proj, obs_target, i, i
                )
                loss *= self.dynamics_loss_coef / V
                total += loss
                loss_components[f"dynamics_loss_{i}_{i}"] = loss
        else:
            total_view_pairs = V * (V - 1)  # w/ order
            for i in range(V):
                for j in range(V):
                    if i == j:
                        continue
                    loss = self._forward_dyn_loss_one_pair(
                        obs_enc, obs_proj, obs_target, i, j
                    )
                    loss *= self.dynamics_loss_coef / total_view_pairs
                    total += loss
                    loss_components[f"dynamics_loss_{i}_{j}"] = loss
        loss_components["dynamics_loss_total"] = total
        if self.ema_beta is not None:
            loss_components["ema_beta"] = torch.Tensor([self.ema_encoder.beta]).to(
                obs_enc.device
            )
        return total, loss_components

    def _forward_dyn_loss_one_pair(
        self,
        obs_enc: torch.Tensor,
        obs_proj: torch.Tensor,
        obs_target: torch.Tensor,
        i: int,
        j: int,
    ):
        forward_dyn_input = torch.cat([obs_enc[:, :-1, j], obs_proj[:, 1:, i]], dim=-1)
        obs_enc_pred = self.forward_dynamics(forward_dyn_input)  # (N, T-1, E)
        loss = (
            1
            - torch.nn.functional.cosine_similarity(
                obs_enc_pred, obs_target[:, 1:, j].detach(), dim=-1
            ).mean()
        )
        return loss

    def _covariance_reg_loss(self, obs_enc: torch.Tensor):
        loss = off_diag_cov_loss(obs_enc)
        return loss * self.covariance_reg_coef

    def adjust_beta(self, epoch: int, max_epoch: int):
        if (self.ema_beta is None) or not self.beta_scheduling or (max_epoch == 0):
            return
        self.ema_encoder.beta = 1.0 - 0.5 * (
            1.0 + np.cos(np.pi * epoch / max_epoch)
        ) * (1.0 - self.ema_beta)
        if self.projector_use_ema:
            self.ema_projector.beta = 1.0 - 0.5 * (
                1.0 + np.cos(np.pi * epoch / max_epoch)
            ) * (1.0 - self.ema_beta)

    def step(self):
        self.forward_dynamics_optimizer.step()
        self.forward_dynamics_optimizer.zero_grad(set_to_none=True)
        if self.ema_beta is not None:
            self.ema_encoder.step(self.encoder)
            if self.projector_use_ema:
                self.ema_projector.step(self.projector)