|
from transformers import PretrainedConfig, PreTrainedModel, BertModel, BertConfig
|
|
import timm
|
|
import torch.nn as nn
|
|
import torch
|
|
import numpy
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
|
|
class RenameLayerScale(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
init_values: float = 1e-5,
|
|
inplace: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self.inplace = inplace
|
|
self.weight = nn.Parameter(init_values * torch.ones(dim))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x.mul_(self.weight) if self.inplace else x * self.weight
|
|
|
|
timm.models.vision_transformer.LayerScale = RenameLayerScale
|
|
|
|
class KEEPConfig(PretrainedConfig):
|
|
model_type = "keep"
|
|
|
|
def __init__(
|
|
self,
|
|
vision_config=None,
|
|
text_config=None,
|
|
projection_dim=768,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
self.vision_config = vision_config
|
|
self.text_config = text_config
|
|
self.projection_dim = projection_dim
|
|
|
|
class KEEPModel(PreTrainedModel):
|
|
config_class = KEEPConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
|
|
vision_config = config.vision_config
|
|
self.visual = timm.create_model(
|
|
"vit_large_patch16_224",
|
|
pretrained=False,
|
|
img_size=vision_config["img_size"],
|
|
patch_size=vision_config["patch_size"],
|
|
init_values=vision_config["init_values"],
|
|
num_classes=vision_config["num_classes"],
|
|
)
|
|
|
|
self.visual_head = nn.Sequential(
|
|
nn.Linear(self.visual.num_features, config.projection_dim),
|
|
nn.GELU(),
|
|
nn.Linear(config.projection_dim, config.projection_dim)
|
|
)
|
|
|
|
|
|
text_config = BertConfig(**config.text_config)
|
|
self.text = BertModel(text_config)
|
|
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * numpy.log(1 / 0.04))
|
|
|
|
def encode_image(self, image_inputs):
|
|
vision_features = self.visual(image_inputs)
|
|
vision_features = torch.nn.functional.normalize(self.visual_head(vision_features), dim=-1)
|
|
|
|
return vision_features
|
|
|
|
def encode_text(self, text_inputs):
|
|
text_features = torch.nn.functional.normalize(self.text(**text_inputs).pooler_output, dim=-1)
|
|
return text_features
|
|
|
|
|
|
def forward(self, image_inputs, text_inputs):
|
|
vision_features = self.encode_image(image_inputs)
|
|
|
|
text_features = self.encode_text(text_inputs)
|
|
|
|
return {
|
|
"vision_features": vision_features,
|
|
"text_features": text_features
|
|
} |