import abc import torch import torch.nn as nn from typing import Tuple, Dict class AbstractSSL(nn.Module): """ This class should contain everything inside the SSL method (e.g. key queue for MoCo, EMA for BYOL, etc.), loss function, and the optimizer. """ @abc.abstractmethod def __init__( self, encoder: nn.Module, projector: nn.Module, ): """ Initializes the SSL method. Inputs: encoder: the encoder module projector: the projector module """ raise NotImplementedError @abc.abstractmethod def forward( self, obs: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor],]: """ Inputs: obs: the input observations Outputs: obs_enc: the encoded observations obs_proj: the projected observations loss: the total loss loss_components: the components of the total loss """ raise NotImplementedError def step(self): """ This function should be called at each training step to update the SSL method's internal state. e.g. step the optimizer, update the key queue for MoCo, update EMA for BYOL, etc. """ raise NotImplementedError