dynamo_ssl / train.py
jeffacce
initial commit
393d3de
import os
import tqdm
import utils
import hydra
import torch
import einops
import datasets
import numpy as np
import torch.distributed
from pathlib import Path
from datetime import timedelta
from omegaconf import OmegaConf
from accelerate import Accelerator
from collections import OrderedDict
from workspaces.base import Workspace
from torch.utils.data import DataLoader
from accelerate.logging import get_logger
from accelerate import InitProcessGroupKwargs, DistributedDataParallelKwargs
os.environ["WANDB_START_METHOD"] = "thread"
logger = get_logger(__name__)
class Trainer:
def __init__(self, cfg):
process_group_kwargs = InitProcessGroupKwargs(
timeout=timedelta(seconds=cfg.timeout_seconds)
)
dist_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
self.cfg = cfg
self.effective_batch_size = self.cfg.batch_size
self.accelerator = Accelerator(
log_with="wandb", kwargs_handlers=[process_group_kwargs, dist_kwargs]
)
logger.info(f"Mixed precision: {self.accelerator.mixed_precision}")
utils.set_seed_everywhere(cfg.seed)
self.job_num, self.work_dir = utils.get_hydra_jobnum_workdir()
# all processes use the work_dir from the main process
if torch.distributed.is_initialized():
objs = [str(self.work_dir)]
torch.distributed.broadcast_object_list(objs, 0)
self.work_dir = Path(objs[0])
self.accelerator.wait_for_everyone()
logger.info("Saving to {}".format(self.work_dir))
os.chdir(self.work_dir)
self.work_dir = Path(os.getcwd()) # get the absolute path
self.dataset = hydra.utils.instantiate(cfg.env.dataset)
self.train_set, self.test_set = self._split_and_slice_dataset(self.dataset)
self._setup_loaders(batch_size=self.cfg.batch_size)
self._init_tracker(cfg)
# Create the model
self.encoder = None
self.projector = None
self.ssl = None
self._init_encoder()
self._init_projector()
self._init_ssl()
self.workspace: Workspace = hydra.utils.instantiate(
self.cfg.env.workspace,
cfg=self.cfg,
work_dir=self.work_dir,
_recursive_=False,
)
self.workspace.set_dataset(self.dataset)
self.log_components = OrderedDict()
self.epoch = 0
def _init_tracker(self, cfg):
wandb_cfg = OmegaConf.to_container(cfg, resolve=True)
wandb_cfg["effective_batch_size"] = self.effective_batch_size
wandb_cfg["save_path"] = str(self.work_dir)
self.accelerator.init_trackers(
project_name=cfg.project,
config=wandb_cfg,
init_kwargs={
"wandb": {
"reinit": False,
"settings": {"start_method": "thread"},
},
},
)
if self.accelerator.is_main_process:
self.wandb_run = self.accelerator.get_tracker("wandb", unwrap=True)
logger.info("wandb run url: %s", self.wandb_run.get_url())
def _init_encoder(self):
if self.encoder is None: # possibly already initialized from snapshot
self.encoder = hydra.utils.instantiate(self.cfg.encoder)
if self.cfg.sync_bn:
self.encoder = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
self.encoder
)
self.encoder_optim = torch.optim.AdamW(
params=self.encoder.parameters(),
lr=self.cfg.ssl_lr,
weight_decay=self.cfg.ssl_weight_decay,
betas=tuple(self.cfg.betas),
)
(
self.encoder,
self.encoder_optim,
) = self.accelerator.prepare(self.encoder, self.encoder_optim)
if self.accelerator.is_main_process:
self.wandb_run.watch(self.encoder)
def _init_projector(self):
if self.projector is None: # possibly already initialized from snapshot
self.projector = hydra.utils.instantiate(
self.cfg.projector, _recursive_=False
)
self.projector_optim: torch.optim.Optimizer = (
self.projector.configure_optimizers(
lr=self.cfg.ssl_lr,
weight_decay=self.cfg.ssl_weight_decay,
betas=tuple(self.cfg.betas),
)
)
(
self.projector,
self.projector_optim,
) = self.accelerator.prepare(self.projector, self.projector_optim)
def _init_ssl(self):
if self.ssl is None:
self.ssl = hydra.utils.instantiate(
self.cfg.ssl,
encoder=self.encoder,
projector=self.projector,
)
def _split_and_slice_dataset(self, dataset):
kwargs = {
"train_fraction": self.cfg.train_fraction,
"random_seed": self.cfg.seed,
"window_size": self.cfg.window_size,
"future_conditional": (self.cfg.goal_conditional == "future"),
"min_future_sep": self.cfg.min_future_sep,
"future_seq_len": self.cfg.goal_seq_len,
"num_extra_predicted_actions": self.cfg.num_extra_predicted_actions,
}
return datasets.core.get_train_val_sliced(dataset, **kwargs)
def _setup_loaders(self, batch_size=None, pin_memory=True, num_workers=None):
if num_workers is None:
num_workers = self.cfg.num_workers
kwargs = {
"batch_size": batch_size or self.cfg.batch_size,
"num_workers": num_workers,
"pin_memory": pin_memory,
}
# scale batch size by number of gpus
assert kwargs["batch_size"] % self.accelerator.num_processes == 0, (
"Batch size must be divisible by the number of processes. "
f"Got {kwargs['batch_size']} and {self.accelerator.num_processes}."
)
kwargs["batch_size"] = kwargs["batch_size"] // self.accelerator.num_processes
self.train_loader = DataLoader(self.train_set, shuffle=True, **kwargs)
self.test_loader = DataLoader(self.test_set, shuffle=False, **kwargs)
self.train_loader = self.accelerator.prepare(self.train_loader)
self.test_loader = self.accelerator.prepare(self.test_loader)
def train(self):
if self.cfg.use_lr_scheduling:
lr = self.adjust_lr()
self.log_append("metrics", 1, {"lr": lr})
self.ssl.adjust_beta(self.epoch, self.cfg.num_epochs)
pbar = tqdm.tqdm(
self.train_loader,
desc=f"Training epoch {self.epoch}",
disable=not self.accelerator.is_main_process,
ncols=80,
)
for data in pbar:
obs, _, _ = data
with self.accelerator.autocast():
(
obs_enc,
obs_proj,
ssl_loss,
ssl_loss_components,
) = self.ssl.forward(obs)
self.log_append("ssl_train", len(obs), ssl_loss_components)
self.accelerator.backward(ssl_loss, retain_graph=True)
if self.cfg.clip_grad_norm:
self.accelerator.clip_grad_norm_(
self.encoder.parameters(), self.cfg.clip_grad_norm
)
self.accelerator.clip_grad_norm_(
self.projector.parameters(), self.cfg.clip_grad_norm
)
self.accelerator.clip_grad_norm_(
self.ssl.parameters(), self.cfg.clip_grad_norm
)
self.encoder_optim.step()
self.projector_optim.step()
self.ssl.step()
self.encoder_optim.zero_grad(set_to_none=True)
self.projector_optim.zero_grad(set_to_none=True)
def eval(self):
if self.cfg.eval_offline:
# env-specific offline eval
self.workspace.set_models(
encoder=self.encoder,
projector=self.projector,
)
offline_eval_results = self.workspace.run_offline_eval()
if self.accelerator.is_main_process:
self.log_append("env_offline_eval", 1, offline_eval_results)
with utils.inference.eval_mode(
self.encoder,
self.projector,
no_grad=True,
):
# eval on test set
self.eval_loss = 0
for data in self.test_loader:
obs, _, _ = data
(
obs_enc,
obs_proj,
ssl_loss,
ssl_loss_components,
) = self.ssl.forward(obs)
ssl_loss = self.accelerator.gather_for_metrics(ssl_loss).mean()
ssl_loss_components = utils.reduce_dict(
torch.mean,
self.accelerator.gather_for_metrics(ssl_loss_components),
)
self.log_append(
"ssl_eval",
len(obs),
ssl_loss_components,
)
flat_obs_enc = self.accelerator.gather_for_metrics(obs_enc)
flat_obs_enc = einops.rearrange(flat_obs_enc, "N T V E -> (N T V) E")
obs_enc_mean_std = flat_obs_enc.std(dim=0).mean()
obs_enc_mean_norm = flat_obs_enc.norm(dim=-1).mean()
self.log_append(
"metrics",
len(flat_obs_enc),
{
"obs_enc_mean_std": obs_enc_mean_std,
"obs_enc_mean_norm": obs_enc_mean_norm,
},
)
flat_obs_proj = self.accelerator.gather_for_metrics(obs_proj)
flat_obs_proj = einops.rearrange(flat_obs_proj, "N T V Z -> (N T V) Z")
obs_proj_mean_std = flat_obs_proj.std(dim=0).mean()
obs_proj_mean_norm = flat_obs_proj.norm(dim=-1).mean()
self.log_append(
"metrics",
len(flat_obs_proj),
{
"obs_proj_mean_std": obs_proj_mean_std,
"obs_proj_mean_norm": obs_proj_mean_norm,
},
)
def run(self):
snapshot = Path(self.work_dir) / "snapshot.pt"
if snapshot.exists():
print(f"Resuming: {snapshot}")
self.load_snapshot()
self.train_iterator = tqdm.trange(
self.epoch,
self.cfg.num_epochs,
disable=not self.accelerator.is_main_process,
ncols=80,
)
self.train_iterator.set_description("Training")
# Reset the log.
self.log_components = OrderedDict()
for epoch in self.train_iterator:
self.epoch = epoch
self.train()
self.eval()
self.flush_log(step=self.epoch, iterator=self.train_iterator)
if (self.epoch + 1) % self.cfg.save_every_epochs == 0:
self.save_snapshot()
self.accelerator.wait_for_everyone()
self.accelerator.end_training()
return float(self.eval_loss)
def save_snapshot(self):
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
self._keys_to_save = [
"encoder",
"projector",
"encoder_optim",
"projector_optim",
"ssl",
"epoch",
]
payload = {}
# if key is an accelerator DDP model, unwrap
for k in self._keys_to_save:
if hasattr(self.__dict__[k], "module"):
payload[k] = self.accelerator.unwrap_model(self.__dict__[k])
else:
payload[k] = self.__dict__[k]
with (self.work_dir / "snapshot.pt").open("wb") as f:
torch.save(payload, f)
with (self.work_dir / "encoder.pt").open("wb") as f:
torch.save(payload["encoder"], f)
with (self.work_dir / f"snapshot_{self.epoch}.pt").open("wb") as f:
torch.save(payload, f)
with (self.work_dir / f"encoder_{self.epoch}.pt").open("wb") as f:
torch.save(payload["encoder"], f)
def load_snapshot(self):
with (self.work_dir / "snapshot.pt").open("rb") as f:
payload = torch.load(f)
for k, v in payload.items():
self.__dict__[k] = v
not_in_payload = set(self._keys_to_save) - set(payload.keys())
if len(not_in_payload):
logger.warning("Keys not found in snapshot: %s", not_in_payload)
def log_append(self, log_key, length, loss_components):
for key, value in loss_components.items():
if isinstance(value, torch.Tensor):
value = value.detach().cpu().item()
key_name = f"{log_key}/{key}"
count, sum = self.log_components.get(key_name, (0, 0.0))
self.log_components[key_name] = (
count + length,
sum + (length * value),
)
def flush_log(self, step, iterator=None):
log_components = OrderedDict()
iterator_log_component = OrderedDict()
for key, value in self.log_components.items():
count, sum = value
to_log = sum / count
log_components[key] = to_log
# Set the iterator status
log_key, name_key = key.split("/")
iterator_log_name = f"{log_key[0]}{name_key[0]}".upper()
iterator_log_component[iterator_log_name] = to_log
postfix = ",".join(
"{}:{:.2e}".format(key, iterator_log_component[key])
for key in iterator_log_component.keys()
)
if iterator is not None:
iterator.set_postfix_str(postfix)
self.accelerator.log(log_components, step=step)
logger.info(f"[{self.job_num}] Epoch {self.epoch}: {log_components}")
self.log_components = OrderedDict()
def adjust_lr(self):
# from https://github.com/facebookresearch/moco-v3/blob/c349e6e24f40d3fedb22d973f92defa4cedf37a7/main_moco.py#L420
"""Decays the learning rate with half-cycle cosine after warmup"""
# fmt: off
if self.epoch < self.cfg.warmup_epochs:
lr = self.cfg.ssl_lr * self.epoch / self.cfg.warmup_epochs
else:
lr = self.cfg.ssl_lr * 0.5 * (1.0 + np.cos(np.pi * (self.epoch - self.cfg.warmup_epochs) / (self.cfg.num_epochs - self.cfg.warmup_epochs)))
# fmt: on
optimizers = [self.encoder_optim, self.projector_optim]
for optim in optimizers:
for param_group in optim.param_groups:
param_group["lr"] = lr
return lr
@hydra.main(version_base="1.2", config_path="configs", config_name="train")
def main(cfg):
trainer = Trainer(cfg)
eval_loss = trainer.run()
return eval_loss
if __name__ == "__main__":
main()