visheratin commited on
Commit
f01b1a7
·
verified ·
1 Parent(s): 485e043

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +4 -0
  2. mexma_siglip.py +126 -0
config.json CHANGED
@@ -1,4 +1,8 @@
1
  {
 
 
 
 
2
  "architectures": [
3
  "MexmaSigLIP"
4
  ],
 
1
  {
2
+ "auto_map": {
3
+ "AutoConfig": "mexma_siglip.MexmaSigLIPConfig",
4
+ "AutoModel": "mexma_siglip.MexmaSigLIP"
5
+ },
6
  "architectures": [
7
  "MexmaSigLIP"
8
  ],
mexma_siglip.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from PIL import Image
8
+ from transformers import (
9
+ PretrainedConfig,
10
+ PreTrainedModel,
11
+ SiglipVisionConfig,
12
+ SiglipVisionModel,
13
+ XLMRobertaConfig,
14
+ XLMRobertaModel,
15
+ )
16
+
17
+
18
+ class MexmaSigLIPConfig(PretrainedConfig):
19
+ def __init__(
20
+ self,
21
+ optimized: bool = False,
22
+ **kwargs,
23
+ ):
24
+ super().__init__(**kwargs)
25
+ self.optimized = optimized
26
+
27
+
28
+ class MLP(nn.Module):
29
+ def __init__(self, hidden_size: int, intermediate_size: int):
30
+ super().__init__()
31
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
32
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
33
+
34
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
35
+ hidden_states = self.fc1(hidden_states)
36
+ hidden_states = nn.SiLU()(hidden_states)
37
+ hidden_states = self.fc2(hidden_states)
38
+ return hidden_states
39
+
40
+ class MultiheadAttentionPoolingHead(nn.Module):
41
+ def __init__(self, hidden_size: int, out_hidden_size: int, num_attention_heads: int, layer_norm_eps: float, intermediate_size: int):
42
+ super().__init__()
43
+
44
+ self.probe = nn.Parameter(torch.randn(1, 1, hidden_size))
45
+ self.attention = torch.nn.MultiheadAttention(hidden_size, num_attention_heads, batch_first=True)
46
+ self.layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
47
+ self.mlp = MLP(hidden_size, intermediate_size)
48
+ self.projector = nn.Linear(hidden_size, out_hidden_size)
49
+
50
+ def forward(self, hidden_state):
51
+ batch_size = hidden_state.shape[0]
52
+ probe = self.probe.repeat(batch_size, 1, 1)
53
+
54
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
55
+
56
+ residual = hidden_state
57
+ hidden_state = self.layernorm(hidden_state)
58
+ hidden_state = residual + self.mlp(hidden_state)
59
+ hidden_state = self.projector(hidden_state)
60
+ return hidden_state[:, 0]
61
+
62
+
63
+ class MexmaSigLIP(PreTrainedModel):
64
+ config_class = MexmaSigLIPConfig
65
+
66
+ def __init__(self, config: MexmaSigLIPConfig):
67
+ super().__init__(config)
68
+ self.config = config
69
+ text_config = XLMRobertaConfig.from_pretrained("facebook/MEXMA")
70
+ if self.config.optimized:
71
+ text_config._attn_implementation = "sdpa"
72
+ self.text_model = XLMRobertaModel(text_config, add_pooling_layer=False)
73
+ self.text_projector = MultiheadAttentionPoolingHead(1024, 1152, 16, 1e-5, 4304)
74
+ vision_congig = SiglipVisionConfig.from_pretrained(
75
+ "google/siglip2-so400m-patch16-512"
76
+ )
77
+ if self.config.optimized:
78
+ vision_congig._attn_implementation = "flash_attention_2"
79
+ vision_congig.torch_dtype = "bfloat16"
80
+ self.vision_model = SiglipVisionModel(vision_congig).vision_model
81
+ self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
82
+ self.logit_bias = torch.nn.Parameter(torch.ones([]) * -10)
83
+
84
+ def forward(self, image_inputs, input_ids, attention_mask, normalize=False):
85
+ text_features = self.encode_texts(input_ids, attention_mask, normalize)
86
+ image_features = self.encode_images(image_inputs, normalize)
87
+ return {
88
+ "image_features": image_features,
89
+ "text_features": text_features,
90
+ "logit_scale": self.logit_scale,
91
+ "logit_bias": self.logit_bias,
92
+ }
93
+
94
+ def encode_images(
95
+ self,
96
+ pixel_values,
97
+ normalize=False,
98
+ ):
99
+ features = self.vision_model(pixel_values).pooler_output
100
+ return F.normalize(features, dim=-1) if normalize else features
101
+
102
+ def encode_texts(
103
+ self,
104
+ input_ids,
105
+ attention_mask,
106
+ normalize=False,
107
+ ):
108
+ features = self.text_model(
109
+ input_ids=input_ids, attention_mask=attention_mask
110
+ ).last_hidden_state
111
+ features = self.text_projector(features)
112
+ return F.normalize(features, dim=-1) if normalize else features
113
+
114
+ def get_logits(
115
+ self,
116
+ input_ids,
117
+ attention_mask,
118
+ pixel_values,
119
+ ):
120
+ image_features = self.encode_images(pixel_values, normalize=True)
121
+ text_features = self.encode_texts(input_ids, attention_mask, normalize=True)
122
+ image_logits = (
123
+ self.logit_scale.exp() * image_features @ text_features.T + self.logit_bias
124
+ )
125
+ text_logits = image_logits.T
126
+ return image_logits, text_logits