michaeldinzinger commited on
Commit
ec964c6
·
1 Parent(s): 356c916

Add modeling

Browse files
Files changed (1) hide show
  1. modeling_arctic_m_bge_small.py +77 -0
modeling_arctic_m_bge_small.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import BertModel, PreTrainedModel, BertConfig, PretrainedConfig, AutoModel
5
+ from typing import *
6
+
7
+
8
+ class ConcatModelConfig(PretrainedConfig):
9
+ model_type = "arctic-m-bge-small"
10
+
11
+ def __init__(self, **kwargs):
12
+ super().__init__(**kwargs)
13
+
14
+
15
+ # See https://huggingface.co/Marqo/marqo-chimera-arctic-bge-m
16
+ class ConcatModel(PreTrainedModel):
17
+ config_class = ConcatModelConfig
18
+
19
+ def __init__(self, config: ConcatModelConfig):
20
+ super().__init__(config)
21
+ bert_config = BertConfig(
22
+ vocab_size=30522,
23
+ hidden_size=768,
24
+ num_hidden_layers=12,
25
+ num_attention_heads=12,
26
+ intermediate_size=3072,
27
+ hidden_act="gelu",
28
+ hidden_dropout_prob=0.1,
29
+ attention_probs_dropout_prob=0.1,
30
+ max_position_embeddings=512,
31
+ type_vocab_size=2,
32
+ initializer_range=0.02,
33
+ layer_norm_eps=1e-12,
34
+ )
35
+
36
+ self.model = nn.ModuleDict(
37
+ {
38
+ "model_0": BertModel(bert_config),
39
+ "model_1": BertModel(bert_config),
40
+ }
41
+ )
42
+
43
+ def forward(
44
+ self,
45
+ input_ids: torch.Tensor,
46
+ attention_mask: torch.Tensor,
47
+ token_type_ids: torch.Tensor = None,
48
+ ) -> torch.Tensor:
49
+ embeddings = []
50
+ for _, model in self.model.items():
51
+ model_output = model(
52
+ input_ids=input_ids,
53
+ attention_mask=attention_mask,
54
+ token_type_ids=token_type_ids,
55
+ )
56
+ pooled_output = model_output[0][:, 0]
57
+ pooled_output = F.normalize(pooled_output, p=2, dim=-1)
58
+ embeddings.append(pooled_output)
59
+
60
+ return torch.cat(embeddings, dim=-1)
61
+
62
+ def load_weights_from_automodels(
63
+ self, in_models: List[str], has_pooling_layer: List[bool]
64
+ ):
65
+ model_list = []
66
+ for i, model_name in enumerate(in_models):
67
+ model = AutoModel.from_pretrained(
68
+ model_name,
69
+ add_pooling_layer=has_pooling_layer[i],
70
+ trust_remote_code=True,
71
+ )
72
+ model.eval()
73
+ model_list.append(model)
74
+
75
+ self.model = nn.ModuleDict(
76
+ {f"model_{i}": model for i, model in enumerate(model_list)}
77
+ )