# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
from typing import Union, Tuple | |
import torch | |
from torch import nn | |
norm_t = Union[Tuple[float, float, float], torch.Tensor] | |
class InputConditioner(nn.Module): | |
def __init__(self, | |
input_scale: float, | |
norm_mean: norm_t, | |
norm_std: norm_t, | |
dtype: torch.dtype = None, | |
): | |
super().__init__() | |
self.dtype = dtype | |
self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale) | |
self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale) | |
def forward(self, x: torch.Tensor): | |
y = (x - self.norm_mean) / self.norm_std | |
if self.dtype is not None: | |
y = y.to(self.dtype) | |
return y | |
def get_default_conditioner(): | |
from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD | |
return InputConditioner( | |
input_scale=1.0, | |
norm_mean=OPENAI_CLIP_MEAN, | |
norm_std=OPENAI_CLIP_STD, | |
) | |
def _to_tensor(v: norm_t): | |
return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1) | |