kernel-luso-comfort's picture
Add Apache License 2.0 header to multiple source files
202eff6
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import time
import pickle
import torch
import torch.nn as nn
from utilities.distributed import is_main_process
logger = logging.getLogger(__name__)
NORM_MODULES = [
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
]
def register_norm_module(cls):
NORM_MODULES.append(cls)
return cls
def align_and_update_state_dicts(model_state_dict, ckpt_state_dict):
model_keys = sorted(model_state_dict.keys())
ckpt_keys = sorted(ckpt_state_dict.keys())
result_dicts = {}
matched_log = []
unmatched_log = []
unloaded_log = []
for model_key in model_keys:
model_weight = model_state_dict[model_key]
if model_key in ckpt_keys:
ckpt_weight = ckpt_state_dict[model_key]
if model_weight.shape == ckpt_weight.shape:
result_dicts[model_key] = ckpt_weight
ckpt_keys.pop(ckpt_keys.index(model_key))
matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
else:
unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
else:
unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape))
if is_main_process():
for info in matched_log:
logger.info(info)
for info in unloaded_log:
logger.warning(info)
for key in ckpt_keys:
logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape))
for info in unmatched_log:
logger.warning(info)
return result_dicts