jeffacce
initial commit
393d3de
import abc
import torch
import torch.nn as nn
from typing import Tuple, Dict
class AbstractSSL(nn.Module):
"""
This class should contain everything inside the SSL method (e.g. key queue for MoCo, EMA for BYOL, etc.), loss function, and the optimizer.
"""
@abc.abstractmethod
def __init__(
self,
encoder: nn.Module,
projector: nn.Module,
):
"""
Initializes the SSL method.
Inputs:
encoder: the encoder module
projector: the projector module
"""
raise NotImplementedError
@abc.abstractmethod
def forward(
self,
obs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor],]:
"""
Inputs:
obs: the input observations
Outputs:
obs_enc: the encoded observations
obs_proj: the projected observations
loss: the total loss
loss_components: the components of the total loss
"""
raise NotImplementedError
def step(self):
"""
This function should be called at each training step to update the SSL method's internal state.
e.g. step the optimizer, update the key queue for MoCo, update EMA for BYOL, etc.
"""
raise NotImplementedError