yuezih commited on
Commit
ca19ab4
·
1 Parent(s): 455badc
Files changed (46) hide show
  1. .gitattributes +2 -0
  2. README.md +1 -1
  3. SMILE/BLIP/CHANGELOG +24 -0
  4. SMILE/BLIP/__init__.py +0 -0
  5. SMILE/BLIP/configs/bert_config.json +21 -0
  6. SMILE/BLIP/configs/caption_coco.yaml +32 -0
  7. SMILE/BLIP/configs/med_config.json +21 -0
  8. SMILE/BLIP/configs/nlvr.yaml +21 -0
  9. SMILE/BLIP/configs/nocaps.yaml +15 -0
  10. SMILE/BLIP/configs/pretrain.yaml +27 -0
  11. SMILE/BLIP/configs/retrieval_coco.yaml +34 -0
  12. SMILE/BLIP/configs/retrieval_flickr.yaml +34 -0
  13. SMILE/BLIP/configs/retrieval_msrvtt.yaml +12 -0
  14. SMILE/BLIP/configs/vqa.yaml +25 -0
  15. SMILE/BLIP/data/__init__.py +101 -0
  16. SMILE/BLIP/data/coco_karpathy_dataset.py +126 -0
  17. SMILE/BLIP/data/flickr30k_dataset.py +93 -0
  18. SMILE/BLIP/data/nlvr_dataset.py +78 -0
  19. SMILE/BLIP/data/nocaps_dataset.py +32 -0
  20. SMILE/BLIP/data/pretrain_dataset.py +59 -0
  21. SMILE/BLIP/data/utils.py +112 -0
  22. SMILE/BLIP/data/video_dataset.py +110 -0
  23. SMILE/BLIP/data/vqa_dataset.py +88 -0
  24. SMILE/BLIP/demo.ipynb +0 -0
  25. SMILE/BLIP/models/__init__.py +0 -0
  26. SMILE/BLIP/models/blip.py +238 -0
  27. SMILE/BLIP/models/blip_vqa.py +186 -0
  28. SMILE/BLIP/models/med.py +955 -0
  29. SMILE/BLIP/models/model.py +211 -0
  30. SMILE/BLIP/models/vit.py +305 -0
  31. SMILE/BLIP/requirements.txt +4 -0
  32. SMILE/BLIP/scripts/eval.sh +9 -0
  33. SMILE/BLIP/scripts/train.sh +7 -0
  34. SMILE/BLIP/train_caption.py +221 -0
  35. SMILE/BLIP/transform/randaugment.py +340 -0
  36. SMILE/BLIP/utils.py +278 -0
  37. SMILE/LICENSE +40 -0
  38. SMILE/README.md +102 -0
  39. SMILE/__init__.py +0 -0
  40. app.py +70 -0
  41. example/COCO_val2014_000000001682.jpg +0 -0
  42. example/COCO_val2014_000000093534.jpg +0 -0
  43. example/COCO_val2014_000000411845.jpg +0 -0
  44. example/COCO_val2014_000000473133.jpg +0 -0
  45. example/COCO_val2014_000000562150.jpg +0 -0
  46. requirements.txt +7 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model/blip_mle_smile_base.pth filter=lfs diff=lfs merge=lfs -text
37
+ model/blip_smile_base.pth filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: BLIP SMILE
3
- emoji: 🏢
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
 
1
  ---
2
  title: BLIP SMILE
