visheratin commited on
Commit
bee271e
·
verified ·
1 Parent(s): 1924a68

Update nllb_mrl.py

Browse files
Files changed (1) hide show
  1. nllb_mrl.py +10 -4
nllb_mrl.py CHANGED
@@ -15,12 +15,14 @@ class MatryoshkaNllbClipConfig(PretrainedConfig):
15
  clip_model_name: str = "",
16
  target_resolution: int = -1,
17
  mrl_resolutions: List[int] = [],
 
18
  **kwargs,
19
  ):
20
  super().__init__(**kwargs)
21
  self.clip_model_name = clip_model_name
22
  self.target_resolution = target_resolution
23
  self.mrl_resolutions = mrl_resolutions
 
24
 
25
 
26
  class MatryoshkaLayer(nn.Module):
@@ -50,10 +52,14 @@ class MatryoshkaNllbClip(PreTrainedModel):
50
  if isinstance(device, str):
51
  device = torch.device(device)
52
  self.config = config
53
- self.model = create_model(
54
- config.clip_model_name, output_dict=True
 
 
 
 
 
55
  )
56
- pp_cfg = PreprocessCfg(**self.model.visual.preprocess_cfg)
57
  self.transform = image_transform_v2(
58
  pp_cfg,
59
  is_train=False,
@@ -106,7 +112,7 @@ class MatryoshkaNllbClip(PreTrainedModel):
106
  )
107
  features = self.matryoshka_layer.layers[str(resolution)](features)
108
  return F.normalize(features, dim=-1) if normalize else features
109
-
110
  def encode_text(
111
  self,
112
  text,
 
15
  clip_model_name: str = "",
16
  target_resolution: int = -1,
17
  mrl_resolutions: List[int] = [],
18
+ preprocess_cfg: Union[dict, None] = None,
19
  **kwargs,
20
  ):
21
  super().__init__(**kwargs)
22
  self.clip_model_name = clip_model_name
23
  self.target_resolution = target_resolution
24
  self.mrl_resolutions = mrl_resolutions
25
+ self.preprocess_cfg = preprocess_cfg
26
 
27
 
28
  class MatryoshkaLayer(nn.Module):
 
52
  if isinstance(device, str):
53
  device = torch.device(device)
54
  self.config = config
55
+ self.model = create_model(config.clip_model_name, output_dict=True)
56
+ pp_cfg = PreprocessCfg(
57
+ size=config.preprocess_cfg["size"],
58
+ mean=config.preprocess_cfg["mean"],
59
+ std=config.preprocess_cfg["std"],
60
+ interpolation=config.preprocess_cfg["interpolation"],
61
+ resize_mode=config.preprocess_cfg["resize_mode"],
62
  )
 
63
  self.transform = image_transform_v2(
64
  pp_cfg,
65
  is_train=False,
 
112
  )
113
  features = self.matryoshka_layer.layers[str(resolution)](features)
114
  return F.normalize(features, dim=-1) if normalize else features
115
+
116
  def encode_text(
117
  self,
118
  text,