dynamo_ssl / models /ssl /dynamo.py
jeffacce
initial commit
393d3de
import torch
import einops
import numpy as np
import torch.nn as nn
from .base import AbstractSSL
from accelerate import Accelerator
from typing import Tuple, Dict, Optional
from ..transformer_encoder import TransformerEncoder, TransformerEncoderConfig
from ..ema import EMA
# https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py#L239
def off_diag(x):
n, m = x.shape
assert n == m
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
def off_diag_cov_loss(x: torch.Tensor) -> torch.Tensor:
cov = torch.cov(einops.rearrange(x, "... E -> E (...)"))
return off_diag(cov).square().mean()
accelerator = Accelerator()
class DynaMoSSL(AbstractSSL):
def __init__(
self,
encoder: nn.Module,
projector: nn.Module,
window_size: int,
feature_dim: int,
projection_dim: int,
n_layer: int,
n_head: int,
n_embd: int,
dropout: float = 0.0,
covariance_reg_coef: float = 0.04,
dynamics_loss_coef: float = 1.0,
ema_beta: Optional[float] = None, # None for SimSiam; float for EMA encoder
beta_scheduling: bool = False,
projector_use_ema: bool = False,
lr: float = 1e-4,
weight_decay: float = 0.0,
betas: Tuple[float, float] = (0.9, 0.999),
separate_single_views: bool = True,
):
nn.Module.__init__(self)
# avoid registering encoder/projector as submodules
self.__dict__["encoder"] = encoder
self.__dict__["projector"] = projector
forward_dynamics_cfg = TransformerEncoderConfig(
block_size=window_size,
input_dim=feature_dim + projection_dim,
n_layer=n_layer,
n_head=n_head,
n_embd=n_embd,
dropout=dropout,
output_dim=feature_dim,
)
self.forward_dynamics = TransformerEncoder(forward_dynamics_cfg)
self.forward_dynamics_optimizer = self.forward_dynamics.configure_optimizers(
weight_decay=weight_decay,
lr=lr,
betas=betas,
)
self.forward_dynamics, self.forward_dynamics_optimizer = accelerator.prepare(
self.forward_dynamics,
self.forward_dynamics_optimizer,
)
self.covariance_reg_coef = covariance_reg_coef
self.dynamics_loss_coef = dynamics_loss_coef
self.ema_beta = ema_beta
self.beta_scheduling = beta_scheduling
self.projector_use_ema = projector_use_ema
if self.ema_beta is not None:
self.ema_encoder = EMA(self.encoder, self.ema_beta)
if self.projector_use_ema:
self.ema_projector = EMA(self.projector, self.ema_beta)
self.separate_single_views = separate_single_views
def forward(
self,
obs: torch.Tensor,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
Dict[str, torch.Tensor],
]:
obs_enc = self.encoder(obs)
if self.ema_beta is not None:
obs_target = self.ema_encoder(obs) # use EMA encoder as target
if self.projector_use_ema:
obs_proj = self.ema_projector(obs_enc)
else:
obs_proj = self.projector(obs_enc)
else:
obs_target = obs_enc # use SimSiam target
obs_proj = self.projector(obs_enc)
covariance_loss = self._covariance_reg_loss(obs_enc)
dynamics_loss, dynamics_loss_components = self._forward_dyn_loss(
obs_enc, obs_proj, obs_target, self.separate_single_views
)
total_loss = dynamics_loss + covariance_loss
loss_components = {
"total_loss": total_loss,
**dynamics_loss_components,
"covariance_loss": covariance_loss,
}
return obs_enc, obs_proj, total_loss, loss_components
def _forward_dyn_loss(
self,
obs_enc: torch.Tensor,
obs_proj: torch.Tensor,
obs_target: torch.Tensor,
separate_single_views: bool = True,
):
V = obs_proj.shape[2] # number of views
total = torch.zeros(1, device=obs_enc.device)
loss_components = {}
if separate_single_views:
for i in range(V):
loss = self._forward_dyn_loss_one_pair(
obs_enc, obs_proj, obs_target, i, i
)
loss *= self.dynamics_loss_coef / V
total += loss
loss_components[f"dynamics_loss_{i}_{i}"] = loss
else:
total_view_pairs = V * (V - 1) # w/ order
for i in range(V):
for j in range(V):
if i == j:
continue
loss = self._forward_dyn_loss_one_pair(
obs_enc, obs_proj, obs_target, i, j
)
loss *= self.dynamics_loss_coef / total_view_pairs
total += loss
loss_components[f"dynamics_loss_{i}_{j}"] = loss
loss_components["dynamics_loss_total"] = total
if self.ema_beta is not None:
loss_components["ema_beta"] = torch.Tensor([self.ema_encoder.beta]).to(
obs_enc.device
)
return total, loss_components
def _forward_dyn_loss_one_pair(
self,
obs_enc: torch.Tensor,
obs_proj: torch.Tensor,
obs_target: torch.Tensor,
i: int,
j: int,
):
forward_dyn_input = torch.cat([obs_enc[:, :-1, j], obs_proj[:, 1:, i]], dim=-1)
obs_enc_pred = self.forward_dynamics(forward_dyn_input) # (N, T-1, E)
loss = (
1
- torch.nn.functional.cosine_similarity(
obs_enc_pred, obs_target[:, 1:, j].detach(), dim=-1
).mean()
)
return loss
def _covariance_reg_loss(self, obs_enc: torch.Tensor):
loss = off_diag_cov_loss(obs_enc)
return loss * self.covariance_reg_coef
def adjust_beta(self, epoch: int, max_epoch: int):
if (self.ema_beta is None) or not self.beta_scheduling or (max_epoch == 0):
return
self.ema_encoder.beta = 1.0 - 0.5 * (
1.0 + np.cos(np.pi * epoch / max_epoch)
) * (1.0 - self.ema_beta)
if self.projector_use_ema:
self.ema_projector.beta = 1.0 - 0.5 * (
1.0 + np.cos(np.pi * epoch / max_epoch)
) * (1.0 - self.ema_beta)
def step(self):
self.forward_dynamics_optimizer.step()
self.forward_dynamics_optimizer.zero_grad(set_to_none=True)
if self.ema_beta is not None:
self.ema_encoder.step(self.encoder)
if self.projector_use_ema:
self.ema_projector.step(self.projector)