File size: 1,347 Bytes
393d3de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import abc

import torch
import torch.nn as nn
from typing import Tuple, Dict


class AbstractSSL(nn.Module):
    """
    This class should contain everything inside the SSL method (e.g. key queue for MoCo, EMA for BYOL, etc.), loss function, and the optimizer.
    """

    @abc.abstractmethod
    def __init__(
        self,
        encoder: nn.Module,
        projector: nn.Module,
    ):
        """
        Initializes the SSL method.
        Inputs:
            encoder: the encoder module
            projector: the projector module
        """
        raise NotImplementedError

    @abc.abstractmethod
    def forward(
        self,
        obs: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor],]:
        """
        Inputs:
            obs: the input observations
        Outputs:
            obs_enc: the encoded observations
            obs_proj: the projected observations
            loss: the total loss
            loss_components: the components of the total loss
        """
        raise NotImplementedError

    def step(self):
        """
        This function should be called at each training step to update the SSL method's internal state.
        e.g. step the optimizer, update the key queue for MoCo, update EMA for BYOL, etc.
        """
        raise NotImplementedError