from typing import List, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from transformers import ( PretrainedConfig, PreTrainedModel, SiglipVisionConfig, SiglipVisionModel, XLMRobertaConfig, XLMRobertaModel, ) class MexmaSigLIPConfig(PretrainedConfig): def __init__( self, optimized: bool = False, **kwargs, ): super().__init__(**kwargs) self.optimized = optimized class MLP(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() self.fc1 = nn.Linear(hidden_size, intermediate_size) self.fc2 = nn.Linear(intermediate_size, hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = nn.SiLU()(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class MultiheadAttentionPoolingHead(nn.Module): def __init__(self, hidden_size: int, out_hidden_size: int, num_attention_heads: int, layer_norm_eps: float, intermediate_size: int): super().__init__() self.probe = nn.Parameter(torch.randn(1, 1, hidden_size)) self.attention = torch.nn.MultiheadAttention(hidden_size, num_attention_heads, batch_first=True) self.layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.mlp = MLP(hidden_size, intermediate_size) self.projector = nn.Linear(hidden_size, out_hidden_size) def forward(self, hidden_state): batch_size = hidden_state.shape[0] probe = self.probe.repeat(batch_size, 1, 1) hidden_state = self.attention(probe, hidden_state, hidden_state)[0] residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) hidden_state = self.projector(hidden_state) return hidden_state[:, 0] class MexmaSigLIP(PreTrainedModel): config_class = MexmaSigLIPConfig def __init__(self, config: MexmaSigLIPConfig): super().__init__(config) self.config = config text_config = XLMRobertaConfig.from_pretrained("facebook/MEXMA") if self.config.optimized: text_config._attn_implementation = "sdpa" self.text_model = XLMRobertaModel(text_config, add_pooling_layer=False) self.text_projector = MultiheadAttentionPoolingHead(1024, 1152, 16, 1e-5, 4304) vision_congig = SiglipVisionConfig.from_pretrained( "google/siglip2-so400m-patch16-512" ) if self.config.optimized: vision_congig._attn_implementation = "flash_attention_2" vision_congig.torch_dtype = "bfloat16" self.vision_model = SiglipVisionModel(vision_congig).vision_model self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.logit_bias = torch.nn.Parameter(torch.ones([]) * -10) def forward(self, image_inputs, input_ids, attention_mask, normalize=False): text_features = self.encode_texts(input_ids, attention_mask, normalize) image_features = self.encode_images(image_inputs, normalize) return { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale, "logit_bias": self.logit_bias, } def encode_images( self, pixel_values, normalize=False, ): features = self.vision_model(pixel_values).pooler_output return F.normalize(features, dim=-1) if normalize else features def encode_texts( self, input_ids, attention_mask, normalize=False, ): features = self.text_model( input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state features = self.text_projector(features) return F.normalize(features, dim=-1) if normalize else features def get_logits( self, input_ids, attention_mask, pixel_values, ): image_features = self.encode_images(pixel_values, normalize=True) text_features = self.encode_texts(input_ids, attention_mask, normalize=True) image_logits = ( self.logit_scale.exp() * image_features @ text_features.T + self.logit_bias ) text_logits = image_logits.T return image_logits, text_logits