File size: 2,985 Bytes
d5e7cfd
 
 
 
 
 
 
 
06471d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5e7cfd
06471d0
d5e7cfd
 
 
06471d0
 
 
d5e7cfd
 
 
 
 
 
 
 
06471d0
d5e7cfd
 
 
 
06471d0
9ab05d7
d5e7cfd
 
 
06471d0
 
 
 
d5e7cfd
 
 
 
 
 
 
 
06471d0
d5e7cfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605f24b
d5e7cfd
 
06471d0
 
d5e7cfd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from transformers import PretrainedConfig, PreTrainedModel, BertModel, BertConfig
import timm
import torch.nn as nn
import torch
import numpy
from torchvision import transforms
from PIL import Image

class RenameLayerScale(nn.Module):
    def __init__(

            self,

            dim: int,

            init_values: float = 1e-5,

            inplace: bool = False,

    ) -> None:
        super().__init__()
        self.inplace = inplace
        self.weight = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.mul_(self.weight) if self.inplace else x * self.weight

timm.models.vision_transformer.LayerScale = RenameLayerScale

class KEEPConfig(PretrainedConfig):
    model_type = "keep"  # 

    def __init__(

        self,

        vision_config=None,  # Vision Encoder

        text_config=None,    # Text Encoder

        projection_dim=768,  

        **kwargs,

    ):
        super().__init__(**kwargs)
        self.vision_config = vision_config
        self.text_config = text_config
        self.projection_dim = projection_dim

class KEEPModel(PreTrainedModel):
    config_class = KEEPConfig  # 

    def __init__(self, config):
        super().__init__(config)

        # Vision Encoder 
        vision_config = config.vision_config
        self.visual = timm.create_model(
            "vit_large_patch16_224",
            pretrained=False,
            img_size=vision_config["img_size"],
            patch_size=vision_config["patch_size"],
            init_values=vision_config["init_values"],
            num_classes=vision_config["num_classes"],
        )

        self.visual_head = nn.Sequential(
                    nn.Linear(self.visual.num_features, config.projection_dim),
                    nn.GELU(),
                    nn.Linear(config.projection_dim, config.projection_dim)
                )

        # Text Encoder
        text_config =  BertConfig(**config.text_config)
        self.text = BertModel(text_config)
        
        self.logit_scale = nn.Parameter(torch.ones([]) * numpy.log(1 / 0.04))

    def encode_image(self, image_inputs):
        vision_features = self.visual(image_inputs)  # [batch_size, vision_dim]
        vision_features =  torch.nn.functional.normalize(self.visual_head(vision_features), dim=-1)  # [batch_size, projection_dim]
        
        return vision_features
    
    def encode_text(self, text_inputs):
        text_features = torch.nn.functional.normalize(self.text(**text_inputs).pooler_output, dim=-1)  # [batch_size, text_dim]
        return text_features
    
    
    def forward(self, image_inputs, text_inputs):
        vision_features = self.encode_image(image_inputs)
        
        text_features = self.encode_text(text_inputs)

        return {
            "vision_features": vision_features,  
            "text_features": text_features       
        }