|
import torch |
|
import einops |
|
import numpy as np |
|
import torch.nn as nn |
|
from .base import AbstractSSL |
|
from accelerate import Accelerator |
|
from typing import Tuple, Dict, Optional |
|
from ..transformer_encoder import TransformerEncoder, TransformerEncoderConfig |
|
from ..ema import EMA |
|
|
|
|
|
|
|
def off_diag(x): |
|
n, m = x.shape |
|
assert n == m |
|
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() |
|
|
|
|
|
def off_diag_cov_loss(x: torch.Tensor) -> torch.Tensor: |
|
cov = torch.cov(einops.rearrange(x, "... E -> E (...)")) |
|
return off_diag(cov).square().mean() |
|
|
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
class DynaMoSSL(AbstractSSL): |
|
def __init__( |
|
self, |
|
encoder: nn.Module, |
|
projector: nn.Module, |
|
window_size: int, |
|
feature_dim: int, |
|
projection_dim: int, |
|
n_layer: int, |
|
n_head: int, |
|
n_embd: int, |
|
dropout: float = 0.0, |
|
covariance_reg_coef: float = 0.04, |
|
dynamics_loss_coef: float = 1.0, |
|
ema_beta: Optional[float] = None, |
|
beta_scheduling: bool = False, |
|
projector_use_ema: bool = False, |
|
lr: float = 1e-4, |
|
weight_decay: float = 0.0, |
|
betas: Tuple[float, float] = (0.9, 0.999), |
|
separate_single_views: bool = True, |
|
): |
|
nn.Module.__init__(self) |
|
|
|
self.__dict__["encoder"] = encoder |
|
self.__dict__["projector"] = projector |
|
forward_dynamics_cfg = TransformerEncoderConfig( |
|
block_size=window_size, |
|
input_dim=feature_dim + projection_dim, |
|
n_layer=n_layer, |
|
n_head=n_head, |
|
n_embd=n_embd, |
|
dropout=dropout, |
|
output_dim=feature_dim, |
|
) |
|
self.forward_dynamics = TransformerEncoder(forward_dynamics_cfg) |
|
self.forward_dynamics_optimizer = self.forward_dynamics.configure_optimizers( |
|
weight_decay=weight_decay, |
|
lr=lr, |
|
betas=betas, |
|
) |
|
self.forward_dynamics, self.forward_dynamics_optimizer = accelerator.prepare( |
|
self.forward_dynamics, |
|
self.forward_dynamics_optimizer, |
|
) |
|
self.covariance_reg_coef = covariance_reg_coef |
|
self.dynamics_loss_coef = dynamics_loss_coef |
|
self.ema_beta = ema_beta |
|
self.beta_scheduling = beta_scheduling |
|
self.projector_use_ema = projector_use_ema |
|
if self.ema_beta is not None: |
|
self.ema_encoder = EMA(self.encoder, self.ema_beta) |
|
if self.projector_use_ema: |
|
self.ema_projector = EMA(self.projector, self.ema_beta) |
|
self.separate_single_views = separate_single_views |
|
|
|
def forward( |
|
self, |
|
obs: torch.Tensor, |
|
) -> Tuple[ |
|
torch.Tensor, |
|
torch.Tensor, |
|
torch.Tensor, |
|
Dict[str, torch.Tensor], |
|
]: |
|
obs_enc = self.encoder(obs) |
|
if self.ema_beta is not None: |
|
obs_target = self.ema_encoder(obs) |
|
if self.projector_use_ema: |
|
obs_proj = self.ema_projector(obs_enc) |
|
else: |
|
obs_proj = self.projector(obs_enc) |
|
else: |
|
obs_target = obs_enc |
|
obs_proj = self.projector(obs_enc) |
|
|
|
covariance_loss = self._covariance_reg_loss(obs_enc) |
|
dynamics_loss, dynamics_loss_components = self._forward_dyn_loss( |
|
obs_enc, obs_proj, obs_target, self.separate_single_views |
|
) |
|
total_loss = dynamics_loss + covariance_loss |
|
loss_components = { |
|
"total_loss": total_loss, |
|
**dynamics_loss_components, |
|
"covariance_loss": covariance_loss, |
|
} |
|
return obs_enc, obs_proj, total_loss, loss_components |
|
|
|
def _forward_dyn_loss( |
|
self, |
|
obs_enc: torch.Tensor, |
|
obs_proj: torch.Tensor, |
|
obs_target: torch.Tensor, |
|
separate_single_views: bool = True, |
|
): |
|
V = obs_proj.shape[2] |
|
total = torch.zeros(1, device=obs_enc.device) |
|
loss_components = {} |
|
if separate_single_views: |
|
for i in range(V): |
|
loss = self._forward_dyn_loss_one_pair( |
|
obs_enc, obs_proj, obs_target, i, i |
|
) |
|
loss *= self.dynamics_loss_coef / V |
|
total += loss |
|
loss_components[f"dynamics_loss_{i}_{i}"] = loss |
|
else: |
|
total_view_pairs = V * (V - 1) |
|
for i in range(V): |
|
for j in range(V): |
|
if i == j: |
|
continue |
|
loss = self._forward_dyn_loss_one_pair( |
|
obs_enc, obs_proj, obs_target, i, j |
|
) |
|
loss *= self.dynamics_loss_coef / total_view_pairs |
|
total += loss |
|
loss_components[f"dynamics_loss_{i}_{j}"] = loss |
|
loss_components["dynamics_loss_total"] = total |
|
if self.ema_beta is not None: |
|
loss_components["ema_beta"] = torch.Tensor([self.ema_encoder.beta]).to( |
|
obs_enc.device |
|
) |
|
return total, loss_components |
|
|
|
def _forward_dyn_loss_one_pair( |
|
self, |
|
obs_enc: torch.Tensor, |
|
obs_proj: torch.Tensor, |
|
obs_target: torch.Tensor, |
|
i: int, |
|
j: int, |
|
): |
|
forward_dyn_input = torch.cat([obs_enc[:, :-1, j], obs_proj[:, 1:, i]], dim=-1) |
|
obs_enc_pred = self.forward_dynamics(forward_dyn_input) |
|
loss = ( |
|
1 |
|
- torch.nn.functional.cosine_similarity( |
|
obs_enc_pred, obs_target[:, 1:, j].detach(), dim=-1 |
|
).mean() |
|
) |
|
return loss |
|
|
|
def _covariance_reg_loss(self, obs_enc: torch.Tensor): |
|
loss = off_diag_cov_loss(obs_enc) |
|
return loss * self.covariance_reg_coef |
|
|
|
def adjust_beta(self, epoch: int, max_epoch: int): |
|
if (self.ema_beta is None) or not self.beta_scheduling or (max_epoch == 0): |
|
return |
|
self.ema_encoder.beta = 1.0 - 0.5 * ( |
|
1.0 + np.cos(np.pi * epoch / max_epoch) |
|
) * (1.0 - self.ema_beta) |
|
if self.projector_use_ema: |
|
self.ema_projector.beta = 1.0 - 0.5 * ( |
|
1.0 + np.cos(np.pi * epoch / max_epoch) |
|
) * (1.0 - self.ema_beta) |
|
|
|
def step(self): |
|
self.forward_dynamics_optimizer.step() |
|
self.forward_dynamics_optimizer.zero_grad(set_to_none=True) |
|
if self.ema_beta is not None: |
|
self.ema_encoder.step(self.encoder) |
|
if self.projector_use_ema: |
|
self.ema_projector.step(self.projector) |
|
|