File size: 6,833 Bytes
393d3de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import torch
import einops
import numpy as np
import torch.nn as nn
from .base import AbstractSSL
from accelerate import Accelerator
from typing import Tuple, Dict, Optional
from ..transformer_encoder import TransformerEncoder, TransformerEncoderConfig
from ..ema import EMA
# https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py#L239
def off_diag(x):
n, m = x.shape
assert n == m
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
def off_diag_cov_loss(x: torch.Tensor) -> torch.Tensor:
cov = torch.cov(einops.rearrange(x, "... E -> E (...)"))
return off_diag(cov).square().mean()
accelerator = Accelerator()
class DynaMoSSL(AbstractSSL):
def __init__(
self,
encoder: nn.Module,
projector: nn.Module,
window_size: int,
feature_dim: int,
projection_dim: int,
n_layer: int,
n_head: int,
n_embd: int,
dropout: float = 0.0,
covariance_reg_coef: float = 0.04,
dynamics_loss_coef: float = 1.0,
ema_beta: Optional[float] = None, # None for SimSiam; float for EMA encoder
beta_scheduling: bool = False,
projector_use_ema: bool = False,
lr: float = 1e-4,
weight_decay: float = 0.0,
betas: Tuple[float, float] = (0.9, 0.999),
separate_single_views: bool = True,
):
nn.Module.__init__(self)
# avoid registering encoder/projector as submodules
self.__dict__["encoder"] = encoder
self.__dict__["projector"] = projector
forward_dynamics_cfg = TransformerEncoderConfig(
block_size=window_size,
input_dim=feature_dim + projection_dim,
n_layer=n_layer,
n_head=n_head,
n_embd=n_embd,
dropout=dropout,
output_dim=feature_dim,
)
self.forward_dynamics = TransformerEncoder(forward_dynamics_cfg)
self.forward_dynamics_optimizer = self.forward_dynamics.configure_optimizers(
weight_decay=weight_decay,
lr=lr,
betas=betas,
)
self.forward_dynamics, self.forward_dynamics_optimizer = accelerator.prepare(
self.forward_dynamics,
self.forward_dynamics_optimizer,
)
self.covariance_reg_coef = covariance_reg_coef
self.dynamics_loss_coef = dynamics_loss_coef
self.ema_beta = ema_beta
self.beta_scheduling = beta_scheduling
self.projector_use_ema = projector_use_ema
if self.ema_beta is not None:
self.ema_encoder = EMA(self.encoder, self.ema_beta)
if self.projector_use_ema:
self.ema_projector = EMA(self.projector, self.ema_beta)
self.separate_single_views = separate_single_views
def forward(
self,
obs: torch.Tensor,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
Dict[str, torch.Tensor],
]:
obs_enc = self.encoder(obs)
if self.ema_beta is not None:
obs_target = self.ema_encoder(obs) # use EMA encoder as target
if self.projector_use_ema:
obs_proj = self.ema_projector(obs_enc)
else:
obs_proj = self.projector(obs_enc)
else:
obs_target = obs_enc # use SimSiam target
obs_proj = self.projector(obs_enc)
covariance_loss = self._covariance_reg_loss(obs_enc)
dynamics_loss, dynamics_loss_components = self._forward_dyn_loss(
obs_enc, obs_proj, obs_target, self.separate_single_views
)
total_loss = dynamics_loss + covariance_loss
loss_components = {
"total_loss": total_loss,
**dynamics_loss_components,
"covariance_loss": covariance_loss,
}
return obs_enc, obs_proj, total_loss, loss_components
def _forward_dyn_loss(
self,
obs_enc: torch.Tensor,
obs_proj: torch.Tensor,
obs_target: torch.Tensor,
separate_single_views: bool = True,
):
V = obs_proj.shape[2] # number of views
total = torch.zeros(1, device=obs_enc.device)
loss_components = {}
if separate_single_views:
for i in range(V):
loss = self._forward_dyn_loss_one_pair(
obs_enc, obs_proj, obs_target, i, i
)
loss *= self.dynamics_loss_coef / V
total += loss
loss_components[f"dynamics_loss_{i}_{i}"] = loss
else:
total_view_pairs = V * (V - 1) # w/ order
for i in range(V):
for j in range(V):
if i == j:
continue
loss = self._forward_dyn_loss_one_pair(
obs_enc, obs_proj, obs_target, i, j
)
loss *= self.dynamics_loss_coef / total_view_pairs
total += loss
loss_components[f"dynamics_loss_{i}_{j}"] = loss
loss_components["dynamics_loss_total"] = total
if self.ema_beta is not None:
loss_components["ema_beta"] = torch.Tensor([self.ema_encoder.beta]).to(
obs_enc.device
)
return total, loss_components
def _forward_dyn_loss_one_pair(
self,
obs_enc: torch.Tensor,
obs_proj: torch.Tensor,
obs_target: torch.Tensor,
i: int,
j: int,
):
forward_dyn_input = torch.cat([obs_enc[:, :-1, j], obs_proj[:, 1:, i]], dim=-1)
obs_enc_pred = self.forward_dynamics(forward_dyn_input) # (N, T-1, E)
loss = (
1
- torch.nn.functional.cosine_similarity(
obs_enc_pred, obs_target[:, 1:, j].detach(), dim=-1
).mean()
)
return loss
def _covariance_reg_loss(self, obs_enc: torch.Tensor):
loss = off_diag_cov_loss(obs_enc)
return loss * self.covariance_reg_coef
def adjust_beta(self, epoch: int, max_epoch: int):
if (self.ema_beta is None) or not self.beta_scheduling or (max_epoch == 0):
return
self.ema_encoder.beta = 1.0 - 0.5 * (
1.0 + np.cos(np.pi * epoch / max_epoch)
) * (1.0 - self.ema_beta)
if self.projector_use_ema:
self.ema_projector.beta = 1.0 - 0.5 * (
1.0 + np.cos(np.pi * epoch / max_epoch)
) * (1.0 - self.ema_beta)
def step(self):
self.forward_dynamics_optimizer.step()
self.forward_dynamics_optimizer.zero_grad(set_to_none=True)
if self.ema_beta is not None:
self.ema_encoder.step(self.encoder)
if self.projector_use_ema:
self.ema_projector.step(self.projector)
|