Astaxanthin commited on
Commit
d5e7cfd
·
verified ·
1 Parent(s): 55a05ca

Upload modeling_keep.py

Browse files
Files changed (1) hide show
  1. modeling_keep.py +75 -0
modeling_keep.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel, BertModel, BertConfig
2
+ import timm
3
+ import torch.nn as nn
4
+ import torch
5
+ import numpy
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+
9
+ class KEEPConfig(PretrainedConfig):
10
+ model_type = "keep" # 标记模型类型
11
+
12
+ def __init__(
13
+ self,
14
+ vision_config=None, # Vision Encoder 的配置
15
+ text_config=None, # Text Encoder 的配置
16
+ projection_dim=768, # 投影维度,默认为 768
17
+ **kwargs,
18
+ ):
19
+ super().__init__(**kwargs)
20
+ self.vision_config = vision_config
21
+ self.text_config = text_config
22
+ self.projection_dim = projection_dim
23
+
24
+
25
+ class KEEPModel(PreTrainedModel):
26
+ config_class = KEEPConfig # 绑定到自定义配置类
27
+
28
+ def __init__(self, config):
29
+ super().__init__(config)
30
+
31
+ # Vision Encoder (基于 timm 的 ViT)
32
+ self.visual = timm.create_model(
33
+ "vit_large_patch16_224",
34
+ pretrained=False,
35
+ img_size=224,
36
+ patch_size=16,
37
+ init_values=1e-5,
38
+ num_classes=0,
39
+ dynamic_img_size=True,
40
+ )
41
+
42
+ # 线性投影层,将 Vision Encoder 的输出投影到 768 维
43
+ self.visual_head = nn.Sequential(
44
+ nn.Linear(self.visual.num_features, config.projection_dim),
45
+ nn.GELU(),
46
+ nn.Linear(config.projection_dim, config.projection_dim)
47
+ )
48
+
49
+ # Text Encoder (基于 PubMedBERT)
50
+ text_config = BertConfig(**config.text_config)
51
+ self.text = BertModel(text_config)
52
+
53
+ self.logit_scale = nn.Parameter(torch.ones([]) * numpy.log(1 / 0.04))
54
+
55
+ def encode_image(self, image_inputs):
56
+ vision_features = self.visual(image_inputs) # [batch_size, vision_dim]
57
+ vision_features = torch.nn.functional.normalize(self.visual_head(vision_features), dim=-1) # [batch_size, projection_dim]
58
+
59
+ return vision_features
60
+
61
+ def encode_text(self, text_inputs):
62
+ text_features = torch.nn.functional.normalize(self.text(**text_inputs).pooler_output, dim=-1) # [batch_size, text_dim]
63
+ return text_features
64
+
65
+
66
+ def forward(self, image_inputs, text_inputs):
67
+ vision_features = self.encode_image(image_inputs)
68
+
69
+ text_features = self.encode_image(text_inputs)
70
+
71
+ # 返回两个独立的特征
72
+ return {
73
+ "vision_features": vision_features, # 视觉特征
74
+ "text_features": text_features # 文本特征
75
+ }