visheratin
commited on
Update nllb_mrl.py
Browse files- nllb_mrl.py +10 -4
nllb_mrl.py
CHANGED
@@ -15,12 +15,14 @@ class MatryoshkaNllbClipConfig(PretrainedConfig):
|
|
15 |
clip_model_name: str = "",
|
16 |
target_resolution: int = -1,
|
17 |
mrl_resolutions: List[int] = [],
|
|
|
18 |
**kwargs,
|
19 |
):
|
20 |
super().__init__(**kwargs)
|
21 |
self.clip_model_name = clip_model_name
|
22 |
self.target_resolution = target_resolution
|
23 |
self.mrl_resolutions = mrl_resolutions
|
|
|
24 |
|
25 |
|
26 |
class MatryoshkaLayer(nn.Module):
|
@@ -50,10 +52,14 @@ class MatryoshkaNllbClip(PreTrainedModel):
|
|
50 |
if isinstance(device, str):
|
51 |
device = torch.device(device)
|
52 |
self.config = config
|
53 |
-
self.model = create_model(
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
55 |
)
|
56 |
-
pp_cfg = PreprocessCfg(**self.model.visual.preprocess_cfg)
|
57 |
self.transform = image_transform_v2(
|
58 |
pp_cfg,
|
59 |
is_train=False,
|
@@ -106,7 +112,7 @@ class MatryoshkaNllbClip(PreTrainedModel):
|
|
106 |
)
|
107 |
features = self.matryoshka_layer.layers[str(resolution)](features)
|
108 |
return F.normalize(features, dim=-1) if normalize else features
|
109 |
-
|
110 |
def encode_text(
|
111 |
self,
|
112 |
text,
|
|
|
15 |
clip_model_name: str = "",
|
16 |
target_resolution: int = -1,
|
17 |
mrl_resolutions: List[int] = [],
|
18 |
+
preprocess_cfg: Union[dict, None] = None,
|
19 |
**kwargs,
|
20 |
):
|
21 |
super().__init__(**kwargs)
|
22 |
self.clip_model_name = clip_model_name
|
23 |
self.target_resolution = target_resolution
|
24 |
self.mrl_resolutions = mrl_resolutions
|
25 |
+
self.preprocess_cfg = preprocess_cfg
|
26 |
|
27 |
|
28 |
class MatryoshkaLayer(nn.Module):
|
|
|
52 |
if isinstance(device, str):
|
53 |
device = torch.device(device)
|
54 |
self.config = config
|
55 |
+
self.model = create_model(config.clip_model_name, output_dict=True)
|
56 |
+
pp_cfg = PreprocessCfg(
|
57 |
+
size=config.preprocess_cfg["size"],
|
58 |
+
mean=config.preprocess_cfg["mean"],
|
59 |
+
std=config.preprocess_cfg["std"],
|
60 |
+
interpolation=config.preprocess_cfg["interpolation"],
|
61 |
+
resize_mode=config.preprocess_cfg["resize_mode"],
|
62 |
)
|
|
|
63 |
self.transform = image_transform_v2(
|
64 |
pp_cfg,
|
65 |
is_train=False,
|
|
|
112 |
)
|
113 |
features = self.matryoshka_layer.layers[str(resolution)](features)
|
114 |
return F.normalize(features, dim=-1) if normalize else features
|
115 |
+
|
116 |
def encode_text(
|
117 |
self,
|
118 |
text,
|