File size: 5,654 Bytes
393d3de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import os
import torch
import wandb
import random
import einops
import numpy as np
import torch.nn as nn
from . import inference
import torch.utils.data
from pathlib import Path
from hydra.types import RunMode
from typing import Callable, Dict
from prettytable import PrettyTable
from collections import OrderedDict
from torch.utils.data import random_split
from hydra.core.hydra_config import HydraConfig
# Modified from https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad:
continue
params = parameter.numel()
table.add_row([name, params])
total_params += params
return total_params, table
def get_split_idx(l, seed, train_fraction=0.95):
rng = torch.Generator().manual_seed(seed)
idx = torch.randperm(l, generator=rng).tolist()
l_train = int(l * train_fraction)
return idx[:l_train], idx[l_train:]
def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None):
if hidden_depth == 0:
mods = [nn.Linear(input_dim, output_dim)]
else:
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
for i in range(hidden_depth - 1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
mods.append(nn.Linear(hidden_dim, output_dim))
if output_mod is not None:
mods.append(output_mod)
trunk = nn.Sequential(*mods)
return trunk
def freeze_module(module: nn.Module) -> nn.Module:
for param in module.parameters():
param.requires_grad = False
module.eval()
return module
def set_seed_everywhere(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
def batch_indexing(input, idx):
"""
Given an input with shape (*batch_shape, k, *value_shape),
and an index with shape (*batch_shape) with values in [0, k),
index the input on the k dimension.
Returns: (*batch_shape, *value_shape)
"""
batch_shape = idx.shape
dim = len(idx.shape)
value_shape = input.shape[dim + 1 :]
N = batch_shape.numel()
assert input.shape[:dim] == batch_shape, "Input batch shape must match index shape"
assert len(value_shape) > 0, "No values left after indexing"
# flatten the batch shape
input_flat = input.reshape(N, *input.shape[dim:])
idx_flat = idx.reshape(N)
result = input_flat[np.arange(N), idx_flat]
return result.reshape(*batch_shape, *value_shape)
def shuffle_along_axis(a, axis):
idx = np.random.rand(*a.shape).argsort(axis=axis)
return np.take_along_axis(a, idx, axis=axis)
def transpose_batch_timestep(*args):
return (einops.rearrange(arg, "b t ... -> t b ...") for arg in args)
class TrainWithLogger:
def reset_log(self):
self.log_components = OrderedDict()
def log_append(self, log_key, length, loss_components):
for key, value in loss_components.items():
key_name = f"{log_key}/{key}"
count, sum = self.log_components.get(key_name, (0, 0.0))
self.log_components[key_name] = (
count + length,
sum + (length * value.detach().cpu().item()),
)
def flush_log(self, epoch, iterator=None):
log_components = OrderedDict()
iterator_log_component = OrderedDict()
for key, value in self.log_components.items():
count, sum = value
to_log = sum / count
log_components[key] = to_log
# Set the iterator status
log_key, name_key = key.split("/")
iterator_log_name = f"{log_key[0]}{name_key[0]}".upper()
iterator_log_component[iterator_log_name] = to_log
postfix = ",".join(
"{}:{:.2e}".format(key, iterator_log_component[key])
for key in iterator_log_component.keys()
)
if iterator is not None:
iterator.set_postfix_str(postfix)
wandb.log(log_components, step=epoch)
self.log_components = OrderedDict()
class SaveModule(nn.Module):
def set_snapshot_path(self, path):
self.snapshot_path = path
print(f"Setting snapshot path to {self.snapshot_path}")
def save_snapshot(self):
os.makedirs(self.snapshot_path, exist_ok=True)
torch.save(self.state_dict(), self.snapshot_path / "snapshot.pth")
def load_snapshot(self):
self.load_state_dict(torch.load(self.snapshot_path / "snapshot.pth"))
def split_datasets(dataset, train_fraction=0.95, random_seed=42):
dataset_length = len(dataset)
lengths = [
int(train_fraction * dataset_length),
dataset_length - int(train_fraction * dataset_length),
]
train_set, val_set = random_split(
dataset, lengths, generator=torch.Generator().manual_seed(random_seed)
)
return train_set, val_set
def reduce_dict(f: Callable, d: Dict):
return {k: reduce_dict(f, v) if isinstance(v, dict) else f(v) for k, v in d.items()}
def get_hydra_jobnum_workdir():
if HydraConfig.get().mode == RunMode.MULTIRUN:
job_num = HydraConfig.get().job.num
work_dir = Path(HydraConfig.get().sweep.dir) / HydraConfig.get().sweep.subdir
else:
job_num = 0
work_dir = HydraConfig.get().run.dir
return job_num, work_dir
|