KEEP / modeling_keep.py
Astaxanthin's picture
Upload 6 files
06471d0 verified
raw
history blame
2.99 kB
from transformers import PretrainedConfig, PreTrainedModel, BertModel, BertConfig
import timm
import torch.nn as nn
import torch
import numpy
from torchvision import transforms
from PIL import Image
class RenameLayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.weight = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.weight) if self.inplace else x * self.weight
timm.models.vision_transformer.LayerScale = RenameLayerScale
class KEEPConfig(PretrainedConfig):
model_type = "keep" #
def __init__(
self,
vision_config=None, # Vision Encoder
text_config=None, # Text Encoder
projection_dim=768,
**kwargs,
):
super().__init__(**kwargs)
self.vision_config = vision_config
self.text_config = text_config
self.projection_dim = projection_dim
class KEEPModel(PreTrainedModel):
config_class = KEEPConfig #
def __init__(self, config):
super().__init__(config)
# Vision Encoder
vision_config = config.vision_config
self.visual = timm.create_model(
"vit_large_patch16_224",
pretrained=False,
img_size=vision_config["img_size"],
patch_size=vision_config["patch_size"],
init_values=vision_config["init_values"],
num_classes=vision_config["num_classes"],
)
self.visual_head = nn.Sequential(
nn.Linear(self.visual.num_features, config.projection_dim),
nn.GELU(),
nn.Linear(config.projection_dim, config.projection_dim)
)
# Text Encoder
text_config = BertConfig(**config.text_config)
self.text = BertModel(text_config)
self.logit_scale = nn.Parameter(torch.ones([]) * numpy.log(1 / 0.04))
def encode_image(self, image_inputs):
vision_features = self.visual(image_inputs) # [batch_size, vision_dim]
vision_features = torch.nn.functional.normalize(self.visual_head(vision_features), dim=-1) # [batch_size, projection_dim]
return vision_features
def encode_text(self, text_inputs):
text_features = torch.nn.functional.normalize(self.text(**text_inputs).pooler_output, dim=-1) # [batch_size, text_dim]
return text_features
def forward(self, image_inputs, text_inputs):
vision_features = self.encode_image(image_inputs)
text_features = self.encode_text(text_inputs)
return {
"vision_features": vision_features,
"text_features": text_features
}