3
+ emoji: 🌩
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
SMILE/BLIP/CHANGELOG ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Compared to the original BLIP code, the following changes have been made by Zihao Yue:
2
+
3
+ deleted: BLIP.gif
4
+ deleted: CODEOWNERS
5
+ deleted: CODE_OF_CONDUCT.md
6
+ deleted: LICENSE.txt
7
+ deleted: README.md
8
+ deleted: SECURITY.md
9
+ deleted: cog.yaml
10
+ deleted: eval_nocaps.py
11
+ deleted: eval_retrieval_video.py
12
+ deleted: models/blip_itm.py
13
+ deleted: models/blip_nlvr.py
14
+ deleted: models/blip_pretrain.py
15
+ deleted: models/blip_retrieval.py
16
+ deleted: models/nlvr_encoder.py
17
+ deleted: train_vqa.py
18
+ deleted: predict.py
19
+ deleted: pretrain.py
20
+ deleted: train_nlvr.py
21
+ deleted: train_retrieval.py
22
+ modified: train_caption.py
23
+ modified: configs/caption_coco.yaml
24
+ modified: demo.ipynb
SMILE/BLIP/__init__.py ADDED
File without changes
SMILE/BLIP/configs/bert_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30522,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
SMILE/BLIP/configs/caption_coco.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/coco/images/'
2
+ ann_root: 'annotation'
3
+ coco_gt_root: 'annotation/coco_gt'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
7
+
8
+ # size of vit model; base or large
9
+ vit: 'base'
10
+ vit_grad_ckpt: False
11
+ vit_ckpt_layer: 0
12
+ batch_size: 32
13
+ init_lr: 1e-5
14
+
15
+ # vit: 'large'
16
+ # vit_grad_ckpt: True
17
+ # vit_ckpt_layer: 5
18
+ # batch_size: 16
19
+ # init_lr: 2e-6
20
+
21
+ image_size: 384
22
+
23
+ # generation configs
24
+ max_length: 75
25
+ min_length: 1
26
+ num_beams: 3
27
+ prompt: 'a picture of '
28
+
29
+ # optimizer
30
+ weight_decay: 0.05
31
+ min_lr: 0
32
+ max_epoch: 5
SMILE/BLIP/configs/med_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30524,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
SMILE/BLIP/configs/nlvr.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/NLVR2/'
2
+ ann_root: 'annotation'
3
+
4
+ # set pretrained as a file path or an url
5
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
6
+
7
+ #size of vit model; base or large
8
+ vit: 'base'
9
+ batch_size_train: 16
10
+ batch_size_test: 64
11
+ vit_grad_ckpt: False
12
+ vit_ckpt_layer: 0
13
+ max_epoch: 15
14
+
15
+ image_size: 384
16
+
17
+ # optimizer
18
+ weight_decay: 0.05
19
+ init_lr: 3e-5
20
+ min_lr: 0
21
+
SMILE/BLIP/configs/nocaps.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/nocaps/'
2
+ ann_root: 'annotation'
3
+
4
+ # set pretrained as a file path or an url
5
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
6
+
7
+ vit: 'base'
8
+ batch_size: 32
9
+
10
+ image_size: 384
11
+
12
+ max_length: 20
13
+ min_length: 5
14
+ num_beams: 3
15
+ prompt: 'a picture of '
SMILE/BLIP/configs/pretrain.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
2
+ '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
3
+ ]
4
+ laion_path: ''
5
+
6
+ # size of vit model; base or large
7
+ vit: 'base'
8
+ vit_grad_ckpt: False
9
+ vit_ckpt_layer: 0
10
+
11
+ image_size: 224
12
+ batch_size: 75
13
+
14
+ queue_size: 57600
15
+ alpha: 0.4
16
+
17
+ # optimizer
18
+ weight_decay: 0.05
19
+ init_lr: 3e-4
20
+ min_lr: 1e-6
21
+ warmup_lr: 1e-6
22
+ lr_decay_rate: 0.9
23
+ max_epoch: 20
24
+ warmup_steps: 3000
25
+
26
+
27
+
SMILE/BLIP/configs/retrieval_coco.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/coco/images/'
2
+ ann_root: 'annotation'
3
+ dataset: 'coco'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
7
+
8
+ # size of vit model; base or large
9
+
10
+ vit: 'base'
11
+ batch_size_train: 32
12
+ batch_size_test: 64
13
+ vit_grad_ckpt: True
14
+ vit_ckpt_layer: 4
15
+ init_lr: 1e-5
16
+
17
+ # vit: 'large'
18
+ # batch_size_train: 16
19
+ # batch_size_test: 32
20
+ # vit_grad_ckpt: True
21
+ # vit_ckpt_layer: 12
22
+ # init_lr: 5e-6
23
+
24
+ image_size: 384
25
+ queue_size: 57600
26
+ alpha: 0.4
27
+ k_test: 256
28
+ negative_all_rank: True
29
+
30
+ # optimizer
31
+ weight_decay: 0.05
32
+ min_lr: 0
33
+ max_epoch: 6
34
+
SMILE/BLIP/configs/retrieval_flickr.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/flickr30k/'
2
+ ann_root: 'annotation'
3
+ dataset: 'flickr'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
7
+
8
+ # size of vit model; base or large
9
+
10
+ vit: 'base'
11
+ batch_size_train: 32
12
+ batch_size_test: 64
13
+ vit_grad_ckpt: True
14
+ vit_ckpt_layer: 4
15
+ init_lr: 1e-5
16
+
17
+ # vit: 'large'
18
+ # batch_size_train: 16
19
+ # batch_size_test: 32
20
+ # vit_grad_ckpt: True
21
+ # vit_ckpt_layer: 10
22
+ # init_lr: 5e-6
23
+
24
+ image_size: 384
25
+ queue_size: 57600
26
+ alpha: 0.4
27
+ k_test: 128
28
+ negative_all_rank: False
29
+
30
+ # optimizer
31
+ weight_decay: 0.05
32
+ min_lr: 0
33
+ max_epoch: 6
34
+
SMILE/BLIP/configs/retrieval_msrvtt.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
2
+ ann_root: 'annotation'
3
+
4
+ # set pretrained as a file path or an url
5
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
6
+
7
+ # size of vit model; base or large
8
+ vit: 'base'
9
+ batch_size: 64
10
+ k_test: 128
11
+ image_size: 384
12
+ num_frm_test: 8
SMILE/BLIP/configs/vqa.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
2
+ vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
3
+ train_files: ['vqa_train','vqa_val','vg_qa']
4
+ ann_root: 'annotation'
5
+
6
+ # set pretrained as a file path or an url
7
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
8
+
9
+ # size of vit model; base or large
10
+ vit: 'base'
11
+ batch_size_train: 16
12
+ batch_size_test: 32
13
+ vit_grad_ckpt: False
14
+ vit_ckpt_layer: 0
15
+ init_lr: 2e-5
16
+
17
+ image_size: 480
18
+
19
+ k_test: 128
20
+ inference: 'rank'
21
+
22
+ # optimizer
23
+ weight_decay: 0.05
24
+ min_lr: 0
25
+ max_epoch: 10
SMILE/BLIP/data/__init__.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from torchvision import transforms
4
+ from torchvision.transforms.functional import InterpolationMode
5
+
6
+ from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
7
+ from data.nocaps_dataset import nocaps_eval
8
+ from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
9
+ from data.vqa_dataset import vqa_dataset
10
+ from data.nlvr_dataset import nlvr_dataset
11
+ from data.pretrain_dataset import pretrain_dataset
12
+ from transform.randaugment import RandomAugment
13
+
14
+ def create_dataset(dataset, config, min_scale=0.5):
15
+
16
+ normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
17
+
18
+ transform_train = transforms.Compose([
19
+ transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
20
+ transforms.RandomHorizontalFlip(),
21
+ RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
22
+ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
23
+ transforms.ToTensor(),
24
+ normalize,
25
+ ])
26
+ transform_test = transforms.Compose([
27
+ transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
28
+ transforms.ToTensor(),
29
+ normalize,
30
+ ])
31
+
32
+ if dataset=='pretrain':
33
+ dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
34
+ return dataset
35
+
36
+ elif dataset=='caption_coco':
37
+ train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
38
+ val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
39
+ test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
40
+ return train_dataset, val_dataset, test_dataset
41
+
42
+ elif dataset=='nocaps':
43
+ val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
44
+ test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
45
+ return val_dataset, test_dataset
46
+
47
+ elif dataset=='retrieval_coco':
48
+ train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
49
+ val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
50
+ test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
51
+ return train_dataset, val_dataset, test_dataset
52
+
53
+ elif dataset=='retrieval_flickr':
54
+ train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
55
+ val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
56
+ test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
57
+ return train_dataset, val_dataset, test_dataset
58
+
59
+ elif dataset=='vqa':
60
+ train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
61
+ train_files = config['train_files'], split='train')
62
+ test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
63
+ return train_dataset, test_dataset
64
+
65
+ elif dataset=='nlvr':
66
+ train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
67
+ val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
68
+ test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
69
+ return train_dataset, val_dataset, test_dataset
70
+
71
+
72
+ def create_sampler(datasets, shuffles, num_tasks, global_rank):
73
+ samplers = []
74
+ for dataset,shuffle in zip(datasets,shuffles):
75
+ sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
76
+ samplers.append(sampler)
77
+ return samplers
78
+
79
+
80
+ def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
81
+ loaders = []
82
+ for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
83
+ if is_train:
84
+ shuffle = (sampler is None)
85
+ drop_last = True
86
+ else:
87
+ shuffle = False
88
+ drop_last = False
89
+ loader = DataLoader(
90
+ dataset,
91
+ batch_size=bs,
92
+ num_workers=n_worker,
93
+ pin_memory=True,
94
+ sampler=sampler,
95
+ shuffle=shuffle,
96
+ collate_fn=collate_fn,
97
+ drop_last=drop_last,
98
+ )
99
+ loaders.append(loader)
100
+ return loaders
101
+
SMILE/BLIP/data/coco_karpathy_dataset.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import Dataset
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from PIL import Image
8
+
9
+ from data.utils import pre_caption
10
+
11
+ class coco_karpathy_train(Dataset):
12
+ def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
13
+ '''
14
+ image_root (string): Root directory of images (e.g. coco/images/)
15
+ ann_root (string): directory to store the annotation file
16
+ '''
17
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
18
+ filename = 'coco_karpathy_train.json'
19
+
20
+ download_url(url,ann_root)
21
+
22
+ self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
23
+ self.transform = transform
24
+ self.image_root = image_root
25
+ self.max_words = max_words
26
+ self.prompt = prompt
27
+
28
+ self.img_ids = {}
29
+ n = 0
30
+ for ann in self.annotation:
31
+ img_id = ann['image_id']
32
+ if img_id not in self.img_ids.keys():
33
+ self.img_ids[img_id] = n
34
+ n += 1
35
+
36
+ def __len__(self):
37
+ return len(self.annotation)
38
+
39
+ def __getitem__(self, index):
40
+
41
+ ann = self.annotation[index]
42
+
43
+ image_path = os.path.join(self.image_root,ann['image'])
44
+ image = Image.open(image_path).convert('RGB')
45
+ image = self.transform(image)
46
+
47
+ caption = self.prompt+pre_caption(ann['caption'], self.max_words)
48
+
49
+ return image, caption, self.img_ids[ann['image_id']]
50
+
51
+
52
+ class coco_karpathy_caption_eval(Dataset):
53
+ def __init__(self, transform, image_root, ann_root, split):
54
+ '''
55
+ image_root (string): Root directory of images (e.g. coco/images/)
56
+ ann_root (string): directory to store the annotation file
57
+ split (string): val or test
58
+ '''
59
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
60
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
61
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
62
+
63
+ download_url(urls[split],ann_root)
64
+
65
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
66
+ self.transform = transform
67
+ self.image_root = image_root
68
+
69
+ def __len__(self):
70
+ return len(self.annotation)
71
+
72
+ def __getitem__(self, index):
73
+
74
+ ann = self.annotation[index]
75
+
76
+ image_path = os.path.join(self.image_root,ann['image'])
77
+ image = Image.open(image_path).convert('RGB')
78
+ image = self.transform(image)
79
+
80
+ img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
81
+
82
+ return image, int(img_id)
83
+
84
+
85
+ class coco_karpathy_retrieval_eval(Dataset):
86
+ def __init__(self, transform, image_root, ann_root, split, max_words=30):
87
+ '''
88
+ image_root (string): Root directory of images (e.g. coco/images/)
89
+ ann_root (string): directory to store the annotation file
90
+ split (string): val or test
91
+ '''
92
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
93
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
94
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
95
+
96
+ download_url(urls[split],ann_root)
97
+
98
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
99
+ self.transform = transform
100
+ self.image_root = image_root
101
+
102
+ self.text = []
103
+ self.image = []
104
+ self.txt2img = {}
105
+ self.img2txt = {}
106
+
107
+ txt_id = 0
108
+ for img_id, ann in enumerate(self.annotation):
109
+ self.image.append(ann['image'])
110
+ self.img2txt[img_id] = []
111
+ for i, caption in enumerate(ann['caption']):
112
+ self.text.append(pre_caption(caption,max_words))
113
+ self.img2txt[img_id].append(txt_id)
114
+ self.txt2img[txt_id] = img_id
115
+ txt_id += 1
116
+
117
+ def __len__(self):
118
+ return len(self.annotation)
119
+
120
+ def __getitem__(self, index):
121
+
122
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
123
+ image = Image.open(image_path).convert('RGB')
124
+ image = self.transform(image)
125
+
126
+ return image, index
SMILE/BLIP/data/flickr30k_dataset.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import Dataset
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from PIL import Image
8
+
9
+ from data.utils import pre_caption
10
+
11
+ class flickr30k_train(Dataset):
12
+ def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
13
+ '''
14
+ image_root (string): Root directory of images (e.g. flickr30k/)
15
+ ann_root (string): directory to store the annotation file
16
+ '''
17
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
18
+ filename = 'flickr30k_train.json'
19
+
20
+ download_url(url,ann_root)
21
+
22
+ self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
23
+ self.transform = transform
24
+ self.image_root = image_root
25
+ self.max_words = max_words
26
+ self.prompt = prompt
27
+
28
+ self.img_ids = {}
29
+ n = 0
30
+ for ann in self.annotation:
31
+ img_id = ann['image_id']
32
+ if img_id not in self.img_ids.keys():
33
+ self.img_ids[img_id] = n
34
+ n += 1
35
+
36
+ def __len__(self):
37
+ return len(self.annotation)
38
+
39
+ def __getitem__(self, index):
40
+
41
+ ann = self.annotation[index]
42
+
43
+ image_path = os.path.join(self.image_root,ann['image'])
44
+ image = Image.open(image_path).convert('RGB')
45
+ image = self.transform(image)
46
+
47
+ caption = self.prompt+pre_caption(ann['caption'], self.max_words)
48
+
49
+ return image, caption, self.img_ids[ann['image_id']]
50
+
51
+
52
+ class flickr30k_retrieval_eval(Dataset):
53
+ def __init__(self, transform, image_root, ann_root, split, max_words=30):
54
+ '''
55
+ image_root (string): Root directory of images (e.g. flickr30k/)
56
+ ann_root (string): directory to store the annotation file
57
+ split (string): val or test
58
+ '''
59
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
60
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
61
+ filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
62
+
63
+ download_url(urls[split],ann_root)
64
+
65
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
66
+ self.transform = transform
67
+ self.image_root = image_root
68
+
69
+ self.text = []
70
+ self.image = []
71
+ self.txt2img = {}
72
+ self.img2txt = {}
73
+
74
+ txt_id = 0
75
+ for img_id, ann in enumerate(self.annotation):
76
+ self.image.append(ann['image'])
77
+ self.img2txt[img_id] = []
78
+ for i, caption in enumerate(ann['caption']):
79
+ self.text.append(pre_caption(caption,max_words))
80
+ self.img2txt[img_id].append(txt_id)
81
+ self.txt2img[txt_id] = img_id
82
+ txt_id += 1
83
+
84
+ def __len__(self):
85
+ return len(self.annotation)
86
+
87
+ def __getitem__(self, index):
88
+
89
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
90
+ image = Image.open(image_path).convert('RGB')
91
+ image = self.transform(image)
92
+
93
+ return image, index
SMILE/BLIP/data/nlvr_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+
5
+ from torch.utils.data import Dataset
6
+ from torchvision.datasets.utils import download_url
7
+
8
+ from PIL import Image
9
+
10
+ from data.utils import pre_caption
11
+
12
+ class nlvr_dataset(Dataset):
13
+ def __init__(self, transform, image_root, ann_root, split):
14
+ '''
15
+ image_root (string): Root directory of images
16
+ ann_root (string): directory to store the annotation file
17
+ split (string): train, val or test
18
+ '''
19
+ urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
20
+ 'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
21
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
22
+ filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
23
+
24
+ download_url(urls[split],ann_root)
25
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
26
+
27
+ self.transform = transform
28
+ self.image_root = image_root
29
+
30
+
31
+ def __len__(self):
32
+ return len(self.annotation)
33
+
34
+
35
+ def __getitem__(self, index):
36
+
37
+ ann = self.annotation[index]
38
+
39
+ image0_path = os.path.join(self.image_root,ann['images'][0])
40
+ image0 = Image.open(image0_path).convert('RGB')
41
+ image0 = self.transform(image0)
42
+
43
+ image1_path = os.path.join(self.image_root,ann['images'][1])
44
+ image1 = Image.open(image1_path).convert('RGB')
45
+ image1 = self.transform(image1)
46
+
47
+ sentence = pre_caption(ann['sentence'], 40)
48
+
49
+ if ann['label']=='True':
50
+ label = 1
51
+ else:
52
+ label = 0
53
+
54
+ words = sentence.split(' ')
55
+
56
+ if 'left' not in words and 'right' not in words:
57
+ if random.random()<0.5:
58
+ return image0, image1, sentence, label
59
+ else:
60
+ return image1, image0, sentence, label
61
+ else:
62
+ if random.random()<0.5:
63
+ return image0, image1, sentence, label
64
+ else:
65
+ new_words = []
66
+ for word in words:
67
+ if word=='left':
68
+ new_words.append('right')
69
+ elif word=='right':
70
+ new_words.append('left')
71
+ else:
72
+ new_words.append(word)
73
+
74
+ sentence = ' '.join(new_words)
75
+ return image1, image0, sentence, label
76
+
77
+
78
+
SMILE/BLIP/data/nocaps_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import Dataset
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from PIL import Image
8
+
9
+ class nocaps_eval(Dataset):
10
+ def __init__(self, transform, image_root, ann_root, split):
11
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json',
12
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'}
13
+ filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'}
14
+
15
+ download_url(urls[split],ann_root)
16
+
17
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
18
+ self.transform = transform
19
+ self.image_root = image_root
20
+
21
+ def __len__(self):
22
+ return len(self.annotation)
23
+
24
+ def __getitem__(self, index):
25
+
26
+ ann = self.annotation[index]
27
+
28
+ image_path = os.path.join(self.image_root,ann['image'])
29
+ image = Image.open(image_path).convert('RGB')
30
+ image = self.transform(image)
31
+
32
+ return image, int(ann['img_id'])
SMILE/BLIP/data/pretrain_dataset.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ from torch.utils.data import Dataset
6
+
7
+ from PIL import Image
8
+ from PIL import ImageFile
9
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
10
+ Image.MAX_IMAGE_PIXELS = None
11
+
12
+ from data.utils import pre_caption
13
+ import os,glob
14
+
15
+ class pretrain_dataset(Dataset):
16
+ def __init__(self, ann_file, laion_path, transform):
17
+
18
+ self.ann_pretrain = []
19
+ for f in ann_file:
20
+ print('loading '+f)
21
+ ann = json.load(open(f,'r'))
22
+ self.ann_pretrain += ann
23
+
24
+ self.laion_path = laion_path
25
+ if self.laion_path:
26
+ self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
27
+
28
+ print('loading '+self.laion_files[0])
29
+ with open(self.laion_files[0],'r') as f:
30
+ self.ann_laion = json.load(f)
31
+
32
+ self.annotation = self.ann_pretrain + self.ann_laion
33
+ else:
34
+ self.annotation = self.ann_pretrain
35
+
36
+ self.transform = transform
37
+
38
+
39
+ def reload_laion(self, epoch):
40
+ n = epoch%len(self.laion_files)
41
+ print('loading '+self.laion_files[n])
42
+ with open(self.laion_files[n],'r') as f:
43
+ self.ann_laion = json.load(f)
44
+
45
+ self.annotation = self.ann_pretrain + self.ann_laion
46
+
47
+
48
+ def __len__(self):
49
+ return len(self.annotation)
50
+
51
+ def __getitem__(self, index):
52
+
53
+ ann = self.annotation[index]
54
+
55
+ image = Image.open(ann['image']).convert('RGB')
56
+ image = self.transform(image)
57
+ caption = pre_caption(ann['caption'],30)
58
+
59
+ return image, caption
SMILE/BLIP/data/utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import os
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ import utils
9
+
10
+ def pre_caption(caption,max_words=50):
11
+ caption = re.sub(
12
+ r"([.!\"()*#:;~])",
13
+ ' ',
14
+ caption.lower(),
15
+ )
16
+ caption = re.sub(
17
+ r"\s{2,}",
18
+ ' ',
19
+ caption,
20
+ )
21
+ caption = caption.rstrip('\n')
22
+ caption = caption.strip(' ')
23
+
24
+ #truncate caption
25
+ caption_words = caption.split(' ')
26
+ if len(caption_words)>max_words:
27
+ caption = ' '.join(caption_words[:max_words])
28
+
29
+ return caption
30
+
31
+ def pre_question(question,max_ques_words=50):
32
+ question = re.sub(
33
+ r"([.!\"()*#:;~])",
34
+ '',
35
+ question.lower(),
36
+ )
37
+ question = question.rstrip(' ')
38
+
39
+ #truncate question
40
+ question_words = question.split(' ')
41
+ if len(question_words)>max_ques_words:
42
+ question = ' '.join(question_words[:max_ques_words])
43
+
44
+ return question
45
+
46
+
47
+ def save_result(result, result_dir, filename, remove_duplicate=''):
48
+ result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
49
+ final_result_file = os.path.join(result_dir, '%s.json'%filename)
50
+
51
+ json.dump(result,open(result_file,'w'))
52
+
53
+ dist.barrier()
54
+
55
+ if utils.is_main_process():
56
+ # combine results from all processes
57
+ result = []
58
+
59
+ for rank in range(utils.get_world_size()):
60
+ result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
61
+ res = json.load(open(result_file,'r'))
62
+ result += res
63
+
64
+ if remove_duplicate:
65
+ result_new = []
66
+ id_list = []
67
+ for res in result:
68
+ if res[remove_duplicate] not in id_list:
69
+ id_list.append(res[remove_duplicate])
70
+ result_new.append(res)
71
+ result = result_new
72
+
73
+ json.dump(result,open(final_result_file,'w'))
74
+ print('result file saved to %s'%final_result_file)
75
+
76
+ return final_result_file
77
+
78
+
79
+
80
+ from pycocotools.coco import COCO
81
+ from pycocoevalcap.eval import COCOEvalCap
82
+ from torchvision.datasets.utils import download_url
83
+
84
+ def coco_caption_eval(coco_gt_root, results_file, split):
85
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
86
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
87
+ filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
88
+
89
+ download_url(urls[split],coco_gt_root)
90
+ annotation_file = os.path.join(coco_gt_root,filenames[split])
91
+
92
+ # create coco object and coco_result object
93
+ coco = COCO(annotation_file)
94
+ coco_result = coco.loadRes(results_file)
95
+
96
+ # create coco_eval object by taking coco and coco_result
97
+ coco_eval = COCOEvalCap(coco, coco_result)
98
+
99
+ # evaluate on a subset of images by setting
100
+ # coco_eval.params['image_id'] = coco_result.getImgIds()
101
+ # please remove this line when evaluating the full validation set
102
+ # coco_eval.params['image_id'] = coco_result.getImgIds()
103
+
104
+ # evaluate results
105
+ # SPICE will take a few minutes the first time, but speeds up due to caching
106
+ coco_eval.evaluate()
107
+
108
+ # print output evaluation scores
109
+ for metric, score in coco_eval.eval.items():
110
+ print(f'{metric}: {score:.3f}')
111
+
112
+ return coco_eval
SMILE/BLIP/data/video_dataset.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from torchvision.datasets.utils import download_url
3
+
4
+ from PIL import Image
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import decord
9
+ from decord import VideoReader
10
+ import json
11
+ import os
12
+ from data.utils import pre_caption
13
+
14
+ decord.bridge.set_bridge("torch")
15
+
16
+ class ImageNorm(object):
17
+ """Apply Normalization to Image Pixels on GPU
18
+ """
19
+ def __init__(self, mean, std):
20
+ self.mean = torch.tensor(mean).view(1, 3, 1, 1)
21
+ self.std = torch.tensor(std).view(1, 3, 1, 1)
22
+
23
+ def __call__(self, img):
24
+
25
+ if torch.max(img) > 1 and self.mean.max() <= 1:
26
+ img.div_(255.)
27
+ return img.sub_(self.mean).div_(self.std)
28
+
29
+ def load_jsonl(filename):
30
+ with open(filename, "r") as f:
31
+ return [json.loads(l.strip("\n")) for l in f.readlines()]
32
+
33
+
34
+ class VideoDataset(Dataset):
35
+
36
+ def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'):
37
+ '''
38
+ image_root (string): Root directory of video
39
+ ann_root (string): directory to store the annotation file
40
+ '''
41
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl'
42
+ filename = 'msrvtt_test.jsonl'
43
+
44
+ download_url(url,ann_root)
45
+ self.annotation = load_jsonl(os.path.join(ann_root,filename))
46
+
47
+ self.num_frm = num_frm
48
+ self.frm_sampling_strategy = frm_sampling_strategy
49
+ self.max_img_size = max_img_size
50
+ self.video_root = video_root
51
+ self.video_fmt = video_fmt
52
+ self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
53
+
54
+ self.text = [pre_caption(ann['caption'],40) for ann in self.annotation]
55
+ self.txt2video = [i for i in range(len(self.annotation))]
56
+ self.video2txt = self.txt2video
57
+
58
+
59
+ def __len__(self):
60
+ return len(self.annotation)
61
+
62
+ def __getitem__(self, index):
63
+
64
+ ann = self.annotation[index]
65
+
66
+ video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt)
67
+
68
+ vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)
69
+
70
+ video = self.img_norm(vid_frm_array.float())
71
+
72
+ return video, ann['clip_name']
73
+
74
+
75
+
76
+ def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1):
77
+ try:
78
+ if not height or not width:
79
+ vr = VideoReader(video_path)
80
+ else:
81
+ vr = VideoReader(video_path, width=width, height=height)
82
+
83
+ vlen = len(vr)
84
+
85
+ if start_time or end_time:
86
+ assert fps > 0, 'must provide video fps if specifying start and end time.'
87
+
88
+ start_idx = min(int(start_time * fps), vlen)
89
+ end_idx = min(int(end_time * fps), vlen)
90
+ else:
91
+ start_idx, end_idx = 0, vlen
92
+
93
+ if self.frm_sampling_strategy == 'uniform':
94
+ frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int)
95
+ elif self.frm_sampling_strategy == 'rand':
96
+ frame_indices = sorted(random.sample(range(vlen), self.num_frm))
97
+ elif self.frm_sampling_strategy == 'headtail':
98
+ frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2))
99
+ frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2))
100
+ frame_indices = frame_indices_head + frame_indices_tail
101
+ else:
102
+ raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy))
103
+
104
+ raw_sample_frms = vr.get_batch(frame_indices)
105
+ except Exception as e:
106
+ return None
107
+
108
+ raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2)
109
+
110
+ return raw_sample_frms
SMILE/BLIP/data/vqa_dataset.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from PIL import Image
5
+
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from data.utils import pre_question
9
+
10
+ from torchvision.datasets.utils import download_url
11
+
12
+ class vqa_dataset(Dataset):
13
+ def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
14
+ self.split = split
15
+
16
+ self.transform = transform
17
+ self.vqa_root = vqa_root
18
+ self.vg_root = vg_root
19
+
20
+ if split=='train':
21
+ urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json',
22
+ 'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json',
23
+ 'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'}
24
+
25
+ self.annotation = []
26
+ for f in train_files:
27
+ download_url(urls[f],ann_root)
28
+ self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r'))
29
+ else:
30
+ download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root)
31
+ self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r'))
32
+
33
+ download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root)
34
+ self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r'))
35
+
36
+
37
+ def __len__(self):
38
+ return len(self.annotation)
39
+
40
+ def __getitem__(self, index):
41
+
42
+ ann = self.annotation[index]
43
+
44
+ if ann['dataset']=='vqa':
45
+ image_path = os.path.join(self.vqa_root,ann['image'])
46
+ elif ann['dataset']=='vg':
47
+ image_path = os.path.join(self.vg_root,ann['image'])
48
+
49
+ image = Image.open(image_path).convert('RGB')
50
+ image = self.transform(image)
51
+
52
+ if self.split == 'test':
53
+ question = pre_question(ann['question'])
54
+ question_id = ann['question_id']
55
+ return image, question, question_id
56
+
57
+
58
+ elif self.split=='train':
59
+
60
+ question = pre_question(ann['question'])
61
+
62
+ if ann['dataset']=='vqa':
63
+ answer_weight = {}
64
+ for answer in ann['answer']:
65
+ if answer in answer_weight.keys():
66
+ answer_weight[answer] += 1/len(ann['answer'])
67
+ else:
68
+ answer_weight[answer] = 1/len(ann['answer'])
69
+
70
+ answers = list(answer_weight.keys())
71
+ weights = list(answer_weight.values())
72
+
73
+ elif ann['dataset']=='vg':
74
+ answers = [ann['answer']]
75
+ weights = [0.2]
76
+
77
+ return image, question, answers, weights
78
+
79
+
80
+ def vqa_collate_fn(batch):
81
+ image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
82
+ for image, question, answer, weights in batch:
83
+ image_list.append(image)
84
+ question_list.append(question)
85
+ weight_list += weights
86
+ answer_list += answer
87
+ n.append(len(answer))
88
+ return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n
SMILE/BLIP/demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
SMILE/BLIP/models/__init__.py ADDED
File without changes
SMILE/BLIP/models/blip.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+
11
+ from models.vit import VisionTransformer, interpolate_pos_embed
12
+ from models.med import BertConfig, BertModel, BertLMHeadModel
13
+ from transformers import BertTokenizer
14
+
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+
19
+ import os
20
+ from urllib.parse import urlparse
21
+ from timm.models.hub import download_cached_file
22
+
23
+ class BLIP_Base(nn.Module):
24
+ def __init__(self,
25
+ med_config = 'configs/med_config.json',
26
+ image_size = 224,
27
+ vit = 'base',
28
+ vit_grad_ckpt = False,
29
+ vit_ckpt_layer = 0,
30
+ ):
31
+ """
32
+ Args:
33
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
34
+ image_size (int): input image size
35
+ vit (str): model size of vision transformer
36
+ """
37
+ super().__init__()
38
+
39
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
40
+ self.tokenizer = init_tokenizer()
41
+ med_config = BertConfig.from_json_file(med_config)
42
+ med_config.encoder_width = vision_width
43
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
44
+
45
+
46
+ def forward(self, image, caption, mode):
47
+
48
+ assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
49
+ text = self.tokenizer(caption, return_tensors="pt").to(image.device)
50
+
51
+ if mode=='image':
52
+ # return image features
53
+ image_embeds = self.visual_encoder(image)
54
+ return image_embeds
55
+
56
+ elif mode=='text':
57
+ # return text features
58
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
59
+ return_dict = True, mode = 'text')
60
+ return text_output.last_hidden_state
61
+
62
+ elif mode=='multimodal':
63
+ # return multimodel features
64
+ image_embeds = self.visual_encoder(image)
65
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
66
+
67
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
68
+ output = self.text_encoder(text.input_ids,
69
+ attention_mask = text.attention_mask,
70
+ encoder_hidden_states = image_embeds,
71
+ encoder_attention_mask = image_atts,
72
+ return_dict = True,
73
+ )
74
+ return output.last_hidden_state
75
+
76
+
77
+
78
+ class BLIP_Decoder(nn.Module):
79
+ def __init__(self,
80
+ med_config = 'configs/med_config.json',
81
+ image_size = 384,
82
+ vit = 'base',
83
+ vit_grad_ckpt = False,
84
+ vit_ckpt_layer = 0,
85
+ prompt = 'a picture of ',
86
+ ):
87
+ """
88
+ Args:
89
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
90
+ image_size (int): input image size
91
+ vit (str): model size of vision transformer
92
+ """
93
+ super().__init__()
94
+
95
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
96
+ self.tokenizer = init_tokenizer()
97
+ med_config = BertConfig.from_json_file(med_config)
98
+ med_config.encoder_width = vision_width
99
+ self.text_decoder = BertLMHeadModel(config=med_config)
100
+
101
+ self.prompt = prompt
102
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
103
+
104
+
105
+ def forward(self, image, caption):
106
+
107
+ image_embeds = self.visual_encoder(image)
108
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
109
+
110
+ text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
111
+
112
+ text.input_ids[:,0] = self.tokenizer.bos_token_id
113
+
114
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
115
+ decoder_targets[:,:self.prompt_length] = -100
116
+
117
+ decoder_output = self.text_decoder(text.input_ids,
118
+ attention_mask = text.attention_mask,
119
+ encoder_hidden_states = image_embeds,
120
+ encoder_attention_mask = image_atts,
121
+ labels = decoder_targets,
122
+ return_dict = True,
123
+ )
124
+ loss_lm = decoder_output.loss
125
+
126
+ return loss_lm
127
+
128
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
129
+ image_embeds = self.visual_encoder(image)
130
+
131
+ if not sample:
132
+ image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
133
+
134
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
135
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
136
+
137
+ prompt = [self.prompt] * image.size(0)
138
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
139
+ input_ids[:,0] = self.tokenizer.bos_token_id
140
+ input_ids = input_ids[:, :-1]
141
+
142
+ if sample:
143
+ #nucleus sampling
144
+ outputs = self.text_decoder.generate(input_ids=input_ids,
145
+ max_length=max_length,
146
+ min_length=min_length,
147
+ do_sample=True,
148
+ top_p=top_p,
149
+ num_return_sequences=1,
150
+ eos_token_id=self.tokenizer.sep_token_id,
151
+ pad_token_id=self.tokenizer.pad_token_id,
152
+ repetition_penalty=1.1,
153
+ **model_kwargs)
154
+ else:
155
+ #beam search
156
+ outputs = self.text_decoder.generate(input_ids=input_ids,
157
+ max_length=max_length,
158
+ min_length=min_length,
159
+ num_beams=num_beams,
160
+ eos_token_id=self.tokenizer.sep_token_id,
161
+ pad_token_id=self.tokenizer.pad_token_id,
162
+ repetition_penalty=repetition_penalty,
163
+ **model_kwargs)
164
+
165
+ captions = []
166
+ for output in outputs:
167
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
168
+ captions.append(caption[len(self.prompt):])
169
+ return captions
170
+
171
+
172
+ def blip_decoder(pretrained='',**kwargs):
173
+ model = BLIP_Decoder(**kwargs)
174
+ if pretrained:
175
+ model,msg = load_checkpoint(model,pretrained)
176
+ assert(len(msg.missing_keys)==0)
177
+ return model
178
+
179
+ def blip_feature_extractor(pretrained='',**kwargs):
180
+ model = BLIP_Base(**kwargs)
181
+ if pretrained:
182
+ model,msg = load_checkpoint(model,pretrained)
183
+ assert(len(msg.missing_keys)==0)
184
+ return model
185
+
186
+ def init_tokenizer():
187
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
188
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
189
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
190
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
191
+ return tokenizer
192
+
193
+
194
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
195
+
196
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
197
+ if vit=='base':
198
+ vision_width = 768
199
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
200
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
201
+ drop_path_rate=0 or drop_path_rate
202
+ )
203
+ elif vit=='large':
204
+ vision_width = 1024
205
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
206
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
207
+ drop_path_rate=0.1 or drop_path_rate
208
+ )
209
+ return visual_encoder, vision_width
210
+
211
+ def is_url(url_or_filename):
212
+ parsed = urlparse(url_or_filename)
213
+ return parsed.scheme in ("http", "https")
214
+
215
+ def load_checkpoint(model,url_or_filename):
216
+ if is_url(url_or_filename):
217
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
218
+ checkpoint = torch.load(cached_file, map_location='cpu')
219
+ elif os.path.isfile(url_or_filename):
220
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
221
+ else:
222
+ raise RuntimeError('checkpoint url or path is invalid')
223
+
224
+ state_dict = checkpoint['model']
225
+
226
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
227
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
228
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
229
+ model.visual_encoder_m)
230
+ for key in model.state_dict().keys():
231
+ if key in state_dict.keys():
232
+ if state_dict[key].shape!=model.state_dict()[key].shape:
233
+ del state_dict[key]
234
+
235
+ msg = model.load_state_dict(state_dict,strict=False)
236
+ print('load checkpoint from %s'%url_or_filename)
237
+ return model,msg
238
+
SMILE/BLIP/models/blip_vqa.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.med import BertConfig, BertModel, BertLMHeadModel
2
+ from models.blip import create_vit, init_tokenizer, load_checkpoint
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ from transformers import BertTokenizer
8
+ import numpy as np
9
+
10
+ class BLIP_VQA(nn.Module):
11
+ def __init__(self,
12
+ med_config = 'configs/med_config.json',
13
+ image_size = 480,
14
+ vit = 'base',
15
+ vit_grad_ckpt = False,
16
+ vit_ckpt_layer = 0,
17
+ ):
18
+ """
19
+ Args:
20
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
21
+ image_size (int): input image size
22
+ vit (str): model size of vision transformer
23
+ """
24
+ super().__init__()
25
+
26
+ self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
27
+ self.tokenizer = init_tokenizer()
28
+
29
+ encoder_config = BertConfig.from_json_file(med_config)
30
+ encoder_config.encoder_width = vision_width
31
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
32
+
33
+ decoder_config = BertConfig.from_json_file(med_config)
34
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
35
+
36
+
37
+ def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
38
+
39
+ image_embeds = self.visual_encoder(image)
40
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
41
+
42
+ question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
43
+ return_tensors="pt").to(image.device)
44
+ question.input_ids[:,0] = self.tokenizer.enc_token_id
45
+
46
+ if train:
47
+ '''
48
+ n: number of answers for each question
49
+ weights: weight for each answer
50
+ '''
51
+ answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
52
+ answer.input_ids[:,0] = self.tokenizer.bos_token_id
53
+ answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
54
+
55
+ question_output = self.text_encoder(question.input_ids,
56
+ attention_mask = question.attention_mask,
57
+ encoder_hidden_states = image_embeds,
58
+ encoder_attention_mask = image_atts,
59
+ return_dict = True)
60
+
61
+ question_states = []
62
+ question_atts = []
63
+ for b, n in enumerate(n):
64
+ question_states += [question_output.last_hidden_state[b]]*n
65
+ question_atts += [question.attention_mask[b]]*n
66
+ question_states = torch.stack(question_states,0)
67
+ question_atts = torch.stack(question_atts,0)
68
+
69
+ answer_output = self.text_decoder(answer.input_ids,
70
+ attention_mask = answer.attention_mask,
71
+ encoder_hidden_states = question_states,
72
+ encoder_attention_mask = question_atts,
73
+ labels = answer_targets,
74
+ return_dict = True,
75
+ reduction = 'none',
76
+ )
77
+
78
+ loss = weights * answer_output.loss
79
+ loss = loss.sum()/image.size(0)
80
+
81
+ return loss
82
+
83
+
84
+ else:
85
+ question_output = self.text_encoder(question.input_ids,
86
+ attention_mask = question.attention_mask,
87
+ encoder_hidden_states = image_embeds,
88
+ encoder_attention_mask = image_atts,
89
+ return_dict = True)
90
+
91
+ if inference=='generate':
92
+ num_beams = 3
93
+ question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
94
+ question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
95
+ model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
96
+
97
+ bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
98
+
99
+ outputs = self.text_decoder.generate(input_ids=bos_ids,
100
+ max_length=10,
101
+ min_length=1,
102
+ num_beams=num_beams,
103
+ eos_token_id=self.tokenizer.sep_token_id,
104
+ pad_token_id=self.tokenizer.pad_token_id,
105
+ **model_kwargs)
106
+
107
+ answers = []
108
+ for output in outputs:
109
+ answer = self.tokenizer.decode(output, skip_special_tokens=True)
110
+ answers.append(answer)
111
+ return answers
112
+
113
+ elif inference=='rank':
114
+ max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
115
+ answer.input_ids, answer.attention_mask, k_test)
116
+ return max_ids
117
+
118
+
119
+
120
+ def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
121
+
122
+ num_ques = question_states.size(0)
123
+ start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
124
+
125
+ start_output = self.text_decoder(start_ids,
126
+ encoder_hidden_states = question_states,
127
+ encoder_attention_mask = question_atts,
128
+ return_dict = True,
129
+ reduction = 'none')
130
+ logits = start_output.logits[:,0,:] # first token's logit
131
+
132
+ # topk_probs: top-k probability
133
+ # topk_ids: [num_question, k]
134
+ answer_first_token = answer_ids[:,1]
135
+ prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
136
+ topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
137
+
138
+ # answer input: [num_question*k, answer_len]
139
+ input_ids = []
140
+ input_atts = []
141
+ for b, topk_id in enumerate(topk_ids):
142
+ input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
143
+ input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
144
+ input_ids = torch.cat(input_ids,dim=0)
145
+ input_atts = torch.cat(input_atts,dim=0)
146
+
147
+ targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
148
+
149
+ # repeat encoder's output for top-k answers
150
+ question_states = tile(question_states, 0, k)
151
+ question_atts = tile(question_atts, 0, k)
152
+
153
+ output = self.text_decoder(input_ids,
154
+ attention_mask = input_atts,
155
+ encoder_hidden_states = question_states,
156
+ encoder_attention_mask = question_atts,
157
+ labels = targets_ids,
158
+ return_dict = True,
159
+ reduction = 'none')
160
+
161
+ log_probs_sum = -output.loss
162
+ log_probs_sum = log_probs_sum.view(num_ques,k)
163
+
164
+ max_topk_ids = log_probs_sum.argmax(dim=1)
165
+ max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
166
+
167
+ return max_ids
168
+
169
+
170
+ def blip_vqa(pretrained='',**kwargs):
171
+ model = BLIP_VQA(**kwargs)
172
+ if pretrained:
173
+ model,msg = load_checkpoint(model,pretrained)
174
+ # assert(len(msg.missing_keys)==0)
175
+ return model
176
+
177
+
178
+ def tile(x, dim, n_tile):
179
+ init_dim = x.size(dim)
180
+ repeat_idx = [1] * x.dim()
181
+ repeat_idx[dim] = n_tile
182
+ x = x.repeat(*(repeat_idx))
183
+ order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
184
+ return torch.index_select(x, dim, order_index.to(x.device))
185
+
186
+
SMILE/BLIP/models/med.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ '''
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+
65
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68
+
69
+ self.config = config
70
+
71
+ def forward(
72
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73
+ ):
74
+ if input_ids is not None:
75
+ input_shape = input_ids.size()
76
+ else:
77
+ input_shape = inputs_embeds.size()[:-1]
78
+
79
+ seq_length = input_shape[1]
80
+
81
+ if position_ids is None:
82
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83
+
84
+ if inputs_embeds is None:
85
+ inputs_embeds = self.word_embeddings(input_ids)
86
+
87
+ embeddings = inputs_embeds
88
+
89
+ if self.position_embedding_type == "absolute":
90
+ position_embeddings = self.position_embeddings(position_ids)
91
+ embeddings += position_embeddings
92
+ embeddings = self.LayerNorm(embeddings)
93
+ embeddings = self.dropout(embeddings)
94
+ return embeddings
95
+
96
+
97
+ class BertSelfAttention(nn.Module):
98
+ def __init__(self, config, is_cross_attention):
99
+ super().__init__()
100
+ self.config = config
101
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
102
+ raise ValueError(
103
+ "The hidden size (%d) is not a multiple of the number of attention "
104
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
105
+ )
106
+
107
+ self.num_attention_heads = config.num_attention_heads
108
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
109
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
110
+
111
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
112
+ if is_cross_attention:
113
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
114
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
115
+ else:
116
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
117
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
118
+
119
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
121
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
122
+ self.max_position_embeddings = config.max_position_embeddings
123
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
124
+ self.save_attention = False
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def save_attention_map(self, attention_map):
133
+ self.attention_map = attention_map
134
+
135
+ def get_attention_map(self):
136
+ return self.attention_map
137
+
138
+ def transpose_for_scores(self, x):
139
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
140
+ x = x.view(*new_x_shape)
141
+ return x.permute(0, 2, 1, 3)
142
+
143
+ def forward(
144
+ self,
145
+ hidden_states,
146
+ attention_mask=None,
147
+ head_mask=None,
148
+ encoder_hidden_states=None,
149
+ encoder_attention_mask=None,
150
+ past_key_value=None,
151
+ output_attentions=False,
152
+ ):
153
+ mixed_query_layer = self.query(hidden_states)
154
+
155
+ # If this is instantiated as a cross-attention module, the keys
156
+ # and values come from an encoder; the attention mask needs to be
157
+ # such that the encoder's padding tokens are not attended to.
158
+ is_cross_attention = encoder_hidden_states is not None
159
+
160
+ if is_cross_attention:
161
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
162
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
163
+ attention_mask = encoder_attention_mask
164
+ elif past_key_value is not None:
165
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
166
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
167
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
168
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
169
+ else:
170
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
171
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
172
+
173
+ query_layer = self.transpose_for_scores(mixed_query_layer)
174
+
175
+ past_key_value = (key_layer, value_layer)
176
+
177
+ # Take the dot product between "query" and "key" to get the raw attention scores.
178
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
179
+
180
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
181
+ seq_length = hidden_states.size()[1]
182
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
183
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
184
+ distance = position_ids_l - position_ids_r
185
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
186
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
187
+
188
+ if self.position_embedding_type == "relative_key":
189
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
190
+ attention_scores = attention_scores + relative_position_scores
191
+ elif self.position_embedding_type == "relative_key_query":
192
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
193
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
194
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
195
+
196
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
197
+ if attention_mask is not None:
198
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
199
+ attention_scores = attention_scores + attention_mask
200
+
201
+ # Normalize the attention scores to probabilities.
202
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
203
+
204
+ if is_cross_attention and self.save_attention:
205
+ self.save_attention_map(attention_probs)
206
+ attention_probs.register_hook(self.save_attn_gradients)
207
+
208
+ # This is actually dropping out entire tokens to attend to, which might
209
+ # seem a bit unusual, but is taken from the original Transformer paper.
210
+ attention_probs_dropped = self.dropout(attention_probs)
211
+
212
+ # Mask heads if we want to
213
+ if head_mask is not None:
214
+ attention_probs_dropped = attention_probs_dropped * head_mask
215
+
216
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
217
+
218
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
219
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
220
+ context_layer = context_layer.view(*new_context_layer_shape)
221
+
222
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
223
+
224
+ outputs = outputs + (past_key_value,)
225
+ return outputs
226
+
227
+
228
+ class BertSelfOutput(nn.Module):
229
+ def __init__(self, config):
230
+ super().__init__()
231
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
232
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
233
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
234
+
235
+ def forward(self, hidden_states, input_tensor):
236
+ hidden_states = self.dense(hidden_states)
237
+ hidden_states = self.dropout(hidden_states)
238
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
239
+ return hidden_states
240
+
241
+
242
+ class BertAttention(nn.Module):
243
+ def __init__(self, config, is_cross_attention=False):
244
+ super().__init__()
245
+ self.self = BertSelfAttention(config, is_cross_attention)
246
+ self.output = BertSelfOutput(config)
247
+ self.pruned_heads = set()
248
+
249
+ def prune_heads(self, heads):
250
+ if len(heads) == 0:
251
+ return
252
+ heads, index = find_pruneable_heads_and_indices(
253
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
254
+ )
255
+
256
+ # Prune linear layers
257
+ self.self.query = prune_linear_layer(self.self.query, index)
258
+ self.self.key = prune_linear_layer(self.self.key, index)
259
+ self.self.value = prune_linear_layer(self.self.value, index)
260
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
261
+
262
+ # Update hyper params and store pruned heads
263
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
264
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
265
+ self.pruned_heads = self.pruned_heads.union(heads)
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states,
270
+ attention_mask=None,
271
+ head_mask=None,
272
+ encoder_hidden_states=None,
273
+ encoder_attention_mask=None,
274
+ past_key_value=None,
275
+ output_attentions=False,
276
+ ):
277
+ self_outputs = self.self(
278
+ hidden_states,
279
+ attention_mask,
280
+ head_mask,
281
+ encoder_hidden_states,
282
+ encoder_attention_mask,
283
+ past_key_value,
284
+ output_attentions,
285
+ )
286
+ attention_output = self.output(self_outputs[0], hidden_states)
287
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
288
+ return outputs
289
+
290
+
291
+ class BertIntermediate(nn.Module):
292
+ def __init__(self, config):
293
+ super().__init__()
294
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
295
+ if isinstance(config.hidden_act, str):
296
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
297
+ else:
298
+ self.intermediate_act_fn = config.hidden_act
299
+
300
+ def forward(self, hidden_states):
301
+ hidden_states = self.dense(hidden_states)
302
+ hidden_states = self.intermediate_act_fn(hidden_states)
303
+ return hidden_states
304
+
305
+
306
+ class BertOutput(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
310
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
311
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
312
+
313
+ def forward(self, hidden_states, input_tensor):
314
+ hidden_states = self.dense(hidden_states)
315
+ hidden_states = self.dropout(hidden_states)
316
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
317
+ return hidden_states
318
+
319
+
320
+ class BertLayer(nn.Module):
321
+ def __init__(self, config, layer_num):
322
+ super().__init__()
323
+ self.config = config
324
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
325
+ self.seq_len_dim = 1
326
+ self.attention = BertAttention(config)
327
+ self.layer_num = layer_num
328
+ if self.config.add_cross_attention:
329
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
330
+ self.intermediate = BertIntermediate(config)
331
+ self.output = BertOutput(config)
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states,
336
+ attention_mask=None,
337
+ head_mask=None,
338
+ encoder_hidden_states=None,
339
+ encoder_attention_mask=None,
340
+ past_key_value=None,
341
+ output_attentions=False,
342
+ mode=None,
343
+ ):
344
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
345
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
346
+ self_attention_outputs = self.attention(
347
+ hidden_states,
348
+ attention_mask,
349
+ head_mask,
350
+ output_attentions=output_attentions,
351
+ past_key_value=self_attn_past_key_value,
352
+ )
353
+ attention_output = self_attention_outputs[0]
354
+
355
+ outputs = self_attention_outputs[1:-1]
356
+ present_key_value = self_attention_outputs[-1]
357
+
358
+ if mode=='multimodal':
359
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
360
+
361
+ cross_attention_outputs = self.crossattention(
362
+ attention_output,
363
+ attention_mask,
364
+ head_mask,
365
+ encoder_hidden_states,
366
+ encoder_attention_mask,
367
+ output_attentions=output_attentions,
368
+ )
369
+ attention_output = cross_attention_outputs[0]
370
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
371
+ layer_output = apply_chunking_to_forward(
372
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
373
+ )
374
+ outputs = (layer_output,) + outputs
375
+
376
+ outputs = outputs + (present_key_value,)
377
+
378
+ return outputs
379
+
380
+ def feed_forward_chunk(self, attention_output):
381
+ intermediate_output = self.intermediate(attention_output)
382
+ layer_output = self.output(intermediate_output, attention_output)
383
+ return layer_output
384
+
385
+
386
+ class BertEncoder(nn.Module):
387
+ def __init__(self, config):
388
+ super().__init__()
389
+ self.config = config
390
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
391
+ self.gradient_checkpointing = False
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states,
396
+ attention_mask=None,
397
+ head_mask=None,
398
+ encoder_hidden_states=None,
399
+ encoder_attention_mask=None,
400
+ past_key_values=None,
401
+ use_cache=None,
402
+ output_attentions=False,
403
+ output_hidden_states=False,
404
+ return_dict=True,
405
+ mode='multimodal',
406
+ ):
407
+ all_hidden_states = () if output_hidden_states else None
408
+ all_self_attentions = () if output_attentions else None
409
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
410
+
411
+ next_decoder_cache = () if use_cache else None
412
+
413
+ for i in range(self.config.num_hidden_layers):
414
+ layer_module = self.layer[i]
415
+ if output_hidden_states:
416
+ all_hidden_states = all_hidden_states + (hidden_states,)
417
+
418
+ layer_head_mask = head_mask[i] if head_mask is not None else None
419
+ past_key_value = past_key_values[i] if past_key_values is not None else None
420
+
421
+ if self.gradient_checkpointing and self.training:
422
+
423
+ if use_cache:
424
+ logger.warn(
425
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
426
+ )
427
+ use_cache = False
428
+
429
+ def create_custom_forward(module):
430
+ def custom_forward(*inputs):
431
+ return module(*inputs, past_key_value, output_attentions)
432
+
433
+ return custom_forward
434
+
435
+ layer_outputs = torch.utils.checkpoint.checkpoint(
436
+ create_custom_forward(layer_module),
437
+ hidden_states,
438
+ attention_mask,
439
+ layer_head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ mode=mode,
443
+ )
444
+ else:
445
+ layer_outputs = layer_module(
446
+ hidden_states,
447
+ attention_mask,
448
+ layer_head_mask,
449
+ encoder_hidden_states,
450
+ encoder_attention_mask,
451
+ past_key_value,
452
+ output_attentions,
453
+ mode=mode,
454
+ )
455
+
456
+ hidden_states = layer_outputs[0]
457
+ if use_cache:
458
+ next_decoder_cache += (layer_outputs[-1],)
459
+ if output_attentions:
460
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
461
+
462
+ if output_hidden_states:
463
+ all_hidden_states = all_hidden_states + (hidden_states,)
464
+
465
+ if not return_dict:
466
+ return tuple(
467
+ v
468
+ for v in [
469
+ hidden_states,
470
+ next_decoder_cache,
471
+ all_hidden_states,
472
+ all_self_attentions,
473
+ all_cross_attentions,
474
+ ]
475
+ if v is not None
476
+ )
477
+ return BaseModelOutputWithPastAndCrossAttentions(
478
+ last_hidden_state=hidden_states,
479
+ past_key_values=next_decoder_cache,
480
+ hidden_states=all_hidden_states,
481
+ attentions=all_self_attentions,
482
+ cross_attentions=all_cross_attentions,
483
+ )
484
+
485
+
486
+ class BertPooler(nn.Module):
487
+ def __init__(self, config):
488
+ super().__init__()
489
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
490
+ self.activation = nn.Tanh()
491
+
492
+ def forward(self, hidden_states):
493
+ # We "pool" the model by simply taking the hidden state corresponding
494
+ # to the first token.
495
+ first_token_tensor = hidden_states[:, 0]
496
+ pooled_output = self.dense(first_token_tensor)
497
+ pooled_output = self.activation(pooled_output)
498
+ return pooled_output
499
+
500
+
501
+ class BertPredictionHeadTransform(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
505
+ if isinstance(config.hidden_act, str):
506
+ self.transform_act_fn = ACT2FN[config.hidden_act]
507
+ else:
508
+ self.transform_act_fn = config.hidden_act
509
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
510
+
511
+ def forward(self, hidden_states):
512
+ hidden_states = self.dense(hidden_states)
513
+ hidden_states = self.transform_act_fn(hidden_states)
514
+ hidden_states = self.LayerNorm(hidden_states)
515
+ return hidden_states
516
+
517
+
518
+ class BertLMPredictionHead(nn.Module):
519
+ def __init__(self, config):
520
+ super().__init__()
521
+ self.transform = BertPredictionHeadTransform(config)
522
+
523
+ # The output weights are the same as the input embeddings, but there is
524
+ # an output-only bias for each token.
525
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
526
+
527
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
528
+
529
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
530
+ self.decoder.bias = self.bias
531
+
532
+ def forward(self, hidden_states):
533
+ hidden_states = self.transform(hidden_states)
534
+ hidden_states = self.decoder(hidden_states)
535
+ return hidden_states
536
+
537
+
538
+ class BertOnlyMLMHead(nn.Module):
539
+ def __init__(self, config):
540
+ super().__init__()
541
+ self.predictions = BertLMPredictionHead(config)
542
+
543
+ def forward(self, sequence_output):
544
+ prediction_scores = self.predictions(sequence_output)
545
+ return prediction_scores
546
+
547
+
548
+ class BertPreTrainedModel(PreTrainedModel):
549
+ """
550
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
551
+ models.
552
+ """
553
+
554
+ config_class = BertConfig
555
+ base_model_prefix = "bert"
556
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
557
+
558
+ def _init_weights(self, module):
559
+ """ Initialize the weights """
560
+ if isinstance(module, (nn.Linear, nn.Embedding)):
561
+ # Slightly different from the TF version which uses truncated_normal for initialization
562
+ # cf https://github.com/pytorch/pytorch/pull/5617
563
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
564
+ elif isinstance(module, nn.LayerNorm):
565
+ module.bias.data.zero_()
566
+ module.weight.data.fill_(1.0)
567
+ if isinstance(module, nn.Linear) and module.bias is not None:
568
+ module.bias.data.zero_()
569
+
570
+
571
+ class BertModel(BertPreTrainedModel):
572
+ """
573
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
574
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
575
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
576
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
577
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
578
+ input to the forward pass.
579
+ """
580
+
581
+ def __init__(self, config, add_pooling_layer=True):
582
+ super().__init__(config)
583
+ self.config = config
584
+
585
+ self.embeddings = BertEmbeddings(config)
586
+
587
+ self.encoder = BertEncoder(config)
588
+
589
+ self.pooler = BertPooler(config) if add_pooling_layer else None
590
+
591
+ self.init_weights()
592
+
593
+
594
+ def get_input_embeddings(self):
595
+ return self.embeddings.word_embeddings
596
+
597
+ def set_input_embeddings(self, value):
598
+ self.embeddings.word_embeddings = value
599
+
600
+ def _prune_heads(self, heads_to_prune):
601
+ """
602
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
603
+ class PreTrainedModel
604
+ """
605
+ for layer, heads in heads_to_prune.items():
606
+ self.encoder.layer[layer].attention.prune_heads(heads)
607
+
608
+
609
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
610
+ """
611
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
612
+
613
+ Arguments:
614
+ attention_mask (:obj:`torch.Tensor`):
615
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
616
+ input_shape (:obj:`Tuple[int]`):
617
+ The shape of the input to the model.
618
+ device: (:obj:`torch.device`):
619
+ The device of the input to the model.
620
+
621
+ Returns:
622
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
623
+ """
624
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
625
+ # ourselves in which case we just need to make it broadcastable to all heads.
626
+ if attention_mask.dim() == 3:
627
+ extended_attention_mask = attention_mask[:, None, :, :]
628
+ elif attention_mask.dim() == 2:
629
+ # Provided a padding mask of dimensions [batch_size, seq_length]
630
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
631
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
632
+ if is_decoder:
633
+ batch_size, seq_length = input_shape
634
+
635
+ seq_ids = torch.arange(seq_length, device=device)
636
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
637
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
638
+ # causal and attention masks must have same type with pytorch version < 1.3
639
+ causal_mask = causal_mask.to(attention_mask.dtype)
640
+
641
+ if causal_mask.shape[1] < attention_mask.shape[1]:
642
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
643
+ causal_mask = torch.cat(
644
+ [
645
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
646
+ causal_mask,
647
+ ],
648
+ axis=-1,
649
+ )
650
+
651
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
652
+ else:
653
+ extended_attention_mask = attention_mask[:, None, None, :]
654
+ else:
655
+ raise ValueError(
656
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
657
+ input_shape, attention_mask.shape
658
+ )
659
+ )
660
+
661
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
662
+ # masked positions, this operation will create a tensor which is 0.0 for
663
+ # positions we want to attend and -10000.0 for masked positions.
664
+ # Since we are adding it to the raw scores before the softmax, this is
665
+ # effectively the same as removing these entirely.
666
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
667
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
668
+ return extended_attention_mask
669
+
670
+ def forward(
671
+ self,
672
+ input_ids=None,
673
+ attention_mask=None,
674
+ position_ids=None,
675
+ head_mask=None,
676
+ inputs_embeds=None,
677
+ encoder_embeds=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ past_key_values=None,
681
+ use_cache=None,
682
+ output_attentions=None,
683
+ output_hidden_states=None,
684
+ return_dict=None,
685
+ is_decoder=False,
686
+ mode='multimodal',
687
+ ):
688
+ r"""
689
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
690
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
691
+ the model is configured as a decoder.
692
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
693
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
694
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
695
+ - 1 for tokens that are **not masked**,
696
+ - 0 for tokens that are **masked**.
697
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
698
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
699
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
700
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
701
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
702
+ use_cache (:obj:`bool`, `optional`):
703
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
704
+ decoding (see :obj:`past_key_values`).
705
+ """
706
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
707
+ output_hidden_states = (
708
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
709
+ )
710
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
711
+
712
+ if is_decoder:
713
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
714
+ else:
715
+ use_cache = False
716
+
717
+ if input_ids is not None and inputs_embeds is not None:
718
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
719
+ elif input_ids is not None:
720
+ input_shape = input_ids.size()
721
+ batch_size, seq_length = input_shape
722
+ device = input_ids.device
723
+ elif inputs_embeds is not None:
724
+ input_shape = inputs_embeds.size()[:-1]
725
+ batch_size, seq_length = input_shape
726
+ device = inputs_embeds.device
727
+ elif encoder_embeds is not None:
728
+ input_shape = encoder_embeds.size()[:-1]
729
+ batch_size, seq_length = input_shape
730
+ device = encoder_embeds.device
731
+ else:
732
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
733
+
734
+ # past_key_values_length
735
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
736
+
737
+ if attention_mask is None:
738
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
739
+
740
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
741
+ # ourselves in which case we just need to make it broadcastable to all heads.
742
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
743
+ device, is_decoder)
744
+
745
+ # If a 2D or 3D attention mask is provided for the cross-attention
746
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
747
+ if encoder_hidden_states is not None:
748
+ if type(encoder_hidden_states) == list:
749
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
750
+ else:
751
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
752
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
753
+
754
+ if type(encoder_attention_mask) == list:
755
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
756
+ elif encoder_attention_mask is None:
757
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
758
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
759
+ else:
760
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
761
+ else:
762
+ encoder_extended_attention_mask = None
763
+
764
+ # Prepare head mask if needed
765
+ # 1.0 in head_mask indicate we keep the head
766
+ # attention_probs has shape bsz x n_heads x N x N
767
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
768
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
769
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
770
+
771
+ if encoder_embeds is None:
772
+ embedding_output = self.embeddings(
773
+ input_ids=input_ids,
774
+ position_ids=position_ids,
775
+ inputs_embeds=inputs_embeds,
776
+ past_key_values_length=past_key_values_length,
777
+ )
778
+ else:
779
+ embedding_output = encoder_embeds
780
+
781
+ encoder_outputs = self.encoder(
782
+ embedding_output,
783
+ attention_mask=extended_attention_mask,
784
+ head_mask=head_mask,
785
+ encoder_hidden_states=encoder_hidden_states,
786
+ encoder_attention_mask=encoder_extended_attention_mask,
787
+ past_key_values=past_key_values,
788
+ use_cache=use_cache,
789
+ output_attentions=output_attentions,
790
+ output_hidden_states=output_hidden_states,
791
+ return_dict=return_dict,
792
+ mode=mode,
793
+ )
794
+ sequence_output = encoder_outputs[0]
795
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
796
+
797
+ if not return_dict:
798
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
799
+
800
+ return BaseModelOutputWithPoolingAndCrossAttentions(
801
+ last_hidden_state=sequence_output,
802
+ pooler_output=pooled_output,
803
+ past_key_values=encoder_outputs.past_key_values,
804
+ hidden_states=encoder_outputs.hidden_states,
805
+ attentions=encoder_outputs.attentions,
806
+ cross_attentions=encoder_outputs.cross_attentions,
807
+ )
808
+
809
+
810
+
811
+ class BertLMHeadModel(BertPreTrainedModel):
812
+
813
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
814
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
815
+
816
+ def __init__(self, config):
817
+ super().__init__(config)
818
+
819
+ self.bert = BertModel(config, add_pooling_layer=False)
820
+ self.cls = BertOnlyMLMHead(config)
821
+
822
+ self.init_weights()
823
+
824
+ def get_output_embeddings(self):
825
+ return self.cls.predictions.decoder
826
+
827
+ def set_output_embeddings(self, new_embeddings):
828
+ self.cls.predictions.decoder = new_embeddings
829
+
830
+ def forward(
831
+ self,
832
+ input_ids=None,
833
+ attention_mask=None,
834
+ position_ids=None,
835
+ head_mask=None,
836
+ inputs_embeds=None,
837
+ encoder_hidden_states=None,
838
+ encoder_attention_mask=None,
839
+ labels=None,
840
+ past_key_values=None,
841
+ use_cache=None,
842
+ output_attentions=None,
843
+ output_hidden_states=None,
844
+ return_dict=None,
845
+ return_logits=False,
846
+ is_decoder=True,
847
+ reduction='mean',
848
+ mode='multimodal',
849
+ ):
850
+ r"""
851
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
852
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
853
+ the model is configured as a decoder.
854
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
855
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
856
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
857
+ - 1 for tokens that are **not masked**,
858
+ - 0 for tokens that are **masked**.
859
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
860
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
861
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
862
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
863
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
864
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
865
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
866
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
867
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
868
+ use_cache (:obj:`bool`, `optional`):
869
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
870
+ decoding (see :obj:`past_key_values`).
871
+ Returns:
872
+ Example::
873
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
874
+ >>> import torch
875
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
876
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
877
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
878
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
879
+ >>> outputs = model(**inputs)
880
+ >>> prediction_logits = outputs.logits
881
+ """
882
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
883
+ if labels is not None:
884
+ use_cache = False
885
+
886
+ outputs = self.bert(
887
+ input_ids,
888
+ attention_mask=attention_mask,
889
+ position_ids=position_ids,
890
+ head_mask=head_mask,
891
+ inputs_embeds=inputs_embeds,
892
+ encoder_hidden_states=encoder_hidden_states,
893
+ encoder_attention_mask=encoder_attention_mask,
894
+ past_key_values=past_key_values,
895
+ use_cache=use_cache,
896
+ output_attentions=output_attentions,
897
+ output_hidden_states=output_hidden_states,
898
+ return_dict=return_dict,
899
+ is_decoder=is_decoder,
900
+ mode=mode,
901
+ )
902
+
903
+ sequence_output = outputs[0]
904
+ prediction_scores = self.cls(sequence_output)
905
+
906
+ if return_logits:
907
+ return prediction_scores[:, :-1, :].contiguous()
908
+
909
+ lm_loss = None
910
+ if labels is not None:
911
+ # we are doing next-token prediction; shift prediction scores and input ids by one
912
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
913
+ labels = labels[:, 1:].contiguous()
914
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
915
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
916
+ if reduction=='none':
917
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
918
+
919
+ if not return_dict:
920
+ output = (prediction_scores,) + outputs[2:]
921
+ return ((lm_loss,) + output) if lm_loss is not None else output
922
+
923
+ return CausalLMOutputWithCrossAttentions(
924
+ loss=lm_loss,
925
+ logits=prediction_scores,
926
+ past_key_values=outputs.past_key_values,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ cross_attentions=outputs.cross_attentions,
930
+ )
931
+
932
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
933
+ input_shape = input_ids.shape
934
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
935
+ if attention_mask is None:
936
+ attention_mask = input_ids.new_ones(input_shape)
937
+
938
+ # cut decoder_input_ids if past is used
939
+ if past is not None:
940
+ input_ids = input_ids[:, -1:]
941
+
942
+ return {
943
+ "input_ids": input_ids,
944
+ "attention_mask": attention_mask,
945
+ "past_key_values": past,
946
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
947
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
948
+ "is_decoder": True,
949
+ }
950
+
951
+ def _reorder_cache(self, past, beam_idx):
952
+ reordered_past = ()
953
+ for layer_past in past:
954
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
955
+ return reordered_past
SMILE/BLIP/models/model.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+
4
+ from models.vit import VisionTransformer, interpolate_pos_embed
5
+ from models.med import BertConfig, BertLMHeadModel
6
+ from transformers import BertTokenizer
7
+
8
+
9
+ import torch
10
+ from torch import nn
11
+ import torch.nn.functional as F
12
+
13
+ import os
14
+ from urllib.parse import urlparse
15
+ from timm.models.hub import download_cached_file
16
+
17
+ import pdb
18
+
19
+ class CapModel(nn.Module):
20
+ def __init__(self,
21
+ med_config = 'SMILE/BLIP/configs/med_config.json',
22
+ image_size = 224,
23
+ vit = 'base',
24
+ vit_grad_ckpt = False,
25
+ vit_ckpt_layer = 0,
26
+ prompt = 'a picture of ',
27
+ ):
28
+
29
+ super().__init__()
30
+
31
+ self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
32
+ self.tokenizer = init_tokenizer()
33
+ med_config = BertConfig.from_json_file(med_config)
34
+ med_config.encoder_width = vision_width
35
+ self.text_decoder = BertLMHeadModel(config=med_config)
36
+
37
+ self.prompt = prompt
38
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
39
+
40
+ self.vocab_emb = None
41
+
42
+ def forward(self, image, caption):
43
+
44
+ image_embeds = self.visual_encoder(image)
45
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
46
+
47
+ text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
48
+
49
+ text.input_ids[:,0] = self.tokenizer.bos_token_id
50
+
51
+ # # First-token Shifting: Change the first token 'word' to '##word'
52
+ # for i in range(text.input_ids.size(0)):
53
+ # text.input_ids[i, self.prompt_length] = self.tokenizer.convert_tokens_to_ids('##' + self.tokenizer.convert_ids_to_tokens(text.input_ids[i,self.prompt_length].item()))
54
+
55
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
56
+ decoder_targets[:,:self.prompt_length] = -100
57
+
58
+ decoder_output = self.text_decoder(text.input_ids,
59
+ attention_mask = text.attention_mask,
60
+ encoder_hidden_states = image_embeds,
61
+ encoder_attention_mask = image_atts,
62
+ labels = decoder_targets,
63
+ return_dict = True,
64
+ )
65
+
66
+ # # mle
67
+ # mle_loss = decoder_output.loss
68
+
69
+ label = text.input_ids[:, self.prompt_length:].contiguous()
70
+ bs = text.input_ids.size(0)
71
+ N = label.size(1)
72
+ vs = self.text_decoder.config.vocab_size
73
+ logits = decoder_output.logits[:, self.prompt_length-1:-1]
74
+
75
+ # smile
76
+ mask = torch.zeros(bs, vs).to(logits.device).scatter_(1, label, True)
77
+ mask[:, 0] = 0
78
+ mask = mask.unsqueeze(1).expand(-1, N, -1).clone()
79
+ mask[:, 0, :] = 1 # mle on first token
80
+ selected_logits = logits.masked_fill(mask == 0, -1e9)
81
+ smile_loss = F.cross_entropy(selected_logits.view(-1, vs), label.view(-1), ignore_index=0, reduction='mean')
82
+
83
+ # # reverse smile
84
+ # reverse_mask = torch.ones(bs, vs).to(logits.device).scatter_(1, label, False)
85
+ # reverse_mask = reverse_mask.unsqueeze(1).expand(-1, N, -1).clone()
86
+ # reverse_mask.scatter_(2, label.unsqueeze(-1), 1)
87
+ # reverse_mask[:, 0, :] = 1 # mle on first token
88
+ # reverse_selected_logits = logits.masked_fill(reverse_mask == 0, -1e9)
89
+ # reverse_smile_loss = F.cross_entropy(reverse_selected_logits.view(-1, vs), label.view(-1), ignore_index=0, reduction='mean')
90
+
91
+ # # random sample (efficient implementation)
92
+ # sample_num = 10
93
+ # rand_indices = torch.randint(vs, (bs, N, sample_num)).to(label.device)
94
+ # rand_indices_with_label = torch.cat((rand_indices, label.unsqueeze(2)), dim=2) # (bs, N, sample_num + 1)
95
+ # batch_indices = torch.arange(bs)[:, None, None].expand(bs, N, sample_num + 1)
96
+ # seq_indices = torch.arange(N)[None, :, None].expand(bs, N, sample_num + 1)
97
+ # random_mask = torch.zeros(bs, N, vs).to(label.device)
98
+ # random_mask[batch_indices, seq_indices, rand_indices_with_label] = 1
99
+ # random_mask[:, :, 0] = 0
100
+ # random_selected_logits = logits.masked_fill(mask == 0, -1e9)
101
+ # random_smile_loss = F.cross_entropy(random_selected_logits.view(-1, vs), label.view(-1), ignore_index=0, reduction='mean')
102
+
103
+ loss = smile_loss
104
+ # loss = 0.5 * reverse_smile_loss + 0.5 * mle_loss
105
+
106
+ return loss
107
+
108
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
109
+ image_embeds = self.visual_encoder(image)
110
+
111
+ if not sample:
112
+ image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
113
+
114
+ prompt = [self.prompt] * image.size(0)
115
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
116
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
117
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
118
+
119
+ input_ids[:,0] = self.tokenizer.bos_token_id
120
+ input_ids = input_ids[:, :-1]
121
+
122
+ if sample:
123
+ #nucleus sampling
124
+ outputs = self.text_decoder.generate(input_ids=input_ids,
125
+ max_length=max_length,
126
+ min_length=min_length,
127
+ do_sample=True,
128
+ top_p=top_p,
129
+ num_return_sequences=1,
130
+ eos_token_id=self.tokenizer.sep_token_id,
131
+ pad_token_id=self.tokenizer.pad_token_id,
132
+ repetition_penalty=1.1,
133
+ **model_kwargs)
134
+ else:
135
+ #beam search
136
+ outputs = self.text_decoder.generate(input_ids=input_ids,
137
+ max_length=max_length,
138
+ min_length=min_length,
139
+ num_beams=num_beams,
140
+ eos_token_id=self.tokenizer.sep_token_id,
141
+ pad_token_id=self.tokenizer.pad_token_id,
142
+ repetition_penalty=repetition_penalty,
143
+ **model_kwargs)
144
+
145
+ captions = []
146
+ for output in outputs:
147
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
148
+ captions.append(caption[len(self.prompt):])
149
+ # caption = self.tokenizer.decode(output[4:], skip_special_tokens=True)
150
+ # captions.append(caption)
151
+ return captions
152
+
153
+
154
+ def caption_model(pretrained='',**kwargs):
155
+ model = CapModel(**kwargs)
156
+ if pretrained:
157
+ model,msg = load_checkpoint(model,pretrained)
158
+ return model
159
+
160
+ def init_tokenizer():
161
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
162
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
163
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
164
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
165
+ return tokenizer
166
+
167
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
168
+
169
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
170
+ if vit=='base':
171
+ vision_width = 768
172
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
173
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
174
+ drop_path_rate=0 or drop_path_rate
175
+ )
176
+ elif vit=='large':
177
+ vision_width = 1024
178
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
179
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
180
+ drop_path_rate=0.1 or drop_path_rate
181
+ )
182
+ return visual_encoder, vision_width
183
+
184
+ def is_url(url_or_filename):
185
+ parsed = urlparse(url_or_filename)
186
+ return parsed.scheme in ("http", "https")
187
+
188
+ def load_checkpoint(model,url_or_filename):
189
+ if is_url(url_or_filename):
190
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
191
+ checkpoint = torch.load(cached_file, map_location='cpu')
192
+ elif os.path.isfile(url_or_filename):
193
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
194
+ else:
195
+ raise RuntimeError('checkpoint url or path is invalid')
196
+
197
+ state_dict = checkpoint['model']
198
+
199
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
200
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
201
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
202
+ model.visual_encoder_m)
203
+
204
+ for key in model.state_dict().keys():
205
+ if key in state_dict.keys():
206
+ if state_dict[key].shape!=model.state_dict()[key].shape:
207
+ del state_dict[key]
208
+
209
+ msg = model.load_state_dict(state_dict, strict=False)
210
+ print('load checkpoint from %s'%url_or_filename)
211
+ return model,msg
SMILE/BLIP/models/vit.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on timm code base
8
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from functools import partial
15
+
16
+ from timm.models.vision_transformer import _cfg, PatchEmbed
17
+ from timm.models.registry import register_model
18
+ from timm.models.layers import trunc_normal_, DropPath
19
+ from timm.models.helpers import named_apply, adapt_input_conv
20
+
21
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22
+
23
+ class Mlp(nn.Module):
24
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25
+ """
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50
+ self.scale = qk_scale or head_dim ** -0.5
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+ self.attn_gradients = None
56
+ self.attention_map = None
57
+
58
+ def save_attn_gradients(self, attn_gradients):
59
+ self.attn_gradients = attn_gradients
60
+
61
+ def get_attn_gradients(self):
62
+ return self.attn_gradients
63
+
64
+ def save_attention_map(self, attention_map):
65
+ self.attention_map = attention_map
66
+
67
+ def get_attention_map(self):
68
+ return self.attention_map
69
+
70
+ def forward(self, x, register_hook=False):
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ if register_hook:
80
+ self.save_attention_map(attn)
81
+ attn.register_hook(self.save_attn_gradients)
82
+
83
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84
+ x = self.proj(x)
85
+ x = self.proj_drop(x)
86
+ return x
87
+
88
+
89
+ class Block(nn.Module):
90
+
91
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93
+ super().__init__()
94
+ self.norm1 = norm_layer(dim)
95
+ self.attn = Attention(
96
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
+ self.norm2 = norm_layer(dim)
100
+ mlp_hidden_dim = int(dim * mlp_ratio)
101
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
+
103
+ if use_grad_checkpointing:
104
+ self.attn = checkpoint_wrapper(self.attn)
105
+ self.mlp = checkpoint_wrapper(self.mlp)
106
+
107
+ def forward(self, x, register_hook=False):
108
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
110
+ return x
111
+
112
+
113
+ class VisionTransformer(nn.Module):
114
+ """ Vision Transformer
115
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116
+ https://arxiv.org/abs/2010.11929
117
+ """
118
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121
+ use_grad_checkpointing=False, ckpt_layer=0):
122
+ """
123
+ Args:
124
+ img_size (int, tuple): input image size
125
+ patch_size (int, tuple): patch size
126
+ in_chans (int): number of input channels
127
+ num_classes (int): number of classes for classification head
128
+ embed_dim (int): embedding dimension
129
+ depth (int): depth of transformer
130
+ num_heads (int): number of attention heads
131
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132
+ qkv_bias (bool): enable bias for qkv if True
133
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135
+ drop_rate (float): dropout rate
136
+ attn_drop_rate (float): attention dropout rate
137
+ drop_path_rate (float): stochastic depth rate
138
+ norm_layer: (nn.Module): normalization layer
139
+ """
140
+ super().__init__()
141
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143
+
144
+ self.patch_embed = PatchEmbed(
145
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146
+
147
+ num_patches = self.patch_embed.num_patches
148
+
149
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151
+ self.pos_drop = nn.Dropout(p=drop_rate)
152
+
153
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154
+ self.blocks = nn.ModuleList([
155
+ Block(
156
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159
+ )
160
+ for i in range(depth)])
161
+ self.norm = norm_layer(embed_dim)
162
+
163
+ trunc_normal_(self.pos_embed, std=.02)
164
+ trunc_normal_(self.cls_token, std=.02)
165
+ self.apply(self._init_weights)
166
+
167
+ def _init_weights(self, m):
168
+ if isinstance(m, nn.Linear):
169
+ trunc_normal_(m.weight, std=.02)
170
+ if isinstance(m, nn.Linear) and m.bias is not None:
171
+ nn.init.constant_(m.bias, 0)
172
+ elif isinstance(m, nn.LayerNorm):
173
+ nn.init.constant_(m.bias, 0)
174
+ nn.init.constant_(m.weight, 1.0)
175
+
176
+ @torch.jit.ignore
177
+ def no_weight_decay(self):
178
+ return {'pos_embed', 'cls_token'}
179
+
180
+ def forward(self, x, register_blk=-1):
181
+ B = x.shape[0]
182
+ x = self.patch_embed(x)
183
+
184
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185
+ x = torch.cat((cls_tokens, x), dim=1)
186
+
187
+ x = x + self.pos_embed[:,:x.size(1),:]
188
+ x = self.pos_drop(x)
189
+
190
+ for i,blk in enumerate(self.blocks):
191
+ x = blk(x, register_blk==i)
192
+ x = self.norm(x)
193
+
194
+ return x
195
+
196
+ @torch.jit.ignore()
197
+ def load_pretrained(self, checkpoint_path, prefix=''):
198
+ _load_weights(self, checkpoint_path, prefix)
199
+
200
+
201
+ @torch.no_grad()
202
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204
+ """
205
+ import numpy as np
206
+
207
+ def _n2p(w, t=True):
208
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209
+ w = w.flatten()
210
+ if t:
211
+ if w.ndim == 4:
212
+ w = w.transpose([3, 2, 0, 1])
213
+ elif w.ndim == 3:
214
+ w = w.transpose([2, 0, 1])
215
+ elif w.ndim == 2:
216
+ w = w.transpose([1, 0])
217
+ return torch.from_numpy(w)
218
+
219
+ w = np.load(checkpoint_path)
220
+ if not prefix and 'opt/target/embedding/kernel' in w:
221
+ prefix = 'opt/target/'
222
+
223
+ if hasattr(model.patch_embed, 'backbone'):
224
+ # hybrid
225
+ backbone = model.patch_embed.backbone
226
+ stem_only = not hasattr(backbone, 'stem')
227
+ stem = backbone if stem_only else backbone.stem
228
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231
+ if not stem_only:
232
+ for i, stage in enumerate(backbone.stages):
233
+ for j, block in enumerate(stage.blocks):
234
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235
+ for r in range(3):
236
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239
+ if block.downsample is not None:
240
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244
+ else:
245
+ embed_conv_w = adapt_input_conv(
246
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
248
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251
+ if pos_embed_w.shape != model.pos_embed.shape:
252
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254
+ model.pos_embed.copy_(pos_embed_w)
255
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263
+ for i, block in enumerate(model.blocks.children()):
264
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268
+ block.attn.qkv.weight.copy_(torch.cat([
269
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270
+ block.attn.qkv.bias.copy_(torch.cat([
271
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274
+ for r in range(2):
275
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279
+
280
+
281
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282
+ # interpolate position embedding
283
+ embedding_size = pos_embed_checkpoint.shape[-1]
284
+ num_patches = visual_encoder.patch_embed.num_patches
285
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286
+ # height (== width) for the checkpoint position embedding
287
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288
+ # height (== width) for the new position embedding
289
+ new_size = int(num_patches ** 0.5)
290
+
291
+ if orig_size!=new_size:
292
+ # class_token and dist_token are kept unchanged
293
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294
+ # only the position tokens are interpolated
295
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297
+ pos_tokens = torch.nn.functional.interpolate(
298
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302
+
303
+ return new_pos_embed
304
+ else:
305
+ return pos_embed_checkpoint
SMILE/BLIP/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ timm==0.4.12
2
+ transformers==4.15.0
3
+ fairscale==0.4.4
4
+ pycocoevalcap
SMILE/BLIP/scripts/eval.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0,1,2,3 \
2
+ torchrun \
3
+ --nproc_per_node=4 \
4
+ --master_port 30010 \
5
+ train_caption.py \
6
+ --evaluate \
7
+ --eval_split test \
8
+ --config configs/caption_coco.yaml \
9
+ --output_dir output/blip
SMILE/BLIP/scripts/train.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=1,2,3,4 \
2
+ torchrun \
3
+ --nproc_per_node=4 \
4
+ --master_port 30000 \
5
+ train_caption.py \
6
+ --config configs/caption_coco.yaml \
7
+ --output_dir output/blip
SMILE/BLIP/train_caption.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+
8
+ * Modified by Zihao Yue
9
+ '''
10
+
11
+ import argparse
12
+ import os
13
+ try:
14
+ import ruamel_yaml as yaml
15
+ except:
16
+ import ruamel.yaml as yaml
17
+ import numpy as np
18
+ import random
19
+ import time
20
+ import datetime
21
+ import json
22
+ from pathlib import Path
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ import torch.backends.cudnn as cudnn
28
+ import torch.distributed as dist
29
+ from torch.utils.data import DataLoader
30
+
31
+ from models.model import caption_model
32
+ import utils
33
+ from utils import warmup_lr_schedule, step_lr_schedule, cosine_lr_schedule
34
+ from data import create_dataset, create_sampler, create_loader
35
+ from data.utils import save_result, coco_caption_eval
36
+
37
+ def train(model, data_loader, optimizer, epoch, device):
38
+ # train
39
+ model.train()
40
+
41
+ metric_logger = utils.MetricLogger(delimiter=" ")
42
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
43
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
44
+ header = 'Train Caption Epoch: [{}]'.format(epoch)
45
+ print_freq = 50
46
+
47
+ for i, (image, caption, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
48
+ image = image.to(device)
49
+
50
+ loss = model(image, caption)
51
+
52
+ optimizer.zero_grad()
53
+ loss.backward()
54
+ optimizer.step()
55
+
56
+ metric_logger.update(loss=loss.item())
57
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
58
+
59
+ # gather the stats from all processes
60
+ metric_logger.synchronize_between_processes()
61
+ print("Averaged stats:", metric_logger.global_avg())
62
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
63
+
64
+
65
+ @torch.no_grad()
66
+ def evaluate(model, data_loader, device, config):
67
+ # evaluate
68
+ model.eval()
69
+
70
+ metric_logger = utils.MetricLogger(delimiter=" ")
71
+ header = 'Caption generation:'
72
+ print_freq = 10
73
+
74
+ result = []
75
+ for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
76
+
77
+ image = image.to(device)
78
+ captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'], min_length=config['min_length'])
79
+
80
+ for caption, img_id in zip(captions, image_id):
81
+ result.append({"image_id": img_id.item(), "caption": caption})
82
+
83
+ return result
84
+
85
+
86
+ def main(args, config):
87
+ utils.init_distributed_mode(args)
88
+
89
+ device = torch.device(args.device)
90
+
91
+ # fix the seed for reproducibility
92
+ seed = args.seed + utils.get_rank()
93
+ torch.manual_seed(seed)
94
+ np.random.seed(seed)
95
+ random.seed(seed)
96
+ cudnn.benchmark = True
97
+
98
+ #### Dataset ####
99
+ print("Creating captioning dataset")
100
+ train_dataset, val_dataset, test_dataset = create_dataset('caption_coco', config)
101
+
102
+ if args.distributed:
103
+ num_tasks = utils.get_world_size()
104
+ global_rank = utils.get_rank()
105
+ samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank)
106
+ else:
107
+ samplers = [None, None, None]
108
+
109
+ train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
110
+ batch_size=[config['batch_size']]*3,num_workers=[4,4,4],
111
+ is_trains=[True, False, False], collate_fns=[None,None,None])
112
+
113
+ #### Model ####
114
+ print("Creating model")
115
+ model = caption_model(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
116
+ vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
117
+ prompt=config['prompt'])
118
+
119
+ model = model.to(device)
120
+
121
+ model_without_ddp = model
122
+ if args.distributed:
123
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
124
+ model_without_ddp = model.module
125
+
126
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
127
+
128
+ best = 0
129
+ best_epoch = 0
130
+
131
+ print("Start training")
132
+ start_time = time.time()
133
+ for epoch in range(0, config['max_epoch']):
134
+ if not args.evaluate:
135
+ if args.distributed:
136
+ train_loader.sampler.set_epoch(epoch)
137
+
138
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
139
+
140
+ train_stats = train(model, train_loader, optimizer, epoch, device)
141
+
142
+ if args.eval_split == 'val' or not args.evaluate:
143
+ val_result = evaluate(model_without_ddp, val_loader, device, config)
144
+ val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, remove_duplicate='image_id')
145
+ else:
146
+ test_result = evaluate(model_without_ddp, test_loader, device, config)
147
+ test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, remove_duplicate='image_id')
148
+
149
+ if utils.is_main_process():
150
+
151
+ if args.eval_split == 'val' or not args.evaluate:
152
+ coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val')
153
+ else:
154
+ coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test')
155
+
156
+ if args.evaluate:
157
+ if args.eval_split == 'val':
158
+ log_stats = {
159
+ **{f'val_{k}': v for k, v in coco_val.eval.items()},
160
+ }
161
+ else:
162
+ log_stats = {
163
+ **{f'test_{k}': v for k, v in coco_test.eval.items()},
164
+ }
165
+ with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
166
+ f.write(json.dumps(log_stats) + "\n")
167
+ else:
168
+ save_obj = {
169
+ 'model': model_without_ddp.state_dict(),
170
+ 'optimizer': optimizer.state_dict(),
171
+ 'config': config,
172
+ 'epoch': epoch,
173
+ }
174
+
175
+ if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best:
176
+ best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4']
177
+ best_epoch = epoch
178
+ # torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
179
+ # save each epoch
180
+ torch.save(save_obj, os.path.join(args.output_dir, 'epoch%d.pth'%epoch))
181
+
182
+ log_stats = {**{f'train_{k}': float(v) for k, v in train_stats.items()},
183
+ **{f'val_{k}': v for k, v in coco_val.eval.items()},
184
+ 'epoch': epoch,
185
+ 'best_epoch': best_epoch,
186
+ }
187
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
188
+ f.write(json.dumps(log_stats) + "\n")
189
+
190
+ if args.evaluate:
191
+ break
192
+ dist.barrier()
193
+
194
+ total_time = time.time() - start_time
195
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
196
+ print('Training time {}'.format(total_time_str))
197
+
198
+
199
+ if __name__ == '__main__':
200
+ parser = argparse.ArgumentParser()
201
+ parser.add_argument('--config', default='./configs/caption_coco.yaml')
202
+ parser.add_argument('--output_dir', default='output/caption_coco')
203
+ parser.add_argument('--evaluate', action='store_true')
204
+ parser.add_argument('--device', default='cuda')
205
+ parser.add_argument('--seed', default=42, type=int)
206
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
207
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
208
+ parser.add_argument('--distributed', default=True, type=bool)
209
+ parser.add_argument('--eval_split', default='val', type=str)
210
+ args = parser.parse_args()
211
+
212
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
213
+
214
+ args.result_dir = os.path.join(args.output_dir, 'result')
215
+
216
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
217
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
218
+
219
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
220
+
221
+ main(args, config)
SMILE/BLIP/transform/randaugment.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ ## aug functions
6
+ def identity_func(img):
7
+ return img
8
+
9
+
10
+ def autocontrast_func(img, cutoff=0):
11
+ '''
12
+ same output as PIL.ImageOps.autocontrast
13
+ '''
14
+ n_bins = 256
15
+
16
+ def tune_channel(ch):
17
+ n = ch.size
18
+ cut = cutoff * n // 100
19
+ if cut == 0:
20
+ high, low = ch.max(), ch.min()
21
+ else:
22
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
23
+ low = np.argwhere(np.cumsum(hist) > cut)
24
+ low = 0 if low.shape[0] == 0 else low[0]
25
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
26
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
27
+ if high <= low:
28
+ table = np.arange(n_bins)
29
+ else:
30
+ scale = (n_bins - 1) / (high - low)
31
+ offset = -low * scale
32
+ table = np.arange(n_bins) * scale + offset
33
+ table[table < 0] = 0
34
+ table[table > n_bins - 1] = n_bins - 1
35
+ table = table.clip(0, 255).astype(np.uint8)
36
+ return table[ch]
37
+
38
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
39
+ out = cv2.merge(channels)
40
+ return out
41
+
42
+
43
+ def equalize_func(img):
44
+ '''
45
+ same output as PIL.ImageOps.equalize
46
+ PIL's implementation is different from cv2.equalize
47
+ '''
48
+ n_bins = 256
49
+
50
+ def tune_channel(ch):
51
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
52
+ non_zero_hist = hist[hist != 0].reshape(-1)
53
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
54
+ if step == 0: return ch
55
+ n = np.empty_like(hist)
56
+ n[0] = step // 2
57
+ n[1:] = hist[:-1]
58
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
59
+ return table[ch]
60
+
61
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
62
+ out = cv2.merge(channels)
63
+ return out
64
+
65
+
66
+ def rotate_func(img, degree, fill=(0, 0, 0)):
67
+ '''
68
+ like PIL, rotate by degree, not radians
69
+ '''
70
+ H, W = img.shape[0], img.shape[1]
71
+ center = W / 2, H / 2
72
+ M = cv2.getRotationMatrix2D(center, degree, 1)
73
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
74
+ return out
75
+
76
+
77
+ def solarize_func(img, thresh=128):
78
+ '''
79
+ same output as PIL.ImageOps.posterize
80
+ '''
81
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
82
+ table = table.clip(0, 255).astype(np.uint8)
83
+ out = table[img]
84
+ return out
85
+
86
+
87
+ def color_func(img, factor):
88
+ '''
89
+ same output as PIL.ImageEnhance.Color
90
+ '''
91
+ ## implementation according to PIL definition, quite slow
92
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
93
+ # out = blend(degenerate, img, factor)
94
+ # M = (
95
+ # np.eye(3) * factor
96
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
97
+ # )[np.newaxis, np.newaxis, :]
98
+ M = (
99
+ np.float32([
100
+ [0.886, -0.114, -0.114],
101
+ [-0.587, 0.413, -0.587],
102
+ [-0.299, -0.299, 0.701]]) * factor
103
+ + np.float32([[0.114], [0.587], [0.299]])
104
+ )
105
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
106
+ return out
107
+
108
+
109
+ def contrast_func(img, factor):
110
+ """
111
+ same output as PIL.ImageEnhance.Contrast
112
+ """
113
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
114
+ table = np.array([(
115
+ el - mean) * factor + mean
116
+ for el in range(256)
117
+ ]).clip(0, 255).astype(np.uint8)
118
+ out = table[img]
119
+ return out
120
+
121
+
122
+ def brightness_func(img, factor):
123
+ '''
124
+ same output as PIL.ImageEnhance.Contrast
125
+ '''
126
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
127
+ out = table[img]
128
+ return out
129
+
130
+
131
+ def sharpness_func(img, factor):
132
+ '''
133
+ The differences the this result and PIL are all on the 4 boundaries, the center
134
+ areas are same
135
+ '''
136
+ kernel = np.ones((3, 3), dtype=np.float32)
137
+ kernel[1][1] = 5
138
+ kernel /= 13
139
+ degenerate = cv2.filter2D(img, -1, kernel)
140
+ if factor == 0.0:
141
+ out = degenerate
142
+ elif factor == 1.0:
143
+ out = img
144
+ else:
145
+ out = img.astype(np.float32)
146
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
147
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
148
+ out = out.astype(np.uint8)
149
+ return out
150
+
151
+
152
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
153
+ H, W = img.shape[0], img.shape[1]
154
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
155
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
156
+ return out
157
+
158
+
159
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
160
+ '''
161
+ same output as PIL.Image.transform
162
+ '''
163
+ H, W = img.shape[0], img.shape[1]
164
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
165
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
166
+ return out
167
+
168
+
169
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
170
+ '''
171
+ same output as PIL.Image.transform
172
+ '''
173
+ H, W = img.shape[0], img.shape[1]
174
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
175
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
176
+ return out
177
+
178
+
179
+ def posterize_func(img, bits):
180
+ '''
181
+ same output as PIL.ImageOps.posterize
182
+ '''
183
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
184
+ return out
185
+
186
+
187
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
188
+ H, W = img.shape[0], img.shape[1]
189
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
190
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
191
+ return out
192
+
193
+
194
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
195
+ replace = np.array(replace, dtype=np.uint8)
196
+ H, W = img.shape[0], img.shape[1]
197
+ rh, rw = np.random.random(2)
198
+ pad_size = pad_size // 2
199
+ ch, cw = int(rh * H), int(rw * W)
200
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
201
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
202
+ out = img.copy()
203
+ out[x1:x2, y1:y2, :] = replace
204
+ return out
205
+
206
+
207
+ ### level to args
208
+ def enhance_level_to_args(MAX_LEVEL):
209
+ def level_to_args(level):
210
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
211
+ return level_to_args
212
+
213
+
214
+ def shear_level_to_args(MAX_LEVEL, replace_value):
215
+ def level_to_args(level):
216
+ level = (level / MAX_LEVEL) * 0.3
217
+ if np.random.random() > 0.5: level = -level
218
+ return (level, replace_value)
219
+
220
+ return level_to_args
221
+
222
+
223
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
224
+ def level_to_args(level):
225
+ level = (level / MAX_LEVEL) * float(translate_const)
226
+ if np.random.random() > 0.5: level = -level
227
+ return (level, replace_value)
228
+
229
+ return level_to_args
230
+
231
+
232
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
233
+ def level_to_args(level):
234
+ level = int((level / MAX_LEVEL) * cutout_const)
235
+ return (level, replace_value)
236
+
237
+ return level_to_args
238
+
239
+
240
+ def solarize_level_to_args(MAX_LEVEL):
241
+ def level_to_args(level):
242
+ level = int((level / MAX_LEVEL) * 256)
243
+ return (level, )
244
+ return level_to_args
245
+
246
+
247
+ def none_level_to_args(level):
248
+ return ()
249
+
250
+
251
+ def posterize_level_to_args(MAX_LEVEL):
252
+ def level_to_args(level):
253
+ level = int((level / MAX_LEVEL) * 4)
254
+ return (level, )
255
+ return level_to_args
256
+
257
+
258
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
259
+ def level_to_args(level):
260
+ level = (level / MAX_LEVEL) * 30
261
+ if np.random.random() < 0.5:
262
+ level = -level
263
+ return (level, replace_value)
264
+
265
+ return level_to_args
266
+
267
+
268
+ func_dict = {
269
+ 'Identity': identity_func,
270
+ 'AutoContrast': autocontrast_func,
271
+ 'Equalize': equalize_func,
272
+ 'Rotate': rotate_func,
273
+ 'Solarize': solarize_func,
274
+ 'Color': color_func,
275
+ 'Contrast': contrast_func,
276
+ 'Brightness': brightness_func,
277
+ 'Sharpness': sharpness_func,
278
+ 'ShearX': shear_x_func,
279
+ 'TranslateX': translate_x_func,
280
+ 'TranslateY': translate_y_func,
281
+ 'Posterize': posterize_func,
282
+ 'ShearY': shear_y_func,
283
+ }
284
+
285
+ translate_const = 10
286
+ MAX_LEVEL = 10
287
+ replace_value = (128, 128, 128)
288
+ arg_dict = {
289
+ 'Identity': none_level_to_args,
290
+ 'AutoContrast': none_level_to_args,
291
+ 'Equalize': none_level_to_args,
292
+ 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
293
+ 'Solarize': solarize_level_to_args(MAX_LEVEL),
294
+ 'Color': enhance_level_to_args(MAX_LEVEL),
295
+ 'Contrast': enhance_level_to_args(MAX_LEVEL),
296
+ 'Brightness': enhance_level_to_args(MAX_LEVEL),
297
+ 'Sharpness': enhance_level_to_args(MAX_LEVEL),
298
+ 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
299
+ 'TranslateX': translate_level_to_args(
300
+ translate_const, MAX_LEVEL, replace_value
301
+ ),
302
+ 'TranslateY': translate_level_to_args(
303
+ translate_const, MAX_LEVEL, replace_value
304
+ ),
305
+ 'Posterize': posterize_level_to_args(MAX_LEVEL),
306
+ 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
307
+ }
308
+
309
+
310
+ class RandomAugment(object):
311
+
312
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
313
+ self.N = N
314
+ self.M = M
315
+ self.isPIL = isPIL
316
+ if augs:
317
+ self.augs = augs
318
+ else:
319
+ self.augs = list(arg_dict.keys())
320
+
321
+ def get_random_ops(self):
322
+ sampled_ops = np.random.choice(self.augs, self.N)
323
+ return [(op, 0.5, self.M) for op in sampled_ops]
324
+
325
+ def __call__(self, img):
326
+ if self.isPIL:
327
+ img = np.array(img)
328
+ ops = self.get_random_ops()
329
+ for name, prob, level in ops:
330
+ if np.random.random() > prob:
331
+ continue
332
+ args = arg_dict[name](level)
333
+ img = func_dict[name](img, *args)
334
+ return img
335
+
336
+
337
+ if __name__ == '__main__':
338
+ a = RandomAugment()
339
+ img = np.random.randn(32, 32, 3)
340
+ a(img)
SMILE/BLIP/utils.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
3
+ """Decay the learning rate"""
4
+ lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
5
+ for param_group in optimizer.param_groups:
6
+ param_group['lr'] = lr
7
+
8
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
9
+ """Warmup the learning rate"""
10
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
11
+ for param_group in optimizer.param_groups:
12
+ param_group['lr'] = lr
13
+
14
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
15
+ """Decay the learning rate"""
16
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
17
+ for param_group in optimizer.param_groups:
18
+ param_group['lr'] = lr
19
+
20
+ import numpy as np
21
+ import io
22
+ import os
23
+ import time
24
+ from collections import defaultdict, deque
25
+ import datetime
26
+
27
+ import torch
28
+ import torch.distributed as dist
29
+
30
+ class SmoothedValue(object):
31
+ """Track a series of values and provide access to smoothed values over a
32
+ window or the global series average.
33
+ """
34
+
35
+ def __init__(self, window_size=20, fmt=None):
36
+ if fmt is None:
37
+ fmt = "{median:.4f} ({global_avg:.4f})"
38
+ self.deque = deque(maxlen=window_size)
39
+ self.total = 0.0
40
+ self.count = 0
41
+ self.fmt = fmt
42
+
43
+ def update(self, value, n=1):
44
+ self.deque.append(value)
45
+ self.count += n
46
+ self.total += value * n
47
+
48
+ def synchronize_between_processes(self):
49
+ """
50
+ Warning: does not synchronize the deque!
51
+ """
52
+ if not is_dist_avail_and_initialized():
53
+ return
54
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
55
+ dist.barrier()
56
+ dist.all_reduce(t)
57
+ t = t.tolist()
58
+ self.count = int(t[0])
59
+ self.total = t[1]
60
+
61
+ @property
62
+ def median(self):
63
+ d = torch.tensor(list(self.deque))
64
+ return d.median().item()
65
+
66
+ @property
67
+ def avg(self):
68
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
69
+ return d.mean().item()
70
+
71
+ @property
72
+ def global_avg(self):
73
+ return self.total / self.count
74
+
75
+ @property
76
+ def max(self):
77
+ return max(self.deque)
78
+
79
+ @property
80
+ def value(self):
81
+ return self.deque[-1]
82
+
83
+ def __str__(self):
84
+ return self.fmt.format(
85
+ median=self.median,
86
+ avg=self.avg,
87
+ global_avg=self.global_avg,
88
+ max=self.max,
89
+ value=self.value)
90
+
91
+
92
+ class MetricLogger(object):
93
+ def __init__(self, delimiter="\t"):
94
+ self.meters = defaultdict(SmoothedValue)
95
+ self.delimiter = delimiter
96
+
97
+ def update(self, **kwargs):
98
+ for k, v in kwargs.items():
99
+ if isinstance(v, torch.Tensor):
100
+ v = v.item()
101
+ assert isinstance(v, (float, int))
102
+ self.meters[k].update(v)
103
+
104
+ def __getattr__(self, attr):
105
+ if attr in self.meters:
106
+ return self.meters[attr]
107
+ if attr in self.__dict__:
108
+ return self.__dict__[attr]
109
+ raise AttributeError("'{}' object has no attribute '{}'".format(
110
+ type(self).__name__, attr))
111
+
112
+ def __str__(self):
113
+ loss_str = []
114
+ for name, meter in self.meters.items():
115
+ loss_str.append(
116
+ "{}: {}".format(name, str(meter))
117
+ )
118
+ return self.delimiter.join(loss_str)
119
+
120
+ def global_avg(self):
121
+ loss_str = []
122
+ for name, meter in self.meters.items():
123
+ loss_str.append(
124
+ "{}: {:.4f}".format(name, meter.global_avg)
125
+ )
126
+ return self.delimiter.join(loss_str)
127
+
128
+ def synchronize_between_processes(self):
129
+ for meter in self.meters.values():
130
+ meter.synchronize_between_processes()
131
+
132
+ def add_meter(self, name, meter):
133
+ self.meters[name] = meter
134
+
135
+ def log_every(self, iterable, print_freq, header=None):
136
+ i = 0
137
+ if not header:
138
+ header = ''
139
+ start_time = time.time()
140
+ end = time.time()
141
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
142
+ data_time = SmoothedValue(fmt='{avg:.4f}')
143
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
144
+ log_msg = [
145
+ header,
146
+ '[{0' + space_fmt + '}/{1}]',
147
+ 'eta: {eta}',
148
+ '{meters}',
149
+ 'time: {time}',
150
+ 'data: {data}'
151
+ ]
152
+ if torch.cuda.is_available():
153
+ log_msg.append('max mem: {memory:.0f}')
154
+ log_msg = self.delimiter.join(log_msg)
155
+ MB = 1024.0 * 1024.0
156
+ for obj in iterable:
157
+ data_time.update(time.time() - end)
158
+ yield obj
159
+ iter_time.update(time.time() - end)
160
+ if i % print_freq == 0 or i == len(iterable) - 1:
161
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
162
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
163
+ if torch.cuda.is_available():
164
+ print(log_msg.format(
165
+ i, len(iterable), eta=eta_string,
166
+ meters=str(self),
167
+ time=str(iter_time), data=str(data_time),
168
+ memory=torch.cuda.max_memory_allocated() / MB))
169
+ else:
170
+ print(log_msg.format(
171
+ i, len(iterable), eta=eta_string,
172
+ meters=str(self),
173
+ time=str(iter_time), data=str(data_time)))
174
+ i += 1
175
+ end = time.time()
176
+ total_time = time.time() - start_time
177
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
178
+ print('{} Total time: {} ({:.4f} s / it)'.format(
179
+ header, total_time_str, total_time / len(iterable)))
180
+
181
+
182
+ class AttrDict(dict):
183
+ def __init__(self, *args, **kwargs):
184
+ super(AttrDict, self).__init__(*args, **kwargs)
185
+ self.__dict__ = self
186
+
187
+
188
+ def compute_acc(logits, label, reduction='mean'):
189
+ ret = (torch.argmax(logits, dim=1) == label).float()
190
+ if reduction == 'none':
191
+ return ret.detach()
192
+ elif reduction == 'mean':
193
+ return ret.mean().item()
194
+
195
+ def compute_n_params(model, return_str=True):
196
+ tot = 0
197
+ for p in model.parameters():
198
+ w = 1
199
+ for x in p.shape:
200
+ w *= x
201
+ tot += w
202
+ if return_str:
203
+ if tot >= 1e6:
204
+ return '{:.1f}M'.format(tot / 1e6)
205
+ else:
206
+ return '{:.1f}K'.format(tot / 1e3)
207
+ else:
208
+ return tot
209
+
210
+ def setup_for_distributed(is_master):
211
+ """
212
+ This function disables printing when not in master process
213
+ """
214
+ import builtins as __builtin__
215
+ builtin_print = __builtin__.print
216
+
217
+ def print(*args, **kwargs):
218
+ force = kwargs.pop('force', False)
219
+ if is_master or force:
220
+ builtin_print(*args, **kwargs)
221
+
222
+ __builtin__.print = print
223
+
224
+
225
+ def is_dist_avail_and_initialized():
226
+ if not dist.is_available():
227
+ return False
228
+ if not dist.is_initialized():
229
+ return False
230
+ return True
231
+
232
+
233
+ def get_world_size():
234
+ if not is_dist_avail_and_initialized():
235
+ return 1
236
+ return dist.get_world_size()
237
+
238
+
239
+ def get_rank():
240
+ if not is_dist_avail_and_initialized():
241
+ return 0
242
+ return dist.get_rank()
243
+
244
+
245
+ def is_main_process():
246
+ return get_rank() == 0
247
+
248
+
249
+ def save_on_master(*args, **kwargs):
250
+ if is_main_process():
251
+ torch.save(*args, **kwargs)
252
+
253
+
254
+ def init_distributed_mode(args):
255
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
256
+ args.rank = int(os.environ["RANK"])
257
+ args.world_size = int(os.environ['WORLD_SIZE'])
258
+ args.gpu = int(os.environ['LOCAL_RANK'])
259
+ elif 'SLURM_PROCID' in os.environ:
260
+ args.rank = int(os.environ['SLURM_PROCID'])
261
+ args.gpu = args.rank % torch.cuda.device_count()
262
+ else:
263
+ print('Not using distributed mode')
264
+ args.distributed = False
265
+ return
266
+
267
+ args.distributed = True
268
+
269
+ torch.cuda.set_device(args.gpu)
270
+ args.dist_backend = 'nccl'
271
+ print('| distributed init (rank {}, word {}): {}'.format(
272
+ args.rank, args.world_size, args.dist_url), flush=True)
273
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
274
+ world_size=args.world_size, rank=args.rank)
275
+ torch.distributed.barrier()
276
+ setup_for_distributed(args.rank == 0)
277
+
278
+
SMILE/LICENSE ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ------------
2
+
3
+ BSD 3-Clause "New" or "Revised" License for SMILE/BLIP
4
+
5
+ Copyright (c) 2022, Salesforce.com, Inc.
6
+ All rights reserved.
7
+
8
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
11
+
12
+ * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
13
+
14
+ * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
15
+
16
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
17
+
18
+ ------------
19
+
20
+ MIT License for Remaining Contents in SMILE
21
+
22
+ Copyright (c) 2023 Zihao Yue, Renmin University of China
23
+
24
+ Permission is hereby granted, free of charge, to any person obtaining a copy
25
+ of this software and associated documentation files (the "Software"), to deal
26
+ in the Software without restriction, including without limitation the rights
27
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
28
+ copies of the Software, and to permit persons to whom the Software is
29
+ furnished to do so, subject to the following conditions:
30
+
31
+ The above copyright notice and this permission notice shall be included in all
32
+ copies or substantial portions of the Software.
33
+
34
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
35
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
36
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
37
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
38
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
39
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40
+ SOFTWARE.
SMILE/README.md ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div>
2
+ <h2 align="center">
3
+ 🫠 SMILE
4
+ </h2>
5
+ </div>
6
+
7
+ <p align="center">
8
+ <a >
9
+ <img alt="Issues" src="https://img.shields.io/github/issues/yuezih/SMILE?color=blueviolet" />
10
+ </a>
11
+ <a >
12
+ <img alt="Forks" src="https://img.shields.io/github/forks/yuezih/SMILE?color=orange" />
13
+ </a>
14
+ <a >
15
+ <img alt="Stars" src="https://img.shields.io/github/stars/yuezih/SMILE?color=ff69b4" />
16
+ </a>
17
+ <br />
18
+ </p>
19
+
20
+ [Learning Descriptive Image Captioning via Semipermeable Maximum Likelihood Estimation](https://arxiv.org/abs/2306.13460)
21
+
22
+ ![case.png](./assets/case.png)
23
+
24
+ ---
25
+
26
+ ## News 📢
27
+
28
+ - [2023.09.30] We now provide the code and our trained checkpoints (of BLIP) for quick deploying and easy reproduction. The previous demonstrative codes are now available at [demonstrative.md](./assets/demonstrative.md).
29
+ - [2023.06.26] We provide the demonstrative codes to show how to implement SMILE in your codebase, including a pseudocode, a [BLIP](https://github.com/salesforce/BLIP) version, and a [transformers](https://github.com/huggingface/transformers) version.
30
+
31
+ ## Demo
32
+
33
+ We are building online demos. Please stay tuned.
34
+
35
+ ## Usage
36
+
37
+ ```
38
+ git clone https://github.com/yuezih/SMILE
39
+ cd SMILE/BLIP
40
+ ```
41
+
42
+ ### Installation
43
+
44
+ ```
45
+ pip install -r requirements.txt
46
+ ```
47
+
48
+ The code has been tested on PyTorch 2.0.0.
49
+
50
+ ### Data Preparation
51
+
52
+ The data configs are in `SMILE/BLIP/configs/caption_coco.yaml`.
53
+ - Set the `image_root` to your MSCOCO image root.
54
+ - MSCOCO annotation files will be automatically downloaded.
55
+
56
+ ### Checkpoints
57
+
58
+ The pre-trained and MLE-finetuned checkpoints are available at the [original BLIP repo](https://github.com/salesforce/BLIP).
59
+
60
+ We provide our two checkpoints finetuned on MSCOCO with SMILE:
61
+ - `blip_smile_base.pth`: The vanilla SMILE-optimized BLIP.
62
+ - `blip_mle_smile_base.pth`: BLIP finetuned with MLE+SMILE (0.01:0.99), with a compromise between descriptiveness and accuracy.
63
+
64
+ Method|Download|Cap. Len.|Lex. Div.|R@1|R@5|CLIPScore|PPL
65
+ -|:-:|:-:|:-:|:-:|:-:|:-:|:-:
66
+ `blip_smile_base.pth`|[OneDrive](https://1drv.ms/u/s!AocXJ7uKxt6XcsGzBZ4XKoZWKJY?e=BW7fJK)|22.3|4.5|10.0|24.5|75.0|95.6
67
+ `blip_mle_smile_base.pth`|[OneDrive](https://1drv.ms/u/s!AocXJ7uKxt6Xc85rDJCdunDI0jU?e=eDpAGG)|19.8|3.6|**10.9**|**25.1**|76.2|79.4
68
+
69
+ Set the checkpoint path in `SMILE/BLIP/configs/caption_coco.yaml`.
70
+
71
+ ### Training & Inference
72
+
73
+ ```
74
+ bash scripts/train.sh
75
+ ```
76
+
77
+ ```
78
+ bash scripts/eval.sh
79
+ ```
80
+
81
+ Kind reminders:
82
+ - Please use `transformers==4.15.0` rather than a higher version.
83
+ - For `torch<=2.0.0`, replace `torchrun` with `python -m torch.distributed.run` in the training and inference scripts.
84
+
85
+ ## Citation
86
+
87
+ If you find this repo to be helpful for your research, please consider citing our paper:
88
+
89
+ ```bibtex
90
+ @misc{yue2023learning,
91
+ title={Learning Descriptive Image Captioning via Semipermeable Maximum Likelihood Estimation},
92
+ author={Zihao Yue and Anwen Hu and Liang Zhang and Qin Jin},
93
+ year={2023},
94
+ eprint={2306.13460},
95
+ archivePrefix={arXiv},
96
+ primaryClass={cs.CL}
97
+ }
98
+ ```
99
+
100
+ ## Acknowledgement
101
+
102
+ Our work relies on resources from [BLIP](https://github.com/salesforce/BLIP) and [HuggingFace transformers](https://github.com/huggingface/transformers). Many thanks to them for their amazing efforts.
SMILE/__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import requests
4
+ import torch
5
+ from torchvision import transforms
6
+ from torchvision.transforms.functional import InterpolationMode
7
+ import sys
8
+ sys.path.append('SMILE/BLIP')
9
+ from models.model import caption_model
10
+
11
+
12
+ image_size = 384
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ transform = transforms.Compose([
15
+ transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
18
+ ])
19
+
20
+ model_url = {
21
+ 'smile': 'model/blip_smile_base.pth',
22
+ 'mle_smile': 'model/blip_mle_smile_base.pth',
23
+ }
24
+ model_smile = caption_model(pretrained=model_url['smile'], image_size=image_size, vit='base')
25
+ model_smile.eval()
26
+ model_smile = model_smile.to(device)
27
+ model_mle_smile = caption_model(pretrained=model_url['mle_smile'], image_size=image_size, vit='base')
28
+ model_mle_smile.eval()
29
+ model_mle_smile = model_mle_smile.to(device)
30
+
31
+
32
+ def generate_caption(raw_image, strategy):
33
+ image = transform(raw_image).unsqueeze(0).to(device)
34
+ with torch.no_grad():
35
+ if strategy == "More Descriptive":
36
+ caption = model_smile.generate(image, sample=False, num_beams=3, max_length=75, min_length=1)
37
+ else:
38
+ caption = model_mle_smile.generate(image, sample=False, num_beams=3, max_length=75, min_length=1)
39
+
40
+ return str(caption[0]).replace(' - ', '-').lower() + '.'
41
+
42
+
43
+ inputs = [
44
+ gr.Image(type="pil"),
45
+ gr.Radio(choices=["More Descriptive", "More Accurate"], default="More Descriptive", label="Strategy")
46
+ ]
47
+
48
+ outputs = "text"
49
+
50
+ examples = [
51
+ ["example/COCO_val2014_000000093534.jpg", "More Descriptive"],
52
+ ["example/COCO_val2014_000000411845.jpg", "More Descriptive"],
53
+ ["example/COCO_val2014_000000001682.jpg", "More Descriptive"],
54
+ ["example/COCO_val2014_000000473133.jpg", "More Descriptive"],
55
+ ["example/COCO_val2014_000000562150.jpg", "More Descriptive"]
56
+ ]
57
+
58
+ description = """<p style='text-align: center'>Gradio demo for BLIP-SMILE: The most descriptive captioning model before the multimodal LLM era.</p><p style='text-align: center'><a href='https://arxiv.org/abs/2306.13460' target='_blank'>Paper</a> | <a href='https://github.com/yuezih/SMILE' target='_blank'>Github</a></p>"""
59
+
60
+ interface = gr.Interface(
61
+ generate_caption,
62
+ inputs,
63
+ outputs,
64
+ examples=examples,
65
+ title="BLIP-SMILE",
66
+ description=description,
67
+ allow_flagging='never',
68
+ )
69
+
70
+ interface.launch(share=True)
example/COCO_val2014_000000001682.jpg ADDED
example/COCO_val2014_000000093534.jpg ADDED
example/COCO_val2014_000000411845.jpg ADDED
example/COCO_val2014_000000473133.jpg ADDED
example/COCO_val2014_000000562150.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ timm==0.4.12
2
+ transformers==4.15.0
3
+ fairscale==0.4.4
4
+ pycocoevalcap
5
+ torch
6
+ torchvision
7
+ Pillow