Spaces:
Runtime error
Runtime error
File size: 3,328 Bytes
a6dac9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch
An attempt at getting decent performing EvoNorms running in PyTorch.
While currently faster than other impl, still quite a ways off the built-in BN
in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed).
Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
class EvoNormBatch2d(nn.Module):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
super(EvoNormBatch2d, self).__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.momentum = momentum
self.eps = eps
param_shape = (1, num_features, 1, 1)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
if apply_act:
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.apply_act:
nn.init.ones_(self.v)
def forward(self, x):
assert x.dim() == 4, 'expected 4D input'
x_type = x.dtype
if self.training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
n = x.numel() / x.shape[1]
self.running_var.copy_(
var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum))
else:
var = self.running_var
if self.apply_act:
v = self.v.to(dtype=x_type)
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
d = d.max((var + self.eps).sqrt().to(dtype=x_type))
x = x / d
return x * self.weight + self.bias
class EvoNormSample2d(nn.Module):
def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None):
super(EvoNormSample2d, self).__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.groups = groups
self.eps = eps
param_shape = (1, num_features, 1, 1)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
if apply_act:
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.apply_act:
nn.init.ones_(self.v)
def forward(self, x):
assert x.dim() == 4, 'expected 4D input'
B, C, H, W = x.shape
assert C % self.groups == 0
if self.apply_act:
n = x * (x * self.v).sigmoid()
x = x.reshape(B, self.groups, -1)
x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
x = x.reshape(B, C, H, W)
return x * self.weight + self.bias
|