Spaces:
Runtime error
Runtime error
import os | |
import pprint | |
import shutil | |
import sys | |
from datetime import datetime | |
from pathlib import Path | |
import torch | |
import yaml | |
from easydict import EasyDict as edict | |
from .distributed import get_world_size, synchronize | |
from .log import add_logging, logger | |
def init_experiment(args, model_name): | |
model_path = Path(args.model_path) | |
ftree = get_model_family_tree(model_path, model_name=model_name) | |
if ftree is None: | |
print( | |
'Models can only be located in the "models" directory in the root of the repository' | |
) | |
sys.exit(1) | |
cfg = load_config(model_path) | |
update_config(cfg, args) | |
cfg.distributed = args.distributed | |
cfg.local_rank = args.local_rank | |
if cfg.distributed: | |
torch.distributed.init_process_group(backend="nccl", init_method="env://") | |
if args.workers > 0: | |
torch.multiprocessing.set_start_method("forkserver", force=True) | |
experiments_path = Path(cfg.EXPS_PATH) | |
exp_parent_path = experiments_path / "/".join(ftree) | |
exp_parent_path.mkdir(parents=True, exist_ok=True) | |
if cfg.resume_exp: | |
exp_path = find_resume_exp(exp_parent_path, cfg.resume_exp) | |
else: | |
last_exp_indx = find_last_exp_indx(exp_parent_path) | |
exp_name = f"{last_exp_indx:03d}" | |
if cfg.exp_name: | |
exp_name += "_" + cfg.exp_name | |
exp_path = exp_parent_path / exp_name | |
synchronize() | |
if cfg.local_rank == 0: | |
exp_path.mkdir(parents=True) | |
cfg.EXP_PATH = exp_path | |
cfg.CHECKPOINTS_PATH = exp_path / "checkpoints" | |
cfg.VIS_PATH = exp_path / "vis" | |
cfg.LOGS_PATH = exp_path / "logs" | |
if cfg.local_rank == 0: | |
cfg.LOGS_PATH.mkdir(exist_ok=True) | |
cfg.CHECKPOINTS_PATH.mkdir(exist_ok=True) | |
cfg.VIS_PATH.mkdir(exist_ok=True) | |
dst_script_path = exp_path / ( | |
model_path.stem | |
+ datetime.strftime(datetime.today(), "_%Y-%m-%d-%H-%M-%S.py") | |
) | |
if args.temp_model_path: | |
shutil.copy(args.temp_model_path, dst_script_path) | |
os.remove(args.temp_model_path) | |
else: | |
shutil.copy(model_path, dst_script_path) | |
synchronize() | |
if cfg.gpus != "": | |
gpu_ids = [int(id) for id in cfg.gpus.split(",")] | |
else: | |
gpu_ids = list(range(max(cfg.ngpus, get_world_size()))) | |
cfg.gpus = ",".join([str(id) for id in gpu_ids]) | |
cfg.gpu_ids = gpu_ids | |
cfg.ngpus = len(gpu_ids) | |
cfg.multi_gpu = cfg.ngpus > 1 | |
if cfg.distributed: | |
cfg.device = torch.device("cuda") | |
cfg.gpu_ids = [cfg.gpu_ids[cfg.local_rank]] | |
torch.cuda.set_device(cfg.gpu_ids[0]) | |
else: | |
if cfg.multi_gpu: | |
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpus | |
ngpus = torch.cuda.device_count() | |
assert ngpus == cfg.ngpus | |
cfg.device = torch.device(f"cuda:{cfg.gpu_ids[0]}") | |
if cfg.local_rank == 0: | |
add_logging(cfg.LOGS_PATH, prefix="train_") | |
logger.info(f"Number of GPUs: {cfg.ngpus}") | |
if cfg.distributed: | |
logger.info(f"Multi-Process Multi-GPU Distributed Training") | |
logger.info("Run experiment with config:") | |
logger.info(pprint.pformat(cfg, indent=4)) | |
return cfg | |
def get_model_family_tree(model_path, terminate_name="models", model_name=None): | |
if model_name is None: | |
model_name = model_path.stem | |
family_tree = [model_name] | |
for x in model_path.parents: | |
if x.stem == terminate_name: | |
break | |
family_tree.append(x.stem) | |
else: | |
return None | |
return family_tree[::-1] | |
def find_last_exp_indx(exp_parent_path): | |
indx = 0 | |
for x in exp_parent_path.iterdir(): | |
if not x.is_dir(): | |
continue | |
exp_name = x.stem | |
if exp_name[:3].isnumeric(): | |
indx = max(indx, int(exp_name[:3]) + 1) | |
return indx | |
def find_resume_exp(exp_parent_path, exp_pattern): | |
candidates = sorted(exp_parent_path.glob(f"{exp_pattern}*")) | |
if len(candidates) == 0: | |
print( | |
f'No experiments could be found that satisfies the pattern = "*{exp_pattern}"' | |
) | |
sys.exit(1) | |
elif len(candidates) > 1: | |
print("More than one experiment found:") | |
for x in candidates: | |
print(x) | |
sys.exit(1) | |
else: | |
exp_path = candidates[0] | |
print(f'Continue with experiment "{exp_path}"') | |
return exp_path | |
def update_config(cfg, args): | |
for param_name, value in vars(args).items(): | |
if param_name.lower() in cfg or param_name.upper() in cfg: | |
continue | |
cfg[param_name] = value | |
def load_config(model_path): | |
model_name = model_path.stem | |
config_path = model_path.parent / (model_name + ".yml") | |
if config_path.exists(): | |
cfg = load_config_file(config_path) | |
else: | |
cfg = dict() | |
cwd = Path.cwd() | |
config_parent = config_path.parent.absolute() | |
while len(config_parent.parents) > 0: | |
config_path = config_parent / "config.yml" | |
if config_path.exists(): | |
local_config = load_config_file(config_path, model_name=model_name) | |
cfg.update({k: v for k, v in local_config.items() if k not in cfg}) | |
if config_parent.absolute() == cwd: | |
break | |
config_parent = config_parent.parent | |
return edict(cfg) | |
def load_config_file(config_path, model_name=None, return_edict=False): | |
with open(config_path, "r") as f: | |
cfg = yaml.safe_load(f) | |
if "SUBCONFIGS" in cfg: | |
if model_name is not None and model_name in cfg["SUBCONFIGS"]: | |
cfg.update(cfg["SUBCONFIGS"][model_name]) | |
del cfg["SUBCONFIGS"] | |
return edict(cfg) if return_edict else cfg | |