|
import torch |
|
import torch.nn as nn |
|
from copy import deepcopy |
|
|
|
|
|
class EMA(nn.Module): |
|
def __init__(self, src_model: nn.Module, beta: float, copy: bool = True): |
|
super().__init__() |
|
if copy: |
|
self.model = deepcopy(src_model) |
|
else: |
|
self.model = src_model |
|
self.model.eval() |
|
self.model.requires_grad_(False) |
|
self.beta = beta |
|
|
|
def step(self, src_model): |
|
one_minus_beta = 1.0 - self.beta |
|
for ema_param, src_param in zip( |
|
self.model.parameters(), src_model.parameters() |
|
): |
|
|
|
ema_param.data.mul_(self.beta).add_(src_param.data, alpha=one_minus_beta) |
|
ema_param.requires_grad_(False) |
|
|
|
def forward(self, *args, **kwargs): |
|
with torch.no_grad(): |
|
return self.model(*args, **kwargs) |
|
|