import torch.nn as nn import torch.nn.functional as F def mlp( input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None, batchnorm=False, activation=nn.ReLU, ): if hidden_depth == 0: mods = [nn.Linear(input_dim, output_dim)] else: mods = ( [nn.Linear(input_dim, hidden_dim), activation(inplace=True)] if not batchnorm else [ nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), activation(inplace=True), ] ) for _ in range(hidden_depth - 1): mods += ( [nn.Linear(hidden_dim, hidden_dim), activation(inplace=True)] if not batchnorm else [ nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), activation(inplace=True), ] ) mods.append(nn.Linear(hidden_dim, output_dim)) if output_mod is not None: mods.append(output_mod) trunk = nn.Sequential(*mods) return trunk def weight_init(m): """Custom weight init for Conv2D and Linear layers.""" if isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight.data) if hasattr(m.bias, "data"): m.bias.data.fill_(0.0) class MLP(nn.Module): def __init__( self, input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None, batchnorm=False, activation=nn.ReLU, ): super().__init__() self.trunk = mlp( input_dim, hidden_dim, output_dim, hidden_depth, output_mod, batchnorm=batchnorm, activation=activation, ) self.apply(weight_init) def forward(self, x): return self.trunk(x)