File size: 2,985 Bytes
d5e7cfd 06471d0 d5e7cfd 06471d0 d5e7cfd 06471d0 d5e7cfd 06471d0 d5e7cfd 06471d0 9ab05d7 d5e7cfd 06471d0 d5e7cfd 06471d0 d5e7cfd 605f24b d5e7cfd 06471d0 d5e7cfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
from transformers import PretrainedConfig, PreTrainedModel, BertModel, BertConfig
import timm
import torch.nn as nn
import torch
import numpy
from torchvision import transforms
from PIL import Image
class RenameLayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.weight = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.weight) if self.inplace else x * self.weight
timm.models.vision_transformer.LayerScale = RenameLayerScale
class KEEPConfig(PretrainedConfig):
model_type = "keep" #
def __init__(
self,
vision_config=None, # Vision Encoder
text_config=None, # Text Encoder
projection_dim=768,
**kwargs,
):
super().__init__(**kwargs)
self.vision_config = vision_config
self.text_config = text_config
self.projection_dim = projection_dim
class KEEPModel(PreTrainedModel):
config_class = KEEPConfig #
def __init__(self, config):
super().__init__(config)
# Vision Encoder
vision_config = config.vision_config
self.visual = timm.create_model(
"vit_large_patch16_224",
pretrained=False,
img_size=vision_config["img_size"],
patch_size=vision_config["patch_size"],
init_values=vision_config["init_values"],
num_classes=vision_config["num_classes"],
)
self.visual_head = nn.Sequential(
nn.Linear(self.visual.num_features, config.projection_dim),
nn.GELU(),
nn.Linear(config.projection_dim, config.projection_dim)
)
# Text Encoder
text_config = BertConfig(**config.text_config)
self.text = BertModel(text_config)
self.logit_scale = nn.Parameter(torch.ones([]) * numpy.log(1 / 0.04))
def encode_image(self, image_inputs):
vision_features = self.visual(image_inputs) # [batch_size, vision_dim]
vision_features = torch.nn.functional.normalize(self.visual_head(vision_features), dim=-1) # [batch_size, projection_dim]
return vision_features
def encode_text(self, text_inputs):
text_features = torch.nn.functional.normalize(self.text(**text_inputs).pooler_output, dim=-1) # [batch_size, text_dim]
return text_features
def forward(self, image_inputs, text_inputs):
vision_features = self.encode_image(image_inputs)
text_features = self.encode_text(text_inputs)
return {
"vision_features": vision_features,
"text_features": text_features
} |