sachin commited on
Commit
24d96ab
·
1 Parent(s): 3b13f40

Refactoring to distangle modules

Browse files
Files changed (5) hide show
  1. src/config.py +2 -14
  2. src/models.py +6 -5
  3. src/tokenizer.py +5 -3
  4. src/trainer.py +19 -1
  5. src/vision_model.py +14 -5
src/config.py CHANGED
@@ -65,7 +65,7 @@ class TinyCLIPConfig(PretrainedConfig):
65
  max_len: int = 128,
66
  cls_type: bool = True,
67
  freeze_vision_base: bool = False,
68
- freeze_text_base: bool = False,
69
  loss_type: str = "cyclip",
70
  **kwargs,
71
  ):
@@ -85,18 +85,6 @@ class TinyCLIPConfig(PretrainedConfig):
85
  super().__init__(**kwargs)
86
 
87
 
88
- class ModelConfig(pydantic.BaseModel):
89
- text_model: str = "microsoft/xtremedistil-l6-h256-uncased" # 51 mb
90
- vision_model: str = "edgenext_small" # 20 mb
91
- projection_layers: int = 3
92
- embed_dim: int = 256
93
- transformer_embed_dim: int = 768
94
- max_len: int = 128 # 77
95
- cls_type: bool = True
96
- freeze_vision_base: bool = False
97
- freeze_text_base: bool = False
98
-
99
-
100
  class TrainerConfig(pydantic.BaseModel):
101
  epochs: int = 20
102
  batch_size: int = 64
@@ -112,5 +100,5 @@ class TrainerConfig(pydantic.BaseModel):
112
 
113
  run_openai_clip: bool = False
114
 
115
- _model_config: ModelConfig = ModelConfig()
116
  _data_config: DataConfig = DataConfig()
 
65
  max_len: int = 128,
66
  cls_type: bool = True,
67
  freeze_vision_base: bool = False,
68
+ freeze_text_base: bool = True,
69
  loss_type: str = "cyclip",
70
  **kwargs,
71
  ):
 
85
  super().__init__(**kwargs)
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  class TrainerConfig(pydantic.BaseModel):
89
  epochs: int = 20
90
  batch_size: int = 64
 
100
 
101
  run_openai_clip: bool = False
102
 
103
+ _model_config: TinyCLIPConfig = TinyCLIPConfig()
104
  _data_config: DataConfig = DataConfig()
src/models.py CHANGED
@@ -1,14 +1,14 @@
1
  from PIL import Image
2
- import timm
3
- from timm import data
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
- import transformers
8
  from transformers import PreTrainedModel
9
 
10
  from src.config import TinyCLIPConfig, TinyCLIPTextConfig, TinyCLIPVisionConfig
11
  from src import loss
 
12
 
13
 
14
  class Projection(nn.Module):
@@ -70,9 +70,10 @@ class TinyCLIPVisionEncoder(PreTrainedModel):
70
 
71
  def __init__(self, config: TinyCLIPVisionConfig):
72
  super().__init__(config)
73
-
 
74
  self.projection = projection_layers(
75
- self.base.num_features, config.embed_dims, config.projection_layers
76
  )
77
 
78
  def forward(self, images: list[Image.Image]):
 
1
  from PIL import Image
2
+ import transformers
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
+
7
  from transformers import PreTrainedModel
8
 
9
  from src.config import TinyCLIPConfig, TinyCLIPTextConfig, TinyCLIPVisionConfig
10
  from src import loss
11
+ from src import vision_model
12
 
13
 
14
  class Projection(nn.Module):
 
70
 
71
  def __init__(self, config: TinyCLIPVisionConfig):
72
  super().__init__(config)
73
+ base, num_features = vision_model.get_vision_base(config)
74
+ self.base = base
75
  self.projection = projection_layers(
76
+ num_features, config.embed_dims, config.projection_layers
77
  )
78
 
79
  def forward(self, images: list[Image.Image]):
src/tokenizer.py CHANGED
@@ -3,11 +3,13 @@ from typing import Union
3
  import torch
4
  from transformers import AutoTokenizer
5
 
 
 
6
 
7
  class Tokenizer:
8
- def __init__(self, model_name: str, max_len: int) -> None:
9
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- self.max_len = max_len
11
 
12
  def __call__(self, x: Union[str, list[str]]) -> dict[str, torch.LongTensor]:
13
  return self.tokenizer(
 
3
  import torch
4
  from transformers import AutoTokenizer
5
 
6
+ from src.config import TinyCLIPTextConfig
7
+
8
 
9
  class Tokenizer:
10
+ def __init__(self, text_config: TinyCLIPTextConfig) -> None:
11
+ self.tokenizer = AutoTokenizer.from_pretrained(text_config.text_model)
12
+ self.max_len = text_config.max_len
13
 
14
  def __call__(self, x: Union[str, list[str]]) -> dict[str, torch.LongTensor]:
15
  return self.tokenizer(
src/trainer.py CHANGED
@@ -1,7 +1,25 @@
1
  from src import data
2
  from src import config
3
  from src import vision_model
 
 
 
 
4
 
5
 
6
  def train(config: config.TrainerConfig):
7
- train_dl, valid_dl = data.get_dataset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from src import data
2
  from src import config
3
  from src import vision_model
4
+ from src import tokenizer as tk
5
+ from src.lightning_module import LightningModule
6
+ from src import loss
7
+ from src import models
8
 
9
 
10
  def train(config: config.TrainerConfig):
11
+ transform = vision_model.get_vision_transform(config._model_config.vision_config)
12
+ tokenizer = tk.Tokenizer(config._model_config.text_config)
13
+ train_dl, valid_dl = data.get_dataset(
14
+ transform=transform, tokenizer=tokenizer, hyper_parameters=config # type: ignore
15
+ )
16
+ vision_encoder = models.TinyCLIPVisionEncoder(config=config._model_config.vision_config)
17
+ text_encoder = models.TinyCLIPTextEncoder(config=config._model_config.text_config)
18
+
19
+ lightning_module = LightningModule(
20
+ vision_encoder=vision_encoder,
21
+ text_encoder=text_encoder,
22
+ loss_fn=loss.get_loss(config._model_config.loss_type),
23
+ hyper_parameters=config,
24
+ len_train_dl=len(train_dl),
25
+ )
src/vision_model.py CHANGED
@@ -1,11 +1,20 @@
1
  import timm
2
  from timm import data
 
 
3
 
4
- from src import config
5
 
6
 
7
- def get_vision_base_and_transform(config: config.TrainerConfig):
8
- base = timm.create_model(config._model_config.vision_model, num_classes=0)
9
- timm_config = data.resolve_data_config({}, model=base)
 
 
 
 
 
 
 
10
  transform = data.transforms_factory.create_transform(**timm_config)
11
- return base, transform
 
1
  import timm
2
  from timm import data
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
 
6
+ from src.config import TinyCLIPVisionConfig
7
 
8
 
9
+ def get_vision_base(
10
+ config: TinyCLIPVisionConfig,
11
+ ) -> tuple[nn.Module, int]:
12
+ base = timm.create_model(config.vision_model, num_classes=0, pretrained=True)
13
+ num_features = base.num_features
14
+ return base, num_features
15
+
16
+
17
+ def get_vision_transform(config: TinyCLIPVisionConfig) -> transforms.Compose:
18
+ timm_config = data.resolve_data_config({}, model=config.vision_model)
19
  transform = data.transforms_factory.create_transform(**timm_config)
20
+ return transform # type: ignore