helblazer811 commited on
Commit
08f9860
·
1 Parent(s): 227c367

Removed current branch concept attention source code and install it

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +1 -1
  2. concept_attention/__init__.py +0 -2
  3. concept_attention/binary_segmentation_baselines/__init__.py +0 -0
  4. concept_attention/binary_segmentation_baselines/__pycache__/__init__.cpython-310.pyc +0 -0
  5. concept_attention/binary_segmentation_baselines/__pycache__/chefer_clip_vit_baselines.cpython-310.pyc +0 -0
  6. concept_attention/binary_segmentation_baselines/__pycache__/clip_text_span_baseline.cpython-310.pyc +0 -0
  7. concept_attention/binary_segmentation_baselines/__pycache__/daam.cpython-310.pyc +0 -0
  8. concept_attention/binary_segmentation_baselines/__pycache__/daam_sd2.cpython-310.pyc +0 -0
  9. concept_attention/binary_segmentation_baselines/__pycache__/daam_sdxl.cpython-310.pyc +0 -0
  10. concept_attention/binary_segmentation_baselines/__pycache__/dino.cpython-310.pyc +0 -0
  11. concept_attention/binary_segmentation_baselines/__pycache__/raw_cross_attention.cpython-310.pyc +0 -0
  12. concept_attention/binary_segmentation_baselines/__pycache__/raw_output_space.cpython-310.pyc +0 -0
  13. concept_attention/binary_segmentation_baselines/__pycache__/raw_value_space.cpython-310.pyc +0 -0
  14. concept_attention/binary_segmentation_baselines/chefer_clip_vit_baselines.py +0 -272
  15. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_LRP.py +0 -437
  16. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_explanation_generator.py +0 -83
  17. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_new.py +0 -238
  18. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_orig_LRP.py +0 -425
  19. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_LRP.cpython-310.pyc +0 -0
  20. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_explanation_generator.cpython-310.pyc +0 -0
  21. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_new.cpython-310.pyc +0 -0
  22. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_orig_LRP.cpython-310.pyc +0 -0
  23. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/helpers.cpython-310.pyc +0 -0
  24. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/layer_helpers.cpython-310.pyc +0 -0
  25. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/weight_init.cpython-310.pyc +0 -0
  26. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/VOC.py +0 -395
  27. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__init__.py +0 -0
  28. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/Imagenet.cpython-310.pyc +0 -0
  29. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/VOC.cpython-310.pyc +0 -0
  30. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/__init__.cpython-310.pyc +0 -0
  31. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/imagenet.cpython-310.pyc +0 -0
  32. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/imagenet.py +0 -200
  33. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/imagenet_utils.py +0 -1002
  34. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/transforms.py +0 -442
  35. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/generate_visualizations.py +0 -208
  36. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/helpers.py +0 -295
  37. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/layer_helpers.py +0 -21
  38. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/misc_functions.py +0 -68
  39. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__init__.py +0 -0
  40. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  41. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__pycache__/layers_lrp.cpython-310.pyc +0 -0
  42. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__pycache__/layers_ours.cpython-310.pyc +0 -0
  43. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/layers_lrp.py +0 -261
  44. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/layers_ours.py +0 -280
  45. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/pertubation_eval_from_hdf5.py +0 -232
  46. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__init__.py +0 -0
  47. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  48. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__pycache__/confusionmatrix.cpython-310.pyc +0 -0
  49. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__pycache__/iou.cpython-310.pyc +0 -0
  50. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__pycache__/metric.cpython-310.pyc +0 -0
app.py CHANGED
@@ -21,7 +21,7 @@ def update_default_concepts(prompt):
21
 
22
  return gr.update(value=default_concepts.get(prompt, []))
23
 
24
- pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", offload_model=True) # , device="cuda:0") # , offload_model=True)
25
 
26
  def convert_pil_to_bytes(img):
27
  img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
 
21
 
22
  return gr.update(value=default_concepts.get(prompt, []))
23
 
24
+ pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell")# , offload_model=True) # , device="cuda:0") # , offload_model=True)
25
 
26
  def convert_pil_to_bytes(img):
27
  img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
concept_attention/__init__.py DELETED
@@ -1,2 +0,0 @@
1
-
2
- from concept_attention.concept_attention_pipeline import ConceptAttentionFluxPipeline
 
 
 
concept_attention/binary_segmentation_baselines/__init__.py DELETED
File without changes
concept_attention/binary_segmentation_baselines/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (214 Bytes)
 
concept_attention/binary_segmentation_baselines/__pycache__/chefer_clip_vit_baselines.cpython-310.pyc DELETED
Binary file (7.18 kB)
 
concept_attention/binary_segmentation_baselines/__pycache__/clip_text_span_baseline.cpython-310.pyc DELETED
Binary file (3.66 kB)
 
concept_attention/binary_segmentation_baselines/__pycache__/daam.cpython-310.pyc DELETED
Binary file (2.52 kB)
 
concept_attention/binary_segmentation_baselines/__pycache__/daam_sd2.cpython-310.pyc DELETED
Binary file (3.81 kB)
 
concept_attention/binary_segmentation_baselines/__pycache__/daam_sdxl.cpython-310.pyc DELETED
Binary file (4.69 kB)
 
concept_attention/binary_segmentation_baselines/__pycache__/dino.cpython-310.pyc DELETED
Binary file (2.93 kB)
 
concept_attention/binary_segmentation_baselines/__pycache__/raw_cross_attention.cpython-310.pyc DELETED
Binary file (5.85 kB)
 
concept_attention/binary_segmentation_baselines/__pycache__/raw_output_space.cpython-310.pyc DELETED
Binary file (5.83 kB)
 
concept_attention/binary_segmentation_baselines/__pycache__/raw_value_space.cpython-310.pyc DELETED
Binary file (6.64 kB)
 
concept_attention/binary_segmentation_baselines/chefer_clip_vit_baselines.py DELETED
@@ -1,272 +0,0 @@
1
- """
2
- This is just a wrapper around the various baselines implemented in the
3
- Chefer et. al. Transformer Explainability repository.
4
-
5
- Implements
6
- - CheferLRPSegmentationModel
7
- - CheferRolloutSegmentationModel
8
- - CheferLastLayerAttentionSegmentationModel
9
- - CheferAttentionGradCAMSegmentationModel
10
- - CheferTransformerAttributionSegmentationModel
11
- - CheferFullLRPSegmentationModel
12
- - CheferLastLayerLRPSegmentationModel
13
- """
14
-
15
- # # segmentation test for the rollout baseline
16
- # if args.method == 'rollout':
17
- # Res = baselines.generate_rollout(image.cuda(), start_layer=1).reshape(batch_size, 1, 14, 14)
18
-
19
- # # segmentation test for the LRP baseline (this is full LRP, not partial)
20
- # elif args.method == 'full_lrp':
21
- # Res = orig_lrp.generate_LRP(image.cuda(), method="full").reshape(batch_size, 1, 224, 224)
22
-
23
- # # segmentation test for our method
24
- # elif args.method == 'transformer_attribution':
25
- # Res = lrp.generate_LRP(image.cuda(), start_layer=1, method="transformer_attribution").reshape(batch_size, 1, 14, 14)
26
-
27
- # # segmentation test for the partial LRP baseline (last attn layer)
28
- # elif args.method == 'lrp_last_layer':
29
- # Res = orig_lrp.generate_LRP(image.cuda(), method="last_layer", is_ablation=args.is_ablation)\
30
- # .reshape(batch_size, 1, 14, 14)
31
-
32
- # # segmentation test for the raw attention baseline (last attn layer)
33
- # elif args.method == 'attn_last_layer':
34
- # Res = orig_lrp.generate_LRP(image.cuda(), method="last_layer_attn", is_ablation=args.is_ablation)\
35
- # .reshape(batch_size, 1, 14, 14)
36
-
37
- # # segmentation test for the GradCam baseline (last attn layer)
38
- # elif args.method == 'attn_gradcam':
39
- # Res = baselines.generate_cam_attn(image.cuda()).reshape(batch_size, 1, 14, 14)
40
-
41
- # if args.method != 'full_lrp':
42
- # # interpolate to full image size (224,224)
43
- # Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear').cuda()
44
-
45
- import torch
46
- import PIL
47
-
48
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_explanation_generator import LRP
49
- from concept_attention.segmentation import SegmentationAbstractClass
50
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_explanation_generator import Baselines, LRP
51
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_new import vit_base_patch16_224
52
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_LRP import vit_base_patch16_224 as vit_LRP
53
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
54
-
55
-
56
- # # Model
57
- # model = vit_base_patch16_224(pretrained=True).cuda()
58
- # baselines = Baselines(model)
59
-
60
- # # LRP
61
- # model_LRP = vit_LRP(pretrained=True).cuda()
62
- # model_LRP.eval()
63
- # lrp = LRP(model_LRP)
64
-
65
- # # orig LRP
66
- # model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
67
- # model_orig_LRP.eval()
68
- # orig_lrp = LRP(model_orig_LRP)
69
-
70
- # model.eval()
71
-
72
- class CheferLRPSegmentationModel(SegmentationAbstractClass):
73
-
74
- def __init__(
75
- self,
76
- device: str = "cuda",
77
- width: int = 224,
78
- height: int = 224,
79
- ):
80
- """
81
- Initialize the segmentation model.
82
- """
83
- super(CheferLRPSegmentationModel, self).__init__()
84
- self.width = width
85
- self.height = height
86
- self.device = device
87
- # Load the LRP model
88
- model_orig_LRP = vit_orig_LRP(pretrained=True).to(self.device)
89
- model_orig_LRP.eval()
90
- self.orig_lrp = LRP(model_orig_LRP)
91
-
92
- def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
93
- """
94
- Takes a real image and generates a concept segmentation map
95
- it by adding noise and running the DiT on it.
96
- """
97
- if len(image.shape) == 3:
98
- image = image.unsqueeze(0)
99
-
100
- prediction_map = self.orig_lrp.generate_LRP(
101
- image.to(self.device),
102
- method="full"
103
- )
104
- prediction_map = prediction_map.unsqueeze(0)
105
- # Rescale the prediction map to 64x64
106
- prediction_map = torch.nn.functional.interpolate(
107
- prediction_map,
108
- size=(self.width, self.height),
109
- mode="nearest"
110
- ).reshape(1, self.width, self.height)
111
-
112
- return prediction_map, None
113
-
114
- class CheferRolloutSegmentationModel(SegmentationAbstractClass):
115
-
116
- def __init__(self, device: str = "cuda", width: int = 224, height: int = 224):
117
- super(CheferRolloutSegmentationModel, self).__init__()
118
- self.width = width
119
- self.height = height
120
- self.device = device
121
- model = vit_base_patch16_224(pretrained=True).to(device)
122
- self.baselines = Baselines(model)
123
-
124
- def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
125
- if len(image.shape) == 3:
126
- image = image.unsqueeze(0)
127
- prediction_map = self.baselines.generate_rollout(
128
- image.to(self.device), start_layer=1
129
- ).reshape(1, 1, 14, 14)
130
- # Rescale the prediction map to 64x64
131
- prediction_map = torch.nn.functional.interpolate(
132
- prediction_map,
133
- size=(self.width, self.height),
134
- mode="nearest"
135
- ).reshape(1, self.width, self.height)
136
-
137
- return prediction_map, None
138
-
139
-
140
- class CheferLastLayerAttentionSegmentationModel(SegmentationAbstractClass):
141
-
142
- def __init__(self, device: str = "cuda", width: int = 224, height: int = 224):
143
- super(CheferLastLayerAttentionSegmentationModel, self).__init__()
144
- self.width = width
145
- self.height = height
146
- self.device = device
147
- model_orig_LRP = vit_orig_LRP(pretrained=True).to(device)
148
- model_orig_LRP.eval()
149
- self.orig_lrp = LRP(model_orig_LRP)
150
-
151
- def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
152
- if len(image.shape) == 3:
153
- image = image.unsqueeze(0)
154
-
155
- prediction_map = self.orig_lrp.generate_LRP(
156
- image.to(self.device), method="last_layer_attn"
157
- ).reshape(1, 1, 14, 14)
158
- # Rescale the prediction map to 64x64
159
- prediction_map = torch.nn.functional.interpolate(
160
- prediction_map,
161
- size=(self.width, self.height),
162
- mode="nearest"
163
- ).reshape(1, self.width, self.height)
164
-
165
- return prediction_map, None
166
-
167
-
168
- class CheferAttentionGradCAMSegmentationModel(SegmentationAbstractClass):
169
-
170
- def __init__(self, device: str = "cuda", width: int = 224, height: int = 224):
171
- super(CheferAttentionGradCAMSegmentationModel, self).__init__()
172
- self.width = width
173
- self.height = height
174
- self.device = device
175
- model = vit_base_patch16_224(pretrained=True).to(device)
176
- self.baselines = Baselines(model)
177
-
178
- def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
179
- if len(image.shape) == 3:
180
- image = image.unsqueeze(0)
181
- prediction_map = self.baselines.generate_cam_attn(
182
- image.to(self.device)
183
- ).reshape(1, 1, 14, 14)
184
- # Rescale the prediction map to 64x64
185
- prediction_map = torch.nn.functional.interpolate(
186
- prediction_map,
187
- size=(self.width, self.height),
188
- mode="nearest"
189
- ).reshape(1, self.width, self.height)
190
-
191
- return prediction_map, None
192
-
193
-
194
- class CheferTransformerAttributionSegmentationModel(SegmentationAbstractClass):
195
-
196
- def __init__(self, device: str = "cuda", width: int = 224, height: int = 224):
197
- super(CheferTransformerAttributionSegmentationModel, self).__init__()
198
- self.width = width
199
- self.height = height
200
- self.device = device
201
- model_LRP = vit_LRP(pretrained=True).to(device)
202
- model_LRP.eval()
203
- self.lrp = LRP(model_LRP)
204
-
205
- def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
206
- if len(image.shape) == 3:
207
- image = image.unsqueeze(0)
208
- prediction_map = self.lrp.generate_LRP(
209
- image.to(self.device), start_layer=1, method="transformer_attribution"
210
- ).reshape(1, 1, 14, 14)
211
- # Rescale the prediction map to 64x64
212
- prediction_map = torch.nn.functional.interpolate(
213
- prediction_map,
214
- size=(self.width, self.height),
215
- mode="nearest"
216
- ).reshape(1, self.width, self.height)
217
-
218
- return prediction_map, None
219
-
220
-
221
- class CheferFullLRPSegmentationModel(SegmentationAbstractClass):
222
-
223
- def __init__(self, device: str = "cuda", width: int = 224, height: int = 224):
224
- super(CheferFullLRPSegmentationModel, self).__init__()
225
- self.width = width
226
- self.height = height
227
- self.device = device
228
- model_LRP = vit_LRP(pretrained=True).to(device)
229
- model_LRP.eval()
230
- self.lrp = LRP(model_LRP)
231
-
232
- def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
233
- if len(image.shape) == 3:
234
- image = image.unsqueeze(0)
235
- prediction_map = self.lrp.generate_LRP(
236
- image.to(self.device), method="full"
237
- ).reshape(1, 1, 224, 224)
238
- # Rescale the prediction map to 64x64
239
- prediction_map = torch.nn.functional.interpolate(
240
- prediction_map,
241
- size=(self.width, self.height),
242
- mode="nearest"
243
- ).reshape(1, self.width, self.height)
244
-
245
- return prediction_map, None
246
-
247
-
248
- class CheferLastLayerLRPSegmentationModel(SegmentationAbstractClass):
249
-
250
- def __init__(self, device: str = "cuda", width: int = 224, height: int = 224):
251
- super(CheferLastLayerLRPSegmentationModel, self).__init__()
252
- self.width = width
253
- self.height = height
254
- self.device = device
255
- model_LRP = vit_LRP(pretrained=True).to(device)
256
- model_LRP.eval()
257
- self.lrp = LRP(model_LRP)
258
-
259
- def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
260
- if len(image.shape) == 3:
261
- image = image.unsqueeze(0)
262
- prediction_map = self.lrp.generate_LRP(
263
- image.to(self.device), method="last_layer"
264
- ).reshape(1, 1, 14, 14)
265
- # Rescale the prediction map to 64x64
266
- prediction_map = torch.nn.functional.interpolate(
267
- prediction_map,
268
- size=(self.width, self.height),
269
- mode="nearest"
270
- ).reshape(1, self.width, self.height)
271
-
272
- return prediction_map, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_LRP.py DELETED
@@ -1,437 +0,0 @@
1
- """ Vision Transformer (ViT) in PyTorch
2
- Hacked together by / Copyright 2020 Ross Wightman
3
- """
4
- import torch
5
- import torch.nn as nn
6
- from einops import rearrange
7
-
8
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.modules.layers_ours import *
9
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.helpers import load_pretrained
10
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.weight_init import trunc_normal_
11
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.layer_helpers import to_2tuple
12
-
13
-
14
- def _cfg(url='', **kwargs):
15
- return {
16
- 'url': url,
17
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
18
- 'crop_pct': .9, 'interpolation': 'bicubic',
19
- 'first_conv': 'patch_embed.proj', 'classifier': 'head',
20
- **kwargs
21
- }
22
-
23
-
24
- default_cfgs = {
25
- # patch models
26
- 'vit_small_patch16_224': _cfg(
27
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
28
- ),
29
- 'vit_base_patch16_224': _cfg(
30
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
31
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
32
- ),
33
- 'vit_large_patch16_224': _cfg(
34
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
35
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
36
- }
37
-
38
- def compute_rollout_attention(all_layer_matrices, start_layer=0):
39
- # adding residual consideration
40
- num_tokens = all_layer_matrices[0].shape[1]
41
- batch_size = all_layer_matrices[0].shape[0]
42
- eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
43
- all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
44
- # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
45
- # for i in range(len(all_layer_matrices))]
46
- joint_attention = all_layer_matrices[start_layer]
47
- for i in range(start_layer+1, len(all_layer_matrices)):
48
- joint_attention = all_layer_matrices[i].bmm(joint_attention)
49
- return joint_attention
50
-
51
- class Mlp(nn.Module):
52
- def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
53
- super().__init__()
54
- out_features = out_features or in_features
55
- hidden_features = hidden_features or in_features
56
- self.fc1 = Linear(in_features, hidden_features)
57
- self.act = GELU()
58
- self.fc2 = Linear(hidden_features, out_features)
59
- self.drop = Dropout(drop)
60
-
61
- def forward(self, x):
62
- x = self.fc1(x)
63
- x = self.act(x)
64
- x = self.drop(x)
65
- x = self.fc2(x)
66
- x = self.drop(x)
67
- return x
68
-
69
- def relprop(self, cam, **kwargs):
70
- cam = self.drop.relprop(cam, **kwargs)
71
- cam = self.fc2.relprop(cam, **kwargs)
72
- cam = self.act.relprop(cam, **kwargs)
73
- cam = self.fc1.relprop(cam, **kwargs)
74
- return cam
75
-
76
-
77
- class Attention(nn.Module):
78
- def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
79
- super().__init__()
80
- self.num_heads = num_heads
81
- head_dim = dim // num_heads
82
- # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
83
- self.scale = head_dim ** -0.5
84
-
85
- # A = Q*K^T
86
- self.matmul1 = einsum('bhid,bhjd->bhij')
87
- # attn = A*V
88
- self.matmul2 = einsum('bhij,bhjd->bhid')
89
-
90
- self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
91
- self.attn_drop = Dropout(attn_drop)
92
- self.proj = Linear(dim, dim)
93
- self.proj_drop = Dropout(proj_drop)
94
- self.softmax = Softmax(dim=-1)
95
-
96
- self.attn_cam = None
97
- self.attn = None
98
- self.v = None
99
- self.v_cam = None
100
- self.attn_gradients = None
101
-
102
- def get_attn(self):
103
- return self.attn
104
-
105
- def save_attn(self, attn):
106
- self.attn = attn
107
-
108
- def save_attn_cam(self, cam):
109
- self.attn_cam = cam
110
-
111
- def get_attn_cam(self):
112
- return self.attn_cam
113
-
114
- def get_v(self):
115
- return self.v
116
-
117
- def save_v(self, v):
118
- self.v = v
119
-
120
- def save_v_cam(self, cam):
121
- self.v_cam = cam
122
-
123
- def get_v_cam(self):
124
- return self.v_cam
125
-
126
- def save_attn_gradients(self, attn_gradients):
127
- self.attn_gradients = attn_gradients
128
-
129
- def get_attn_gradients(self):
130
- return self.attn_gradients
131
-
132
- def forward(self, x):
133
- b, n, _, h = *x.shape, self.num_heads
134
- qkv = self.qkv(x)
135
- q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)
136
-
137
- self.save_v(v)
138
-
139
- dots = self.matmul1([q, k]) * self.scale
140
-
141
- attn = self.softmax(dots)
142
- attn = self.attn_drop(attn)
143
-
144
- self.save_attn(attn)
145
- attn.register_hook(self.save_attn_gradients)
146
-
147
- out = self.matmul2([attn, v])
148
- out = rearrange(out, 'b h n d -> b n (h d)')
149
-
150
- out = self.proj(out)
151
- out = self.proj_drop(out)
152
- return out
153
-
154
- def relprop(self, cam, **kwargs):
155
- cam = self.proj_drop.relprop(cam, **kwargs)
156
- cam = self.proj.relprop(cam, **kwargs)
157
- cam = rearrange(cam, 'b n (h d) -> b h n d', h=self.num_heads)
158
-
159
- # attn = A*V
160
- (cam1, cam_v)= self.matmul2.relprop(cam, **kwargs)
161
- cam1 /= 2
162
- cam_v /= 2
163
-
164
- self.save_v_cam(cam_v)
165
- self.save_attn_cam(cam1)
166
-
167
- cam1 = self.attn_drop.relprop(cam1, **kwargs)
168
- cam1 = self.softmax.relprop(cam1, **kwargs)
169
-
170
- # A = Q*K^T
171
- (cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
172
- cam_q /= 2
173
- cam_k /= 2
174
-
175
- cam_qkv = rearrange([cam_q, cam_k, cam_v], 'qkv b h n d -> b n (qkv h d)', qkv=3, h=self.num_heads)
176
-
177
- return self.qkv.relprop(cam_qkv, **kwargs)
178
-
179
-
180
- class Block(nn.Module):
181
-
182
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
183
- super().__init__()
184
- self.norm1 = LayerNorm(dim, eps=1e-6)
185
- self.attn = Attention(
186
- dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
187
- self.norm2 = LayerNorm(dim, eps=1e-6)
188
- mlp_hidden_dim = int(dim * mlp_ratio)
189
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
190
-
191
- self.add1 = Add()
192
- self.add2 = Add()
193
- self.clone1 = Clone()
194
- self.clone2 = Clone()
195
-
196
- def forward(self, x):
197
- x1, x2 = self.clone1(x, 2)
198
- x = self.add1([x1, self.attn(self.norm1(x2))])
199
- x1, x2 = self.clone2(x, 2)
200
- x = self.add2([x1, self.mlp(self.norm2(x2))])
201
- return x
202
-
203
- def relprop(self, cam, **kwargs):
204
- (cam1, cam2) = self.add2.relprop(cam, **kwargs)
205
- cam2 = self.mlp.relprop(cam2, **kwargs)
206
- cam2 = self.norm2.relprop(cam2, **kwargs)
207
- cam = self.clone2.relprop((cam1, cam2), **kwargs)
208
-
209
- (cam1, cam2) = self.add1.relprop(cam, **kwargs)
210
- cam2 = self.attn.relprop(cam2, **kwargs)
211
- cam2 = self.norm1.relprop(cam2, **kwargs)
212
- cam = self.clone1.relprop((cam1, cam2), **kwargs)
213
- return cam
214
-
215
-
216
- class PatchEmbed(nn.Module):
217
- """ Image to Patch Embedding
218
- """
219
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
220
- super().__init__()
221
- img_size = to_2tuple(img_size)
222
- patch_size = to_2tuple(patch_size)
223
- num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
224
- self.img_size = img_size
225
- self.patch_size = patch_size
226
- self.num_patches = num_patches
227
-
228
- self.proj = Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
229
-
230
- def forward(self, x):
231
- B, C, H, W = x.shape
232
- # FIXME look at relaxing size constraints
233
- assert H == self.img_size[0] and W == self.img_size[1], \
234
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
235
- x = self.proj(x).flatten(2).transpose(1, 2)
236
- return x
237
-
238
- def relprop(self, cam, **kwargs):
239
- cam = cam.transpose(1,2)
240
- cam = cam.reshape(cam.shape[0], cam.shape[1],
241
- (self.img_size[0] // self.patch_size[0]), (self.img_size[1] // self.patch_size[1]))
242
- return self.proj.relprop(cam, **kwargs)
243
-
244
-
245
- class VisionTransformer(nn.Module):
246
- """ Vision Transformer with support for patch or hybrid CNN input stage
247
- """
248
- def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
249
- num_heads=12, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0., attn_drop_rate=0.):
250
- super().__init__()
251
- self.num_classes = num_classes
252
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
253
- self.patch_embed = PatchEmbed(
254
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
255
- num_patches = self.patch_embed.num_patches
256
-
257
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
258
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
259
-
260
- self.blocks = nn.ModuleList([
261
- Block(
262
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
263
- drop=drop_rate, attn_drop=attn_drop_rate)
264
- for i in range(depth)])
265
-
266
- self.norm = LayerNorm(embed_dim)
267
- if mlp_head:
268
- # paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
269
- self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes)
270
- else:
271
- # with a single Linear layer as head, the param count within rounding of paper
272
- self.head = Linear(embed_dim, num_classes)
273
-
274
- # FIXME not quite sure what the proper weight init is supposed to be,
275
- # normal / trunc normal w/ std == .02 similar to other Bert like transformers
276
- trunc_normal_(self.pos_embed, std=.02) # embeddings same as weights?
277
- trunc_normal_(self.cls_token, std=.02)
278
- self.apply(self._init_weights)
279
-
280
- self.pool = IndexSelect()
281
- self.add = Add()
282
-
283
- self.inp_grad = None
284
-
285
- def save_inp_grad(self,grad):
286
- self.inp_grad = grad
287
-
288
- def get_inp_grad(self):
289
- return self.inp_grad
290
-
291
-
292
- def _init_weights(self, m):
293
- if isinstance(m, nn.Linear):
294
- trunc_normal_(m.weight, std=.02)
295
- if isinstance(m, nn.Linear) and m.bias is not None:
296
- nn.init.constant_(m.bias, 0)
297
- elif isinstance(m, nn.LayerNorm):
298
- nn.init.constant_(m.bias, 0)
299
- nn.init.constant_(m.weight, 1.0)
300
-
301
- @property
302
- def no_weight_decay(self):
303
- return {'pos_embed', 'cls_token'}
304
-
305
- def forward(self, x):
306
- B = x.shape[0]
307
- x = self.patch_embed(x)
308
-
309
- cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
310
- x = torch.cat((cls_tokens, x), dim=1)
311
- x = self.add([x, self.pos_embed])
312
-
313
- x.register_hook(self.save_inp_grad)
314
-
315
- for blk in self.blocks:
316
- x = blk(x)
317
-
318
- x = self.norm(x)
319
- x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
320
- x = x.squeeze(1)
321
- x = self.head(x)
322
- return x
323
-
324
- def relprop(self, cam=None,method="transformer_attribution", is_ablation=False, start_layer=0, **kwargs):
325
- # print(kwargs)
326
- # print("conservation 1", cam.sum())
327
- cam = self.head.relprop(cam, **kwargs)
328
- cam = cam.unsqueeze(1)
329
- cam = self.pool.relprop(cam, **kwargs)
330
- cam = self.norm.relprop(cam, **kwargs)
331
- for blk in reversed(self.blocks):
332
- cam = blk.relprop(cam, **kwargs)
333
-
334
- # print("conservation 2", cam.sum())
335
- # print("min", cam.min())
336
-
337
- if method == "full":
338
- (cam, _) = self.add.relprop(cam, **kwargs)
339
- cam = cam[:, 1:]
340
- cam = self.patch_embed.relprop(cam, **kwargs)
341
- # sum on channels
342
- cam = cam.sum(dim=1)
343
- return cam
344
-
345
- elif method == "rollout":
346
- # cam rollout
347
- attn_cams = []
348
- for blk in self.blocks:
349
- attn_heads = blk.attn.get_attn_cam().clamp(min=0)
350
- avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
351
- attn_cams.append(avg_heads)
352
- cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
353
- cam = cam[:, 0, 1:]
354
- return cam
355
-
356
- # our method, method name grad is legacy
357
- elif method == "transformer_attribution" or method == "grad":
358
- cams = []
359
- for blk in self.blocks:
360
- grad = blk.attn.get_attn_gradients()
361
- cam = blk.attn.get_attn_cam()
362
- cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
363
- grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
364
- cam = grad * cam
365
- cam = cam.clamp(min=0).mean(dim=0)
366
- cams.append(cam.unsqueeze(0))
367
- rollout = compute_rollout_attention(cams, start_layer=start_layer)
368
- cam = rollout[:, 0, 1:]
369
- return cam
370
-
371
- elif method == "last_layer":
372
- cam = self.blocks[-1].attn.get_attn_cam()
373
- cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
374
- if is_ablation:
375
- grad = self.blocks[-1].attn.get_attn_gradients()
376
- grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
377
- cam = grad * cam
378
- cam = cam.clamp(min=0).mean(dim=0)
379
- cam = cam[0, 1:]
380
- return cam
381
-
382
- elif method == "last_layer_attn":
383
- cam = self.blocks[-1].attn.get_attn()
384
- cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
385
- cam = cam.clamp(min=0).mean(dim=0)
386
- cam = cam[0, 1:]
387
- return cam
388
-
389
- elif method == "second_layer":
390
- cam = self.blocks[1].attn.get_attn_cam()
391
- cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
392
- if is_ablation:
393
- grad = self.blocks[1].attn.get_attn_gradients()
394
- grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
395
- cam = grad * cam
396
- cam = cam.clamp(min=0).mean(dim=0)
397
- cam = cam[0, 1:]
398
- return cam
399
-
400
-
401
- def _conv_filter(state_dict, patch_size=16):
402
- """ convert patch embedding weight from manual patchify + linear proj to conv"""
403
- out_dict = {}
404
- for k, v in state_dict.items():
405
- if 'patch_embed.proj.weight' in k:
406
- v = v.reshape((v.shape[0], 3, patch_size, patch_size))
407
- out_dict[k] = v
408
- return out_dict
409
-
410
- def vit_base_patch16_224(pretrained=False, **kwargs):
411
- model = VisionTransformer(
412
- patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
413
- model.default_cfg = default_cfgs['vit_base_patch16_224']
414
- if pretrained:
415
- load_pretrained(
416
- model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
417
- return model
418
-
419
- def vit_large_patch16_224(pretrained=False, **kwargs):
420
- model = VisionTransformer(
421
- patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, **kwargs)
422
- model.default_cfg = default_cfgs['vit_large_patch16_224']
423
- if pretrained:
424
- load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
425
- return model
426
-
427
- def deit_base_patch16_224(pretrained=False, **kwargs):
428
- model = VisionTransformer(
429
- patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
430
- model.default_cfg = _cfg()
431
- if pretrained:
432
- checkpoint = torch.hub.load_state_dict_from_url(
433
- url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
434
- map_location="cpu", check_hash=True
435
- )
436
- model.load_state_dict(checkpoint["model"])
437
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_explanation_generator.py DELETED
@@ -1,83 +0,0 @@
1
- import argparse
2
- import torch
3
- import numpy as np
4
- from numpy import *
5
-
6
- # compute rollout between attention layers
7
- def compute_rollout_attention(all_layer_matrices, start_layer=0):
8
- # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
9
- num_tokens = all_layer_matrices[0].shape[1]
10
- batch_size = all_layer_matrices[0].shape[0]
11
- eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
12
- all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
13
- matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
14
- for i in range(len(all_layer_matrices))]
15
- joint_attention = matrices_aug[start_layer]
16
- for i in range(start_layer+1, len(matrices_aug)):
17
- joint_attention = matrices_aug[i].bmm(joint_attention)
18
- return joint_attention
19
-
20
- class LRP:
21
- def __init__(self, model):
22
- self.model = model
23
- self.model.eval()
24
-
25
- def generate_LRP(self, input, index=None, method="transformer_attribution", is_ablation=False, start_layer=0):
26
- output = self.model(input)
27
- kwargs = {"alpha": 1}
28
- if index == None:
29
- index = np.argmax(output.cpu().data.numpy(), axis=-1)
30
-
31
- one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
32
- one_hot[0, index] = 1
33
- one_hot_vector = one_hot
34
- one_hot = torch.from_numpy(one_hot).requires_grad_(True)
35
- one_hot = torch.sum(one_hot.to(input.device) * output)
36
-
37
- self.model.zero_grad()
38
- one_hot.backward(retain_graph=True)
39
-
40
- return self.model.relprop(torch.tensor(one_hot_vector).to(input.device), method=method, is_ablation=is_ablation,
41
- start_layer=start_layer, **kwargs)
42
-
43
-
44
-
45
- class Baselines:
46
- def __init__(self, model):
47
- self.model = model
48
- self.model.eval()
49
-
50
- def generate_cam_attn(self, input, index=None):
51
- output = self.model(input, register_hook=True)
52
- if index == None:
53
- index = np.argmax(output.cpu().data.numpy())
54
-
55
- one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
56
- one_hot[0][index] = 1
57
- one_hot = torch.from_numpy(one_hot).requires_grad_(True)
58
- one_hot = torch.sum(one_hot.to(output.device) * output)
59
-
60
- self.model.zero_grad()
61
- one_hot.backward(retain_graph=True)
62
- #################### attn
63
- grad = self.model.blocks[-1].attn.get_attn_gradients()
64
- cam = self.model.blocks[-1].attn.get_attention_map()
65
- cam = cam[0, :, 0, 1:].reshape(-1, 14, 14)
66
- grad = grad[0, :, 0, 1:].reshape(-1, 14, 14)
67
- grad = grad.mean(dim=[1, 2], keepdim=True)
68
- cam = (cam * grad).mean(0).clamp(min=0)
69
- cam = (cam - cam.min()) / (cam.max() - cam.min())
70
-
71
- return cam
72
- #################### attn
73
-
74
- def generate_rollout(self, input, start_layer=0):
75
- self.model(input)
76
- blocks = self.model.blocks
77
- all_layer_attentions = []
78
- for blk in blocks:
79
- attn_heads = blk.attn.get_attention_map()
80
- avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
81
- all_layer_attentions.append(avg_heads)
82
- rollout = compute_rollout_attention(all_layer_attentions, start_layer=start_layer)
83
- return rollout[:,0, 1:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_new.py DELETED
@@ -1,238 +0,0 @@
1
- """ Vision Transformer (ViT) in PyTorch
2
- Hacked together by / Copyright 2020 Ross Wightman
3
- """
4
- import torch
5
- import torch.nn as nn
6
- from functools import partial
7
- from einops import rearrange
8
-
9
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.helpers import load_pretrained
10
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.weight_init import trunc_normal_
11
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.layer_helpers import to_2tuple
12
-
13
-
14
- def _cfg(url='', **kwargs):
15
- return {
16
- 'url': url,
17
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
18
- 'crop_pct': .9, 'interpolation': 'bicubic',
19
- 'first_conv': 'patch_embed.proj', 'classifier': 'head',
20
- **kwargs
21
- }
22
-
23
-
24
- default_cfgs = {
25
- # patch models
26
- 'vit_small_patch16_224': _cfg(
27
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
28
- ),
29
- 'vit_base_patch16_224': _cfg(
30
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
31
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
32
- ),
33
- 'vit_large_patch16_224': _cfg(
34
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
35
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
36
- }
37
-
38
- class Mlp(nn.Module):
39
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
40
- super().__init__()
41
- out_features = out_features or in_features
42
- hidden_features = hidden_features or in_features
43
- self.fc1 = nn.Linear(in_features, hidden_features)
44
- self.act = act_layer()
45
- self.fc2 = nn.Linear(hidden_features, out_features)
46
- self.drop = nn.Dropout(drop)
47
-
48
- def forward(self, x):
49
- x = self.fc1(x)
50
- x = self.act(x)
51
- x = self.drop(x)
52
- x = self.fc2(x)
53
- x = self.drop(x)
54
- return x
55
-
56
-
57
- class Attention(nn.Module):
58
- def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
59
- super().__init__()
60
- self.num_heads = num_heads
61
- head_dim = dim // num_heads
62
- # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
63
- self.scale = head_dim ** -0.5
64
-
65
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
66
- self.attn_drop = nn.Dropout(attn_drop)
67
- self.proj = nn.Linear(dim, dim)
68
- self.proj_drop = nn.Dropout(proj_drop)
69
-
70
- self.attn_gradients = None
71
- self.attention_map = None
72
-
73
- def save_attn_gradients(self, attn_gradients):
74
- self.attn_gradients = attn_gradients
75
-
76
- def get_attn_gradients(self):
77
- return self.attn_gradients
78
-
79
- def save_attention_map(self, attention_map):
80
- self.attention_map = attention_map
81
-
82
- def get_attention_map(self):
83
- return self.attention_map
84
-
85
- def forward(self, x, register_hook=False):
86
- b, n, _, h = *x.shape, self.num_heads
87
-
88
- # self.save_output(x)
89
- # x.register_hook(self.save_output_grad)
90
-
91
- qkv = self.qkv(x)
92
- q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
93
-
94
- dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
95
-
96
- attn = dots.softmax(dim=-1)
97
- attn = self.attn_drop(attn)
98
-
99
- out = torch.einsum('bhij,bhjd->bhid', attn, v)
100
-
101
- self.save_attention_map(attn)
102
- if register_hook:
103
- attn.register_hook(self.save_attn_gradients)
104
-
105
- out = rearrange(out, 'b h n d -> b n (h d)')
106
- out = self.proj(out)
107
- out = self.proj_drop(out)
108
- return out
109
-
110
-
111
- class Block(nn.Module):
112
-
113
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
114
- super().__init__()
115
- self.norm1 = norm_layer(dim)
116
- self.attn = Attention(
117
- dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
118
- self.norm2 = norm_layer(dim)
119
- mlp_hidden_dim = int(dim * mlp_ratio)
120
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
121
-
122
- def forward(self, x, register_hook=False):
123
- x = x + self.attn(self.norm1(x), register_hook=register_hook)
124
- x = x + self.mlp(self.norm2(x))
125
- return x
126
-
127
-
128
- class PatchEmbed(nn.Module):
129
- """ Image to Patch Embedding
130
- """
131
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
132
- super().__init__()
133
- img_size = to_2tuple(img_size)
134
- patch_size = to_2tuple(patch_size)
135
- num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
136
- self.img_size = img_size
137
- self.patch_size = patch_size
138
- self.num_patches = num_patches
139
-
140
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
141
-
142
- def forward(self, x):
143
- B, C, H, W = x.shape
144
- # FIXME look at relaxing size constraints
145
- assert H == self.img_size[0] and W == self.img_size[1], \
146
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
147
- x = self.proj(x).flatten(2).transpose(1, 2)
148
- return x
149
-
150
- class VisionTransformer(nn.Module):
151
- """ Vision Transformer
152
- """
153
- def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
154
- num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., norm_layer=nn.LayerNorm):
155
- super().__init__()
156
- self.num_classes = num_classes
157
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
158
- self.patch_embed = PatchEmbed(
159
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
160
- num_patches = self.patch_embed.num_patches
161
-
162
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
163
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
164
- self.pos_drop = nn.Dropout(p=drop_rate)
165
-
166
- self.blocks = nn.ModuleList([
167
- Block(
168
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
169
- drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)
170
- for i in range(depth)])
171
- self.norm = norm_layer(embed_dim)
172
-
173
- # Classifier head
174
- self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
175
-
176
- trunc_normal_(self.pos_embed, std=.02)
177
- trunc_normal_(self.cls_token, std=.02)
178
- self.apply(self._init_weights)
179
-
180
- def _init_weights(self, m):
181
- if isinstance(m, nn.Linear):
182
- trunc_normal_(m.weight, std=.02)
183
- if isinstance(m, nn.Linear) and m.bias is not None:
184
- nn.init.constant_(m.bias, 0)
185
- elif isinstance(m, nn.LayerNorm):
186
- nn.init.constant_(m.bias, 0)
187
- nn.init.constant_(m.weight, 1.0)
188
-
189
- @torch.jit.ignore
190
- def no_weight_decay(self):
191
- return {'pos_embed', 'cls_token'}
192
-
193
- def forward(self, x, register_hook=False):
194
- B = x.shape[0]
195
- x = self.patch_embed(x)
196
-
197
- cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
198
- x = torch.cat((cls_tokens, x), dim=1)
199
- x = x + self.pos_embed
200
- x = self.pos_drop(x)
201
-
202
- for blk in self.blocks:
203
- x = blk(x, register_hook=register_hook)
204
-
205
- x = self.norm(x)
206
- x = x[:, 0]
207
- x = self.head(x)
208
- return x
209
-
210
-
211
- def _conv_filter(state_dict, patch_size=16):
212
- """ convert patch embedding weight from manual patchify + linear proj to conv"""
213
- out_dict = {}
214
- for k, v in state_dict.items():
215
- if 'patch_embed.proj.weight' in k:
216
- v = v.reshape((v.shape[0], 3, patch_size, patch_size))
217
- out_dict[k] = v
218
- return out_dict
219
-
220
-
221
- def vit_base_patch16_224(pretrained=False, **kwargs):
222
- model = VisionTransformer(
223
- patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
224
- norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
225
- model.default_cfg = default_cfgs['vit_base_patch16_224']
226
- if pretrained:
227
- load_pretrained(
228
- model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
229
- return model
230
-
231
- def vit_large_patch16_224(pretrained=False, **kwargs):
232
- model = VisionTransformer(
233
- patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
234
- norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
235
- model.default_cfg = default_cfgs['vit_large_patch16_224']
236
- if pretrained:
237
- load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
238
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_orig_LRP.py DELETED
@@ -1,425 +0,0 @@
1
- """ Vision Transformer (ViT) in PyTorch
2
- Hacked together by / Copyright 2020 Ross Wightman
3
- """
4
- import torch
5
- import torch.nn as nn
6
- from einops import rearrange
7
-
8
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.modules.layers_lrp import *
9
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.helpers import load_pretrained
10
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.weight_init import trunc_normal_
11
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.layer_helpers import to_2tuple
12
-
13
-
14
- def _cfg(url='', **kwargs):
15
- return {
16
- 'url': url,
17
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
18
- 'crop_pct': .9, 'interpolation': 'bicubic',
19
- 'first_conv': 'patch_embed.proj', 'classifier': 'head',
20
- **kwargs
21
- }
22
-
23
-
24
- default_cfgs = {
25
- # patch models
26
- 'vit_small_patch16_224': _cfg(
27
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
28
- ),
29
- 'vit_base_patch16_224': _cfg(
30
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
31
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
32
- ),
33
- 'vit_large_patch16_224': _cfg(
34
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
35
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
36
- }
37
-
38
- def compute_rollout_attention(all_layer_matrices, start_layer=0):
39
- # adding residual consideration
40
- num_tokens = all_layer_matrices[0].shape[1]
41
- batch_size = all_layer_matrices[0].shape[0]
42
- eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
43
- all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
44
- # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
45
- # for i in range(len(all_layer_matrices))]
46
- joint_attention = all_layer_matrices[start_layer]
47
- for i in range(start_layer+1, len(all_layer_matrices)):
48
- joint_attention = all_layer_matrices[i].bmm(joint_attention)
49
- return joint_attention
50
-
51
- class Mlp(nn.Module):
52
- def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
53
- super().__init__()
54
- out_features = out_features or in_features
55
- hidden_features = hidden_features or in_features
56
- self.fc1 = Linear(in_features, hidden_features)
57
- self.act = GELU()
58
- self.fc2 = Linear(hidden_features, out_features)
59
- self.drop = Dropout(drop)
60
-
61
- def forward(self, x):
62
- x = self.fc1(x)
63
- x = self.act(x)
64
- x = self.drop(x)
65
- x = self.fc2(x)
66
- x = self.drop(x)
67
- return x
68
-
69
- def relprop(self, cam, **kwargs):
70
- cam = self.drop.relprop(cam, **kwargs)
71
- cam = self.fc2.relprop(cam, **kwargs)
72
- cam = self.act.relprop(cam, **kwargs)
73
- cam = self.fc1.relprop(cam, **kwargs)
74
- return cam
75
-
76
-
77
- class Attention(nn.Module):
78
- def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
79
- super().__init__()
80
- self.num_heads = num_heads
81
- head_dim = dim // num_heads
82
- # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
83
- self.scale = head_dim ** -0.5
84
-
85
- # A = Q*K^T
86
- self.matmul1 = einsum('bhid,bhjd->bhij')
87
- # attn = A*V
88
- self.matmul2 = einsum('bhij,bhjd->bhid')
89
-
90
- self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
91
- self.attn_drop = Dropout(attn_drop)
92
- self.proj = Linear(dim, dim)
93
- self.proj_drop = Dropout(proj_drop)
94
- self.softmax = Softmax(dim=-1)
95
-
96
- self.attn_cam = None
97
- self.attn = None
98
- self.v = None
99
- self.v_cam = None
100
- self.attn_gradients = None
101
-
102
- def get_attn(self):
103
- return self.attn
104
-
105
- def save_attn(self, attn):
106
- self.attn = attn
107
-
108
- def save_attn_cam(self, cam):
109
- self.attn_cam = cam
110
-
111
- def get_attn_cam(self):
112
- return self.attn_cam
113
-
114
- def get_v(self):
115
- return self.v
116
-
117
- def save_v(self, v):
118
- self.v = v
119
-
120
- def save_v_cam(self, cam):
121
- self.v_cam = cam
122
-
123
- def get_v_cam(self):
124
- return self.v_cam
125
-
126
- def save_attn_gradients(self, attn_gradients):
127
- self.attn_gradients = attn_gradients
128
-
129
- def get_attn_gradients(self):
130
- return self.attn_gradients
131
-
132
- def forward(self, x):
133
- b, n, _, h = *x.shape, self.num_heads
134
- qkv = self.qkv(x)
135
- q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)
136
-
137
- self.save_v(v)
138
-
139
- dots = self.matmul1([q, k]) * self.scale
140
-
141
- attn = self.softmax(dots)
142
- attn = self.attn_drop(attn)
143
-
144
- self.save_attn(attn)
145
- attn.register_hook(self.save_attn_gradients)
146
-
147
- out = self.matmul2([attn, v])
148
- out = rearrange(out, 'b h n d -> b n (h d)')
149
-
150
- out = self.proj(out)
151
- out = self.proj_drop(out)
152
- return out
153
-
154
- def relprop(self, cam, **kwargs):
155
- cam = self.proj_drop.relprop(cam, **kwargs)
156
- cam = self.proj.relprop(cam, **kwargs)
157
- cam = rearrange(cam, 'b n (h d) -> b h n d', h=self.num_heads)
158
-
159
- # attn = A*V
160
- (cam1, cam_v)= self.matmul2.relprop(cam, **kwargs)
161
- cam1 /= 2
162
- cam_v /= 2
163
-
164
- self.save_v_cam(cam_v)
165
- self.save_attn_cam(cam1)
166
-
167
- cam1 = self.attn_drop.relprop(cam1, **kwargs)
168
- cam1 = self.softmax.relprop(cam1, **kwargs)
169
-
170
- # A = Q*K^T
171
- (cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
172
- cam_q /= 2
173
- cam_k /= 2
174
-
175
- cam_qkv = rearrange([cam_q, cam_k, cam_v], 'qkv b h n d -> b n (qkv h d)', qkv=3, h=self.num_heads)
176
-
177
- return self.qkv.relprop(cam_qkv, **kwargs)
178
-
179
-
180
- class Block(nn.Module):
181
-
182
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
183
- super().__init__()
184
- self.norm1 = LayerNorm(dim, eps=1e-6)
185
- self.attn = Attention(
186
- dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
187
- self.norm2 = LayerNorm(dim, eps=1e-6)
188
- mlp_hidden_dim = int(dim * mlp_ratio)
189
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
190
-
191
- self.add1 = Add()
192
- self.add2 = Add()
193
- self.clone1 = Clone()
194
- self.clone2 = Clone()
195
-
196
- def forward(self, x):
197
- x1, x2 = self.clone1(x, 2)
198
- x = self.add1([x1, self.attn(self.norm1(x2))])
199
- x1, x2 = self.clone2(x, 2)
200
- x = self.add2([x1, self.mlp(self.norm2(x2))])
201
- return x
202
-
203
- def relprop(self, cam, **kwargs):
204
- (cam1, cam2) = self.add2.relprop(cam, **kwargs)
205
- cam2 = self.mlp.relprop(cam2, **kwargs)
206
- cam2 = self.norm2.relprop(cam2, **kwargs)
207
- cam = self.clone2.relprop((cam1, cam2), **kwargs)
208
-
209
- (cam1, cam2) = self.add1.relprop(cam, **kwargs)
210
- cam2 = self.attn.relprop(cam2, **kwargs)
211
- cam2 = self.norm1.relprop(cam2, **kwargs)
212
- cam = self.clone1.relprop((cam1, cam2), **kwargs)
213
- return cam
214
-
215
-
216
- class PatchEmbed(nn.Module):
217
- """ Image to Patch Embedding
218
- """
219
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
220
- super().__init__()
221
- img_size = to_2tuple(img_size)
222
- patch_size = to_2tuple(patch_size)
223
- num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
224
- self.img_size = img_size
225
- self.patch_size = patch_size
226
- self.num_patches = num_patches
227
-
228
- self.proj = Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
229
-
230
- def forward(self, x):
231
- B, C, H, W = x.shape
232
- # FIXME look at relaxing size constraints
233
- assert H == self.img_size[0] and W == self.img_size[1], \
234
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
235
- x = self.proj(x).flatten(2).transpose(1, 2)
236
- return x
237
-
238
- def relprop(self, cam, **kwargs):
239
- cam = cam.transpose(1,2)
240
- cam = cam.reshape(cam.shape[0], cam.shape[1],
241
- (self.img_size[0] // self.patch_size[0]), (self.img_size[1] // self.patch_size[1]))
242
- return self.proj.relprop(cam, **kwargs)
243
-
244
-
245
- class VisionTransformer(nn.Module):
246
- """ Vision Transformer with support for patch or hybrid CNN input stage
247
- """
248
- def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
249
- num_heads=12, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0., attn_drop_rate=0.):
250
- super().__init__()
251
- self.num_classes = num_classes
252
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
253
- self.patch_embed = PatchEmbed(
254
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
255
- num_patches = self.patch_embed.num_patches
256
-
257
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
258
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
259
-
260
- self.blocks = nn.ModuleList([
261
- Block(
262
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
263
- drop=drop_rate, attn_drop=attn_drop_rate)
264
- for i in range(depth)])
265
-
266
- self.norm = LayerNorm(embed_dim)
267
- if mlp_head:
268
- # paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
269
- self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes)
270
- else:
271
- # with a single Linear layer as head, the param count within rounding of paper
272
- self.head = Linear(embed_dim, num_classes)
273
-
274
- # FIXME not quite sure what the proper weight init is supposed to be,
275
- # normal / trunc normal w/ std == .02 similar to other Bert like transformers
276
- trunc_normal_(self.pos_embed, std=.02) # embeddings same as weights?
277
- trunc_normal_(self.cls_token, std=.02)
278
- self.apply(self._init_weights)
279
-
280
- self.pool = IndexSelect()
281
- self.add = Add()
282
-
283
- self.inp_grad = None
284
-
285
- def save_inp_grad(self,grad):
286
- self.inp_grad = grad
287
-
288
- def get_inp_grad(self):
289
- return self.inp_grad
290
-
291
-
292
- def _init_weights(self, m):
293
- if isinstance(m, nn.Linear):
294
- trunc_normal_(m.weight, std=.02)
295
- if isinstance(m, nn.Linear) and m.bias is not None:
296
- nn.init.constant_(m.bias, 0)
297
- elif isinstance(m, nn.LayerNorm):
298
- nn.init.constant_(m.bias, 0)
299
- nn.init.constant_(m.weight, 1.0)
300
-
301
- @property
302
- def no_weight_decay(self):
303
- return {'pos_embed', 'cls_token'}
304
-
305
- def forward(self, x):
306
- B = x.shape[0]
307
- x = self.patch_embed(x)
308
-
309
- cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
310
- x = torch.cat((cls_tokens, x), dim=1)
311
- x = self.add([x, self.pos_embed])
312
-
313
- x.register_hook(self.save_inp_grad)
314
-
315
- for blk in self.blocks:
316
- x = blk(x)
317
-
318
- x = self.norm(x)
319
- x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
320
- x = x.squeeze(1)
321
- x = self.head(x)
322
- return x
323
-
324
- def relprop(self, cam=None,method="grad", is_ablation=False, start_layer=0, **kwargs):
325
- # print(kwargs)
326
- # print("conservation 1", cam.sum())
327
- cam = self.head.relprop(cam, **kwargs)
328
- cam = cam.unsqueeze(1)
329
- cam = self.pool.relprop(cam, **kwargs)
330
- cam = self.norm.relprop(cam, **kwargs)
331
- for blk in reversed(self.blocks):
332
- cam = blk.relprop(cam, **kwargs)
333
-
334
- # print("conservation 2", cam.sum())
335
- # print("min", cam.min())
336
-
337
- if method == "full":
338
- (cam, _) = self.add.relprop(cam, **kwargs)
339
- cam = cam[:, 1:]
340
- cam = self.patch_embed.relprop(cam, **kwargs)
341
- # sum on channels
342
- cam = cam.sum(dim=1)
343
- return cam
344
-
345
- elif method == "rollout":
346
- # cam rollout
347
- attn_cams = []
348
- for blk in self.blocks:
349
- attn_heads = blk.attn.get_attn_cam().clamp(min=0)
350
- avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
351
- attn_cams.append(avg_heads)
352
- cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
353
- cam = cam[:, 0, 1:]
354
- return cam
355
-
356
- elif method == "grad":
357
- cams = []
358
- for blk in self.blocks:
359
- grad = blk.attn.get_attn_gradients()
360
- cam = blk.attn.get_attn_cam()
361
- cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
362
- grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
363
- cam = grad * cam
364
- cam = cam.clamp(min=0).mean(dim=0)
365
- cams.append(cam.unsqueeze(0))
366
- rollout = compute_rollout_attention(cams, start_layer=start_layer)
367
- cam = rollout[:, 0, 1:]
368
- return cam
369
-
370
- elif method == "last_layer":
371
- cam = self.blocks[-1].attn.get_attn_cam()
372
- cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
373
- if is_ablation:
374
- grad = self.blocks[-1].attn.get_attn_gradients()
375
- grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
376
- cam = grad * cam
377
- cam = cam.clamp(min=0).mean(dim=0)
378
- cam = cam[0, 1:]
379
- return cam
380
-
381
- elif method == "last_layer_attn":
382
- cam = self.blocks[-1].attn.get_attn()
383
- cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
384
- cam = cam.clamp(min=0).mean(dim=0)
385
- cam = cam[0, 1:]
386
- return cam
387
-
388
- elif method == "second_layer":
389
- cam = self.blocks[1].attn.get_attn_cam()
390
- cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
391
- if is_ablation:
392
- grad = self.blocks[1].attn.get_attn_gradients()
393
- grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
394
- cam = grad * cam
395
- cam = cam.clamp(min=0).mean(dim=0)
396
- cam = cam[0, 1:]
397
- return cam
398
-
399
-
400
- def _conv_filter(state_dict, patch_size=16):
401
- """ convert patch embedding weight from manual patchify + linear proj to conv"""
402
- out_dict = {}
403
- for k, v in state_dict.items():
404
- if 'patch_embed.proj.weight' in k:
405
- v = v.reshape((v.shape[0], 3, patch_size, patch_size))
406
- out_dict[k] = v
407
- return out_dict
408
-
409
-
410
- def vit_base_patch16_224(pretrained=False, **kwargs):
411
- model = VisionTransformer(
412
- patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
413
- model.default_cfg = default_cfgs['vit_base_patch16_224']
414
- if pretrained:
415
- load_pretrained(
416
- model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
417
- return model
418
-
419
- def vit_large_patch16_224(pretrained=False, **kwargs):
420
- model = VisionTransformer(
421
- patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, **kwargs)
422
- model.default_cfg = default_cfgs['vit_large_patch16_224']
423
- if pretrained:
424
- load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
425
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_LRP.cpython-310.pyc DELETED
Binary file (14.4 kB)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_explanation_generator.cpython-310.pyc DELETED
Binary file (3.49 kB)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_new.cpython-310.pyc DELETED
Binary file (9.15 kB)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_orig_LRP.cpython-310.pyc DELETED
Binary file (13.9 kB)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/helpers.cpython-310.pyc DELETED
Binary file (7.28 kB)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/layer_helpers.cpython-310.pyc DELETED
Binary file (810 Bytes)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/weight_init.cpython-310.pyc DELETED
Binary file (1.98 kB)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/VOC.py DELETED
@@ -1,395 +0,0 @@
1
- import os
2
- import tarfile
3
- import torch
4
- import torch.utils.data as data
5
- import numpy as np
6
- import h5py
7
-
8
- from PIL import Image
9
- from scipy import io
10
- from torchvision.datasets.utils import download_url
11
-
12
- DATASET_YEAR_DICT = {
13
- '2012': {
14
- 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
15
- 'filename': 'VOCtrainval_11-May-2012.tar',
16
- 'md5': '6cd6e144f989b92b3379bac3b3de84fd',
17
- 'base_dir': 'VOCdevkit/VOC2012'
18
- },
19
- '2011': {
20
- 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
21
- 'filename': 'VOCtrainval_25-May-2011.tar',
22
- 'md5': '6c3384ef61512963050cb5d687e5bf1e',
23
- 'base_dir': 'TrainVal/VOCdevkit/VOC2011'
24
- },
25
- '2010': {
26
- 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
27
- 'filename': 'VOCtrainval_03-May-2010.tar',
28
- 'md5': 'da459979d0c395079b5c75ee67908abb',
29
- 'base_dir': 'VOCdevkit/VOC2010'
30
- },
31
- '2009': {
32
- 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
33
- 'filename': 'VOCtrainval_11-May-2009.tar',
34
- 'md5': '59065e4b188729180974ef6572f6a212',
35
- 'base_dir': 'VOCdevkit/VOC2009'
36
- },
37
- '2008': {
38
- 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
39
- 'filename': 'VOCtrainval_11-May-2012.tar',
40
- 'md5': '2629fa636546599198acfcfbfcf1904a',
41
- 'base_dir': 'VOCdevkit/VOC2008'
42
- },
43
- '2007': {
44
- 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
45
- 'filename': 'VOCtrainval_06-Nov-2007.tar',
46
- 'md5': 'c52e279531787c972589f7e41ab4ae64',
47
- 'base_dir': 'VOCdevkit/VOC2007'
48
- }
49
- }
50
-
51
-
52
- class VOCSegmentation(data.Dataset):
53
- """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
54
-
55
- Args:
56
- root (string): Root directory of the VOC Dataset.
57
- year (string, optional): The dataset year, supports years 2007 to 2012.
58
- image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
59
- download (bool, optional): If true, downloads the dataset from the internet and
60
- puts it in root directory. If dataset is already downloaded, it is not
61
- downloaded again.
62
- transform (callable, optional): A function/transform that takes in an PIL image
63
- and returns a transformed version. E.g, ``transforms.RandomCrop``
64
- target_transform (callable, optional): A function/transform that takes in the
65
- target and transforms it.
66
- """
67
-
68
- CLASSES = 20
69
- # CLASSES_NAMES = [
70
- # "background", 'airplane', 'bicycle', 'bird', 'boat', 'bottle',
71
- # 'bus', 'car', 'cat', 'chair', 'cow', 'table', 'dog', 'horse',
72
- # 'motorcycle', 'person', 'pot', 'sheep', 'sofa', 'train',
73
- # 'monitor'
74
- # # 'ambigious'
75
- # ]
76
- CLASSES_NAMES = [
77
- "background", 'plane', 'bike', 'bird', 'boat', 'bottle',
78
- 'bus', 'car', 'cat', 'chair', 'cow', 'table', 'dog', 'horse',
79
- 'motorcycle', 'person', 'pot', 'sheep', 'sofa', 'train',
80
- 'monitor'
81
- # 'ambigious'
82
- ]
83
-
84
- def __init__(
85
- self,
86
- root,
87
- year='2012',
88
- image_set='train',
89
- download=False,
90
- transform=None,
91
- target_transform=None,
92
- binary_class=False
93
- ):
94
- self.root = os.path.expanduser(root)
95
- self.binary_class = binary_class
96
- self.year = year
97
- self.url = DATASET_YEAR_DICT[year]['url']
98
- self.filename = DATASET_YEAR_DICT[year]['filename']
99
- self.md5 = DATASET_YEAR_DICT[year]['md5']
100
- self.transform = transform
101
- self.target_transform = target_transform
102
- self.image_set = image_set
103
- base_dir = DATASET_YEAR_DICT[year]['base_dir']
104
- voc_root = os.path.join(self.root, base_dir)
105
- image_dir = os.path.join(voc_root, 'JPEGImages')
106
- mask_dir = os.path.join(voc_root, 'SegmentationClass')
107
-
108
- if download:
109
- download_extract(self.url, self.root, self.filename, self.md5)
110
-
111
- if not os.path.isdir(voc_root):
112
- raise RuntimeError('Dataset not found or corrupted.' +
113
- ' You can use download=True to download it')
114
-
115
- splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
116
-
117
- split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
118
-
119
- if not os.path.exists(split_f):
120
- raise ValueError(
121
- 'Wrong image_set entered! Please use image_set="train" '
122
- 'or image_set="trainval" or image_set="val"')
123
-
124
- with open(os.path.join(split_f), "r") as f:
125
- file_names = [x.strip() for x in f.readlines()]
126
-
127
- self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
128
- self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
129
- assert (len(self.images) == len(self.masks))
130
-
131
- def __getitem__(self, index):
132
- """
133
- Args:
134
- index (int): Index
135
-
136
- Returns:
137
- tuple: (image, target) where target is the image segmentation.
138
- """
139
- img = Image.open(self.images[index]).convert('RGB')
140
- target = Image.open(self.masks[index])
141
-
142
- if self.transform is not None:
143
- img = self.transform(img)
144
-
145
- if self.target_transform is not None:
146
- target = np.array(self.target_transform(target)).astype('int32')
147
- target[target == 255] = -1
148
- target = torch.from_numpy(target).long()
149
-
150
- # # Convert target to (2, height, width)
151
- # target = torch.stack([target, 1 - target], dim=0)
152
- # Get a list of the classes that are present in the image
153
- visible_classes = np.unique(target)
154
- # Convert these to class names
155
- present_classes = [self.CLASSES_NAMES[i] for i in visible_classes if i != -1]
156
-
157
- if self.binary_class:
158
- # Take all classes that aren't zero or -1 and mkae them 1
159
- target[target >= 1] = 1
160
-
161
- return img, target, present_classes
162
-
163
- @staticmethod
164
- def _mask_transform(mask):
165
- target = np.array(mask).astype('int32')
166
- target[target == 255] = -1
167
- return torch.from_numpy(target).long()
168
-
169
- def __len__(self):
170
- return len(self.images)
171
-
172
- @property
173
- def pred_offset(self):
174
- return 0
175
-
176
-
177
- class VOCClassification(data.Dataset):
178
- """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
179
-
180
- Args:
181
- root (string): Root directory of the VOC Dataset.
182
- year (string, optional): The dataset year, supports years 2007 to 2012.
183
- image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
184
- download (bool, optional): If true, downloads the dataset from the internet and
185
- puts it in root directory. If dataset is already downloaded, it is not
186
- downloaded again.
187
- transform (callable, optional): A function/transform that takes in an PIL image
188
- and returns a transformed version. E.g, ``transforms.RandomCrop``
189
- """
190
- CLASSES = 20
191
-
192
- def __init__(self,
193
- root,
194
- year='2012',
195
- image_set='train',
196
- download=False,
197
- transform=None):
198
- self.root = os.path.expanduser(root)
199
- self.year = year
200
- self.url = DATASET_YEAR_DICT[year]['url']
201
- self.filename = DATASET_YEAR_DICT[year]['filename']
202
- self.md5 = DATASET_YEAR_DICT[year]['md5']
203
- self.transform = transform
204
- self.image_set = image_set
205
- base_dir = DATASET_YEAR_DICT[year]['base_dir']
206
- voc_root = os.path.join(self.root, base_dir)
207
- image_dir = os.path.join(voc_root, 'JPEGImages')
208
- mask_dir = os.path.join(voc_root, 'SegmentationClass')
209
-
210
- if download:
211
- download_extract(self.url, self.root, self.filename, self.md5)
212
-
213
- if not os.path.isdir(voc_root):
214
- raise RuntimeError('Dataset not found or corrupted.' +
215
- ' You can use download=True to download it')
216
-
217
- splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
218
-
219
- split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
220
-
221
- if not os.path.exists(split_f):
222
- raise ValueError(
223
- 'Wrong image_set entered! Please use image_set="train" '
224
- 'or image_set="trainval" or image_set="val"')
225
-
226
- with open(os.path.join(split_f), "r") as f:
227
- file_names = [x.strip() for x in f.readlines()]
228
-
229
- self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
230
- self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
231
- assert (len(self.images) == len(self.masks))
232
-
233
- def __getitem__(self, index):
234
- """
235
- Args:
236
- index (int): Index
237
-
238
- Returns:
239
- tuple: (image, target) where target is the image segmentation.
240
- """
241
- img = Image.open(self.images[index]).convert('RGB')
242
- target = Image.open(self.masks[index])
243
-
244
- # if self.transform is not None:
245
- # img = self.transform(img)
246
- if self.transform is not None:
247
- img, target = self.transform(img, target)
248
-
249
- visible_classes = np.unique(target)
250
- labels = torch.zeros(self.CLASSES)
251
- for id in visible_classes:
252
- if id not in (0, 255):
253
- labels[id - 1].fill_(1)
254
-
255
- return img, labels
256
-
257
- def __len__(self):
258
- return len(self.images)
259
-
260
-
261
- class VOCSBDClassification(data.Dataset):
262
- """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
263
-
264
- Args:
265
- root (string): Root directory of the VOC Dataset.
266
- year (string, optional): The dataset year, supports years 2007 to 2012.
267
- image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
268
- download (bool, optional): If true, downloads the dataset from the internet and
269
- puts it in root directory. If dataset is already downloaded, it is not
270
- downloaded again.
271
- transform (callable, optional): A function/transform that takes in an PIL image
272
- and returns a transformed version. E.g, ``transforms.RandomCrop``
273
- """
274
- CLASSES = 20
275
-
276
- def __init__(self,
277
- root,
278
- sbd_root,
279
- year='2012',
280
- image_set='train',
281
- download=False,
282
- transform=None):
283
- self.root = os.path.expanduser(root)
284
- self.sbd_root = os.path.expanduser(sbd_root)
285
- self.year = year
286
- self.url = DATASET_YEAR_DICT[year]['url']
287
- self.filename = DATASET_YEAR_DICT[year]['filename']
288
- self.md5 = DATASET_YEAR_DICT[year]['md5']
289
- self.transform = transform
290
- self.image_set = image_set
291
- base_dir = DATASET_YEAR_DICT[year]['base_dir']
292
- voc_root = os.path.join(self.root, base_dir)
293
- image_dir = os.path.join(voc_root, 'JPEGImages')
294
- mask_dir = os.path.join(voc_root, 'SegmentationClass')
295
- sbd_image_dir = os.path.join(sbd_root, 'img')
296
- sbd_mask_dir = os.path.join(sbd_root, 'cls')
297
-
298
- if download:
299
- download_extract(self.url, self.root, self.filename, self.md5)
300
-
301
- if not os.path.isdir(voc_root):
302
- raise RuntimeError('Dataset not found or corrupted.' +
303
- ' You can use download=True to download it')
304
-
305
- splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
306
-
307
- split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
308
- sbd_split = os.path.join(sbd_root, 'train.txt')
309
-
310
- if not os.path.exists(split_f):
311
- raise ValueError(
312
- 'Wrong image_set entered! Please use image_set="train" '
313
- 'or image_set="trainval" or image_set="val"')
314
-
315
- with open(os.path.join(split_f), "r") as f:
316
- voc_file_names = [x.strip() for x in f.readlines()]
317
-
318
- with open(os.path.join(sbd_split), "r") as f:
319
- sbd_file_names = [x.strip() for x in f.readlines()]
320
-
321
- self.images = [os.path.join(image_dir, x + ".jpg") for x in voc_file_names]
322
- self.images += [os.path.join(sbd_image_dir, x + ".jpg") for x in sbd_file_names]
323
- self.masks = [os.path.join(mask_dir, x + ".png") for x in voc_file_names]
324
- self.masks += [os.path.join(sbd_mask_dir, x + ".mat") for x in sbd_file_names]
325
- assert (len(self.images) == len(self.masks))
326
-
327
- def __getitem__(self, index):
328
- """
329
- Args:
330
- index (int): Index
331
-
332
- Returns:
333
- tuple: (image, target) where target is the image segmentation.
334
- """
335
- img = Image.open(self.images[index]).convert('RGB')
336
- mask_path = self.masks[index]
337
- if mask_path[-3:] == 'mat':
338
- target = io.loadmat(mask_path, struct_as_record=False, squeeze_me=True)['GTcls'].Segmentation
339
- target = Image.fromarray(target, mode='P')
340
- else:
341
- target = Image.open(self.masks[index])
342
-
343
- if self.transform is not None:
344
- img, target = self.transform(img, target)
345
-
346
- visible_classes = np.unique(target)
347
- labels = torch.zeros(self.CLASSES)
348
- for id in visible_classes:
349
- if id not in (0, 255):
350
- labels[id - 1].fill_(1)
351
-
352
- return img, labels
353
-
354
- def __len__(self):
355
- return len(self.images)
356
-
357
-
358
- def download_extract(url, root, filename, md5):
359
- download_url(url, root, filename, md5)
360
- with tarfile.open(os.path.join(root, filename), "r") as tar:
361
- tar.extractall(path=root)
362
-
363
-
364
- class VOCResults(data.Dataset):
365
- CLASSES = 20
366
- CLASSES_NAMES = [
367
- 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
368
- 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
369
- 'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
370
- 'tvmonitor', 'ambigious'
371
- ]
372
-
373
- def __init__(self, path):
374
- super(VOCResults, self).__init__()
375
-
376
- self.path = os.path.join(path, 'results.hdf5')
377
- self.data = None
378
-
379
- print('Reading dataset length...')
380
- with h5py.File(self.path , 'r') as f:
381
- self.data_length = len(f['/image'])
382
-
383
- def __len__(self):
384
- return self.data_length
385
-
386
- def __getitem__(self, item):
387
- if self.data is None:
388
- self.data = h5py.File(self.path, 'r')
389
-
390
- image = torch.tensor(self.data['image'][item])
391
- vis = torch.tensor(self.data['vis'][item])
392
- target = torch.tensor(self.data['target'][item])
393
- class_pred = torch.tensor(self.data['class_pred'][item])
394
-
395
- return image, vis, target, class_pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__init__.py DELETED
File without changes
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/Imagenet.cpython-310.pyc DELETED
Binary file (5.25 kB)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/VOC.cpython-310.pyc DELETED
Binary file (12.1 kB)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (220 Bytes)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/imagenet.cpython-310.pyc DELETED
Binary file (5.37 kB)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/imagenet.py DELETED
@@ -1,200 +0,0 @@
1
- import os
2
- import torch
3
- import torch.utils.data as data
4
- import numpy as np
5
- import cv2
6
-
7
- from torchvision.datasets import ImageNet
8
-
9
- from PIL import Image, ImageFilter
10
- import h5py
11
- from glob import glob
12
-
13
-
14
- class ImageNet_blur(ImageNet):
15
- def __getitem__(self, index):
16
- """
17
- Args:
18
- index (int): Index
19
-
20
- Returns:
21
- tuple: (sample, target) where target is class_index of the target class.
22
- """
23
- path, target = self.samples[index]
24
- sample = self.loader(path)
25
-
26
- gauss_blur = ImageFilter.GaussianBlur(11)
27
- median_blur = ImageFilter.MedianFilter(11)
28
-
29
- blurred_img1 = sample.filter(gauss_blur)
30
- blurred_img2 = sample.filter(median_blur)
31
- blurred_img = Image.blend(blurred_img1, blurred_img2, 0.5)
32
-
33
- if self.transform is not None:
34
- sample = self.transform(sample)
35
- blurred_img = self.transform(blurred_img)
36
- if self.target_transform is not None:
37
- target = self.target_transform(target)
38
-
39
- return (sample, blurred_img), target
40
-
41
-
42
- class Imagenet_Segmentation(data.Dataset):
43
- CLASSES = 2
44
-
45
- def __init__(self,
46
- path,
47
- transform=None,
48
- target_transform=None):
49
- self.path = path
50
- self.transform = transform
51
- self.target_transform = target_transform
52
- # self.h5py = h5py.File(path, 'r+')
53
- self.h5py = None
54
- with h5py.File(path, 'r') as tmp:
55
- self.data_length = len(tmp['/value/img'])
56
-
57
- def __getitem__(self, index):
58
-
59
- if self.h5py is None:
60
- self.h5py = h5py.File(self.path, 'r')
61
-
62
- img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
63
- target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
64
-
65
- img = Image.fromarray(img).convert('RGB')
66
- target = Image.fromarray(target)
67
-
68
- if self.transform is not None:
69
- img = self.transform(img)
70
-
71
- if self.target_transform is not None:
72
- target = np.array(self.target_transform(target)).astype('int32')
73
- target = torch.from_numpy(target).long()
74
-
75
- return img, target
76
-
77
- def __len__(self):
78
- # return len(self.h5py['/value/img'])
79
- return self.data_length
80
-
81
-
82
- class Imagenet_Segmentation_Blur(data.Dataset):
83
- CLASSES = 2
84
-
85
- def __init__(self,
86
- path,
87
- transform=None,
88
- target_transform=None):
89
- self.path = path
90
- self.transform = transform
91
- self.target_transform = target_transform
92
- # self.h5py = h5py.File(path, 'r+')
93
- self.h5py = None
94
- tmp = h5py.File(path, 'r')
95
- self.data_length = len(tmp['/value/img'])
96
- tmp.close()
97
- del tmp
98
-
99
- def __getitem__(self, index):
100
-
101
- if self.h5py is None:
102
- self.h5py = h5py.File(self.path, 'r')
103
-
104
- img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
105
- target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
106
-
107
- img = Image.fromarray(img).convert('RGB')
108
- target = Image.fromarray(target)
109
-
110
- gauss_blur = ImageFilter.GaussianBlur(11)
111
- median_blur = ImageFilter.MedianFilter(11)
112
-
113
- blurred_img1 = img.filter(gauss_blur)
114
- blurred_img2 = img.filter(median_blur)
115
- blurred_img = Image.blend(blurred_img1, blurred_img2, 0.5)
116
-
117
- # blurred_img1 = cv2.GaussianBlur(img, (11, 11), 5)
118
- # blurred_img2 = np.float32(cv2.medianBlur(img, 11))
119
- # blurred_img = (blurred_img1 + blurred_img2) / 2
120
-
121
- if self.transform is not None:
122
- img = self.transform(img)
123
- blurred_img = self.transform(blurred_img)
124
-
125
- if self.target_transform is not None:
126
- target = np.array(self.target_transform(target)).astype('int32')
127
- target = torch.from_numpy(target).long()
128
-
129
- return (img, blurred_img), target
130
-
131
- def __len__(self):
132
- # return len(self.h5py['/value/img'])
133
- return self.data_length
134
-
135
-
136
- class Imagenet_Segmentation_eval_dir(data.Dataset):
137
- CLASSES = 2
138
-
139
- def __init__(self,
140
- path,
141
- eval_path,
142
- transform=None,
143
- target_transform=None):
144
- self.transform = transform
145
- self.target_transform = target_transform
146
- self.h5py = h5py.File(path, 'r+')
147
-
148
- # 500 each file
149
- self.results = glob(os.path.join(eval_path, '*.npy'))
150
-
151
- def __getitem__(self, index):
152
-
153
- img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
154
- target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
155
- res = np.load(self.results[index])
156
-
157
- img = Image.fromarray(img).convert('RGB')
158
- target = Image.fromarray(target)
159
-
160
- if self.transform is not None:
161
- img = self.transform(img)
162
-
163
- if self.target_transform is not None:
164
- target = np.array(self.target_transform(target)).astype('int32')
165
- target = torch.from_numpy(target).long()
166
-
167
- return img, target
168
-
169
- def __len__(self):
170
- return len(self.h5py['/value/img'])
171
-
172
-
173
- if __name__ == '__main__':
174
- import torchvision.transforms as transforms
175
- from tqdm import tqdm
176
- from imageio import imsave
177
- import scipy.io as sio
178
-
179
- # meta = sio.loadmat('/home/shirgur/ext/Data/Datasets/temp/ILSVRC2012_devkit_t12/data/meta.mat', squeeze_me=True)['synsets']
180
-
181
- # Data
182
- normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
183
- std=[0.229, 0.224, 0.225])
184
- test_img_trans = transforms.Compose([
185
- transforms.Resize((224, 224)),
186
- transforms.ToTensor(),
187
- normalize,
188
- ])
189
- test_lbl_trans = transforms.Compose([
190
- transforms.Resize((224, 224), Image.NEAREST),
191
- ])
192
-
193
- ds = Imagenet_Segmentation('/home/shirgur/ext/Data/Datasets/imagenet-seg/other/gtsegs_ijcv.mat',
194
- transform=test_img_trans, target_transform=test_lbl_trans)
195
-
196
- for i, (img, tgt) in enumerate(tqdm(ds)):
197
- tgt = (tgt.numpy() * 255).astype(np.uint8)
198
- imsave('/home/shirgur/ext/Code/C2S/run/imagenet/gt/{}.png'.format(i), tgt)
199
-
200
- print('here')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/imagenet_utils.py DELETED
@@ -1,1002 +0,0 @@
1
- CLS2IDX = {
2
- 0: 'tench, Tinca tinca',
3
- 1: 'goldfish, Carassius auratus',
4
- 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
5
- 3: 'tiger shark, Galeocerdo cuvieri',
6
- 4: 'hammerhead, hammerhead shark',
7
- 5: 'electric ray, crampfish, numbfish, torpedo',
8
- 6: 'stingray',
9
- 7: 'cock',
10
- 8: 'hen',
11
- 9: 'ostrich, Struthio camelus',
12
- 10: 'brambling, Fringilla montifringilla',
13
- 11: 'goldfinch, Carduelis carduelis',
14
- 12: 'house finch, linnet, Carpodacus mexicanus',
15
- 13: 'junco, snowbird',
16
- 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
17
- 15: 'robin, American robin, Turdus migratorius',
18
- 16: 'bulbul',
19
- 17: 'jay',
20
- 18: 'magpie',
21
- 19: 'chickadee',
22
- 20: 'water ouzel, dipper',
23
- 21: 'kite',
24
- 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
25
- 23: 'vulture',
26
- 24: 'great grey owl, great gray owl, Strix nebulosa',
27
- 25: 'European fire salamander, Salamandra salamandra',
28
- 26: 'common newt, Triturus vulgaris',
29
- 27: 'eft',
30
- 28: 'spotted salamander, Ambystoma maculatum',
31
- 29: 'axolotl, mud puppy, Ambystoma mexicanum',
32
- 30: 'bullfrog, Rana catesbeiana',
33
- 31: 'tree frog, tree-frog',
34
- 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
35
- 33: 'loggerhead, loggerhead turtle, Caretta caretta',
36
- 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
37
- 35: 'mud turtle',
38
- 36: 'terrapin',
39
- 37: 'box turtle, box tortoise',
40
- 38: 'banded gecko',
41
- 39: 'common iguana, iguana, Iguana iguana',
42
- 40: 'American chameleon, anole, Anolis carolinensis',
43
- 41: 'whiptail, whiptail lizard',
44
- 42: 'agama',
45
- 43: 'frilled lizard, Chlamydosaurus kingi',
46
- 44: 'alligator lizard',
47
- 45: 'Gila monster, Heloderma suspectum',
48
- 46: 'green lizard, Lacerta viridis',
49
- 47: 'African chameleon, Chamaeleo chamaeleon',
50
- 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
51
- 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
52
- 50: 'American alligator, Alligator mississipiensis',
53
- 51: 'triceratops',
54
- 52: 'thunder snake, worm snake, Carphophis amoenus',
55
- 53: 'ringneck snake, ring-necked snake, ring snake',
56
- 54: 'hognose snake, puff adder, sand viper',
57
- 55: 'green snake, grass snake',
58
- 56: 'king snake, kingsnake',
59
- 57: 'garter snake, grass snake',
60
- 58: 'water snake',
61
- 59: 'vine snake',
62
- 60: 'night snake, Hypsiglena torquata',
63
- 61: 'boa constrictor, Constrictor constrictor',
64
- 62: 'rock python, rock snake, Python sebae',
65
- 63: 'Indian cobra, Naja naja',
66
- 64: 'green mamba',
67
- 65: 'sea snake',
68
- 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
69
- 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
70
- 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
71
- 69: 'trilobite',
72
- 70: 'harvestman, daddy longlegs, Phalangium opilio',
73
- 71: 'scorpion',
74
- 72: 'black and gold garden spider, Argiope aurantia',
75
- 73: 'barn spider, Araneus cavaticus',
76
- 74: 'garden spider, Aranea diademata',
77
- 75: 'black widow, Latrodectus mactans',
78
- 76: 'tarantula',
79
- 77: 'wolf spider, hunting spider',
80
- 78: 'tick',
81
- 79: 'centipede',
82
- 80: 'black grouse',
83
- 81: 'ptarmigan',
84
- 82: 'ruffed grouse, partridge, Bonasa umbellus',
85
- 83: 'prairie chicken, prairie grouse, prairie fowl',
86
- 84: 'peacock',
87
- 85: 'quail',
88
- 86: 'partridge',
89
- 87: 'African grey, African gray, Psittacus erithacus',
90
- 88: 'macaw',
91
- 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
92
- 90: 'lorikeet',
93
- 91: 'coucal',
94
- 92: 'bee eater',
95
- 93: 'hornbill',
96
- 94: 'hummingbird',
97
- 95: 'jacamar',
98
- 96: 'toucan',
99
- 97: 'drake',
100
- 98: 'red-breasted merganser, Mergus serrator',
101
- 99: 'goose',
102
- 100: 'black swan, Cygnus atratus',
103
- 101: 'tusker',
104
- 102: 'echidna, spiny anteater, anteater',
105
- 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
106
- 104: 'wallaby, brush kangaroo',
107
- 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
108
- 106: 'wombat',
109
- 107: 'jellyfish',
110
- 108: 'sea anemone, anemone',
111
- 109: 'brain coral',
112
- 110: 'flatworm, platyhelminth',
113
- 111: 'nematode, nematode worm, roundworm',
114
- 112: 'conch',
115
- 113: 'snail',
116
- 114: 'slug',
117
- 115: 'sea slug, nudibranch',
118
- 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
119
- 117: 'chambered nautilus, pearly nautilus, nautilus',
120
- 118: 'Dungeness crab, Cancer magister',
121
- 119: 'rock crab, Cancer irroratus',
122
- 120: 'fiddler crab',
123
- 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
124
- 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
125
- 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
126
- 124: 'crayfish, crawfish, crawdad, crawdaddy',
127
- 125: 'hermit crab',
128
- 126: 'isopod',
129
- 127: 'white stork, Ciconia ciconia',
130
- 128: 'black stork, Ciconia nigra',
131
- 129: 'spoonbill',
132
- 130: 'flamingo',
133
- 131: 'little blue heron, Egretta caerulea',
134
- 132: 'American egret, great white heron, Egretta albus',
135
- 133: 'bittern',
136
- 134: 'crane',
137
- 135: 'limpkin, Aramus pictus',
138
- 136: 'European gallinule, Porphyrio porphyrio',
139
- 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
140
- 138: 'bustard',
141
- 139: 'ruddy turnstone, Arenaria interpres',
142
- 140: 'red-backed sandpiper, dunlin, Erolia alpina',
143
- 141: 'redshank, Tringa totanus',
144
- 142: 'dowitcher',
145
- 143: 'oystercatcher, oyster catcher',
146
- 144: 'pelican',
147
- 145: 'king penguin, Aptenodytes patagonica',
148
- 146: 'albatross, mollymawk',
149
- 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
150
- 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
151
- 149: 'dugong, Dugong dugon',
152
- 150: 'sea lion',
153
- 151: 'Chihuahua',
154
- 152: 'Japanese spaniel',
155
- 153: 'Maltese dog, Maltese terrier, Maltese',
156
- 154: 'Pekinese, Pekingese, Peke',
157
- 155: 'Shih-Tzu',
158
- 156: 'Blenheim spaniel',
159
- 157: 'papillon',
160
- 158: 'toy terrier',
161
- 159: 'Rhodesian ridgeback',
162
- 160: 'Afghan hound, Afghan',
163
- 161: 'basset, basset hound',
164
- 162: 'beagle',
165
- 163: 'bloodhound, sleuthhound',
166
- 164: 'bluetick',
167
- 165: 'black-and-tan coonhound',
168
- 166: 'Walker hound, Walker foxhound',
169
- 167: 'English foxhound',
170
- 168: 'redbone',
171
- 169: 'borzoi, Russian wolfhound',
172
- 170: 'Irish wolfhound',
173
- 171: 'Italian greyhound',
174
- 172: 'whippet',
175
- 173: 'Ibizan hound, Ibizan Podenco',
176
- 174: 'Norwegian elkhound, elkhound',
177
- 175: 'otterhound, otter hound',
178
- 176: 'Saluki, gazelle hound',
179
- 177: 'Scottish deerhound, deerhound',
180
- 178: 'Weimaraner',
181
- 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
182
- 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
183
- 181: 'Bedlington terrier',
184
- 182: 'Border terrier',
185
- 183: 'Kerry blue terrier',
186
- 184: 'Irish terrier',
187
- 185: 'Norfolk terrier',
188
- 186: 'Norwich terrier',
189
- 187: 'Yorkshire terrier',
190
- 188: 'wire-haired fox terrier',
191
- 189: 'Lakeland terrier',
192
- 190: 'Sealyham terrier, Sealyham',
193
- 191: 'Airedale, Airedale terrier',
194
- 192: 'cairn, cairn terrier',
195
- 193: 'Australian terrier',
196
- 194: 'Dandie Dinmont, Dandie Dinmont terrier',
197
- 195: 'Boston bull, Boston terrier',
198
- 196: 'miniature schnauzer',
199
- 197: 'giant schnauzer',
200
- 198: 'standard schnauzer',
201
- 199: 'Scotch terrier, Scottish terrier, Scottie',
202
- 200: 'Tibetan terrier, chrysanthemum dog',
203
- 201: 'silky terrier, Sydney silky',
204
- 202: 'soft-coated wheaten terrier',
205
- 203: 'West Highland white terrier',
206
- 204: 'Lhasa, Lhasa apso',
207
- 205: 'flat-coated retriever',
208
- 206: 'curly-coated retriever',
209
- 207: 'golden retriever',
210
- 208: 'Labrador retriever',
211
- 209: 'Chesapeake Bay retriever',
212
- 210: 'German short-haired pointer',
213
- 211: 'vizsla, Hungarian pointer',
214
- 212: 'English setter',
215
- 213: 'Irish setter, red setter',
216
- 214: 'Gordon setter',
217
- 215: 'Brittany spaniel',
218
- 216: 'clumber, clumber spaniel',
219
- 217: 'English springer, English springer spaniel',
220
- 218: 'Welsh springer spaniel',
221
- 219: 'cocker spaniel, English cocker spaniel, cocker',
222
- 220: 'Sussex spaniel',
223
- 221: 'Irish water spaniel',
224
- 222: 'kuvasz',
225
- 223: 'schipperke',
226
- 224: 'groenendael',
227
- 225: 'malinois',
228
- 226: 'briard',
229
- 227: 'kelpie',
230
- 228: 'komondor',
231
- 229: 'Old English sheepdog, bobtail',
232
- 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
233
- 231: 'collie',
234
- 232: 'Border collie',
235
- 233: 'Bouvier des Flandres, Bouviers des Flandres',
236
- 234: 'Rottweiler',
237
- 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
238
- 236: 'Doberman, Doberman pinscher',
239
- 237: 'miniature pinscher',
240
- 238: 'Greater Swiss Mountain dog',
241
- 239: 'Bernese mountain dog',
242
- 240: 'Appenzeller',
243
- 241: 'EntleBucher',
244
- 242: 'boxer',
245
- 243: 'bull mastiff',
246
- 244: 'Tibetan mastiff',
247
- 245: 'French bulldog',
248
- 246: 'Great Dane',
249
- 247: 'Saint Bernard, St Bernard',
250
- 248: 'Eskimo dog, husky',
251
- 249: 'malamute, malemute, Alaskan malamute',
252
- 250: 'Siberian husky',
253
- 251: 'dalmatian, coach dog, carriage dog',
254
- 252: 'affenpinscher, monkey pinscher, monkey dog',
255
- 253: 'basenji',
256
- 254: 'pug, pug-dog',
257
- 255: 'Leonberg',
258
- 256: 'Newfoundland, Newfoundland dog',
259
- 257: 'Great Pyrenees',
260
- 258: 'Samoyed, Samoyede',
261
- 259: 'Pomeranian',
262
- 260: 'chow, chow chow',
263
- 261: 'keeshond',
264
- 262: 'Brabancon griffon',
265
- 263: 'Pembroke, Pembroke Welsh corgi',
266
- 264: 'Cardigan, Cardigan Welsh corgi',
267
- 265: 'toy poodle',
268
- 266: 'miniature poodle',
269
- 267: 'standard poodle',
270
- 268: 'Mexican hairless',
271
- 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
272
- 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
273
- 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
274
- 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
275
- 273: 'dingo, warrigal, warragal, Canis dingo',
276
- 274: 'dhole, Cuon alpinus',
277
- 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
278
- 276: 'hyena, hyaena',
279
- 277: 'red fox, Vulpes vulpes',
280
- 278: 'kit fox, Vulpes macrotis',
281
- 279: 'Arctic fox, white fox, Alopex lagopus',
282
- 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
283
- 281: 'tabby, tabby cat',
284
- 282: 'tiger cat',
285
- 283: 'Persian cat',
286
- 284: 'Siamese cat, Siamese',
287
- 285: 'Egyptian cat',
288
- 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
289
- 287: 'lynx, catamount',
290
- 288: 'leopard, Panthera pardus',
291
- 289: 'snow leopard, ounce, Panthera uncia',
292
- 290: 'jaguar, panther, Panthera onca, Felis onca',
293
- 291: 'lion, king of beasts, Panthera leo',
294
- 292: 'tiger, Panthera tigris',
295
- 293: 'cheetah, chetah, Acinonyx jubatus',
296
- 294: 'brown bear, bruin, Ursus arctos',
297
- 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
298
- 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
299
- 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
300
- 298: 'mongoose',
301
- 299: 'meerkat, mierkat',
302
- 300: 'tiger beetle',
303
- 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
304
- 302: 'ground beetle, carabid beetle',
305
- 303: 'long-horned beetle, longicorn, longicorn beetle',
306
- 304: 'leaf beetle, chrysomelid',
307
- 305: 'dung beetle',
308
- 306: 'rhinoceros beetle',
309
- 307: 'weevil',
310
- 308: 'fly',
311
- 309: 'bee',
312
- 310: 'ant, emmet, pismire',
313
- 311: 'grasshopper, hopper',
314
- 312: 'cricket',
315
- 313: 'walking stick, walkingstick, stick insect',
316
- 314: 'cockroach, roach',
317
- 315: 'mantis, mantid',
318
- 316: 'cicada, cicala',
319
- 317: 'leafhopper',
320
- 318: 'lacewing, lacewing fly',
321
- 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
322
- 320: 'damselfly',
323
- 321: 'admiral',
324
- 322: 'ringlet, ringlet butterfly',
325
- 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
326
- 324: 'cabbage butterfly',
327
- 325: 'sulphur butterfly, sulfur butterfly',
328
- 326: 'lycaenid, lycaenid butterfly',
329
- 327: 'starfish, sea star',
330
- 328: 'sea urchin',
331
- 329: 'sea cucumber, holothurian',
332
- 330: 'wood rabbit, cottontail, cottontail rabbit',
333
- 331: 'hare',
334
- 332: 'Angora, Angora rabbit',
335
- 333: 'hamster',
336
- 334: 'porcupine, hedgehog',
337
- 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
338
- 336: 'marmot',
339
- 337: 'beaver',
340
- 338: 'guinea pig, Cavia cobaya',
341
- 339: 'sorrel',
342
- 340: 'zebra',
343
- 341: 'hog, pig, grunter, squealer, Sus scrofa',
344
- 342: 'wild boar, boar, Sus scrofa',
345
- 343: 'warthog',
346
- 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
347
- 345: 'ox',
348
- 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
349
- 347: 'bison',
350
- 348: 'ram, tup',
351
- 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
352
- 350: 'ibex, Capra ibex',
353
- 351: 'hartebeest',
354
- 352: 'impala, Aepyceros melampus',
355
- 353: 'gazelle',
356
- 354: 'Arabian camel, dromedary, Camelus dromedarius',
357
- 355: 'llama',
358
- 356: 'weasel',
359
- 357: 'mink',
360
- 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
361
- 359: 'black-footed ferret, ferret, Mustela nigripes',
362
- 360: 'otter',
363
- 361: 'skunk, polecat, wood pussy',
364
- 362: 'badger',
365
- 363: 'armadillo',
366
- 364: 'three-toed sloth, ai, Bradypus tridactylus',
367
- 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
368
- 366: 'gorilla, Gorilla gorilla',
369
- 367: 'chimpanzee, chimp, Pan troglodytes',
370
- 368: 'gibbon, Hylobates lar',
371
- 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
372
- 370: 'guenon, guenon monkey',
373
- 371: 'patas, hussar monkey, Erythrocebus patas',
374
- 372: 'baboon',
375
- 373: 'macaque',
376
- 374: 'langur',
377
- 375: 'colobus, colobus monkey',
378
- 376: 'proboscis monkey, Nasalis larvatus',
379
- 377: 'marmoset',
380
- 378: 'capuchin, ringtail, Cebus capucinus',
381
- 379: 'howler monkey, howler',
382
- 380: 'titi, titi monkey',
383
- 381: 'spider monkey, Ateles geoffroyi',
384
- 382: 'squirrel monkey, Saimiri sciureus',
385
- 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
386
- 384: 'indri, indris, Indri indri, Indri brevicaudatus',
387
- 385: 'Indian elephant, Elephas maximus',
388
- 386: 'African elephant, Loxodonta africana',
389
- 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
390
- 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
391
- 389: 'barracouta, snoek',
392
- 390: 'eel',
393
- 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
394
- 392: 'rock beauty, Holocanthus tricolor',
395
- 393: 'anemone fish',
396
- 394: 'sturgeon',
397
- 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
398
- 396: 'lionfish',
399
- 397: 'puffer, pufferfish, blowfish, globefish',
400
- 398: 'abacus',
401
- 399: 'abaya',
402
- 400: "academic gown, academic robe, judge's robe",
403
- 401: 'accordion, piano accordion, squeeze box',
404
- 402: 'acoustic guitar',
405
- 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
406
- 404: 'airliner',
407
- 405: 'airship, dirigible',
408
- 406: 'altar',
409
- 407: 'ambulance',
410
- 408: 'amphibian, amphibious vehicle',
411
- 409: 'analog clock',
412
- 410: 'apiary, bee house',
413
- 411: 'apron',
414
- 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
415
- 413: 'assault rifle, assault gun',
416
- 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
417
- 415: 'bakery, bakeshop, bakehouse',
418
- 416: 'balance beam, beam',
419
- 417: 'balloon',
420
- 418: 'ballpoint, ballpoint pen, ballpen, Biro',
421
- 419: 'Band Aid',
422
- 420: 'banjo',
423
- 421: 'bannister, banister, balustrade, balusters, handrail',
424
- 422: 'barbell',
425
- 423: 'barber chair',
426
- 424: 'barbershop',
427
- 425: 'barn',
428
- 426: 'barometer',
429
- 427: 'barrel, cask',
430
- 428: 'barrow, garden cart, lawn cart, wheelbarrow',
431
- 429: 'baseball',
432
- 430: 'basketball',
433
- 431: 'bassinet',
434
- 432: 'bassoon',
435
- 433: 'bathing cap, swimming cap',
436
- 434: 'bath towel',
437
- 435: 'bathtub, bathing tub, bath, tub',
438
- 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
439
- 437: 'beacon, lighthouse, beacon light, pharos',
440
- 438: 'beaker',
441
- 439: 'bearskin, busby, shako',
442
- 440: 'beer bottle',
443
- 441: 'beer glass',
444
- 442: 'bell cote, bell cot',
445
- 443: 'bib',
446
- 444: 'bicycle-built-for-two, tandem bicycle, tandem',
447
- 445: 'bikini, two-piece',
448
- 446: 'binder, ring-binder',
449
- 447: 'binoculars, field glasses, opera glasses',
450
- 448: 'birdhouse',
451
- 449: 'boathouse',
452
- 450: 'bobsled, bobsleigh, bob',
453
- 451: 'bolo tie, bolo, bola tie, bola',
454
- 452: 'bonnet, poke bonnet',
455
- 453: 'bookcase',
456
- 454: 'bookshop, bookstore, bookstall',
457
- 455: 'bottlecap',
458
- 456: 'bow',
459
- 457: 'bow tie, bow-tie, bowtie',
460
- 458: 'brass, memorial tablet, plaque',
461
- 459: 'brassiere, bra, bandeau',
462
- 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
463
- 461: 'breastplate, aegis, egis',
464
- 462: 'broom',
465
- 463: 'bucket, pail',
466
- 464: 'buckle',
467
- 465: 'bulletproof vest',
468
- 466: 'bullet train, bullet',
469
- 467: 'butcher shop, meat market',
470
- 468: 'cab, hack, taxi, taxicab',
471
- 469: 'caldron, cauldron',
472
- 470: 'candle, taper, wax light',
473
- 471: 'cannon',
474
- 472: 'canoe',
475
- 473: 'can opener, tin opener',
476
- 474: 'cardigan',
477
- 475: 'car mirror',
478
- 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
479
- 477: "carpenter's kit, tool kit",
480
- 478: 'carton',
481
- 479: 'car wheel',
482
- 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
483
- 481: 'cassette',
484
- 482: 'cassette player',
485
- 483: 'castle',
486
- 484: 'catamaran',
487
- 485: 'CD player',
488
- 486: 'cello, violoncello',
489
- 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
490
- 488: 'chain',
491
- 489: 'chainlink fence',
492
- 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
493
- 491: 'chain saw, chainsaw',
494
- 492: 'chest',
495
- 493: 'chiffonier, commode',
496
- 494: 'chime, bell, gong',
497
- 495: 'china cabinet, china closet',
498
- 496: 'Christmas stocking',
499
- 497: 'church, church building',
500
- 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
501
- 499: 'cleaver, meat cleaver, chopper',
502
- 500: 'cliff dwelling',
503
- 501: 'cloak',
504
- 502: 'clog, geta, patten, sabot',
505
- 503: 'cocktail shaker',
506
- 504: 'coffee mug',
507
- 505: 'coffeepot',
508
- 506: 'coil, spiral, volute, whorl, helix',
509
- 507: 'combination lock',
510
- 508: 'computer keyboard, keypad',
511
- 509: 'confectionery, confectionary, candy store',
512
- 510: 'container ship, containership, container vessel',
513
- 511: 'convertible',
514
- 512: 'corkscrew, bottle screw',
515
- 513: 'cornet, horn, trumpet, trump',
516
- 514: 'cowboy boot',
517
- 515: 'cowboy hat, ten-gallon hat',
518
- 516: 'cradle',
519
- 517: 'crane',
520
- 518: 'crash helmet',
521
- 519: 'crate',
522
- 520: 'crib, cot',
523
- 521: 'Crock Pot',
524
- 522: 'croquet ball',
525
- 523: 'crutch',
526
- 524: 'cuirass',
527
- 525: 'dam, dike, dyke',
528
- 526: 'desk',
529
- 527: 'desktop computer',
530
- 528: 'dial telephone, dial phone',
531
- 529: 'diaper, nappy, napkin',
532
- 530: 'digital clock',
533
- 531: 'digital watch',
534
- 532: 'dining table, board',
535
- 533: 'dishrag, dishcloth',
536
- 534: 'dishwasher, dish washer, dishwashing machine',
537
- 535: 'disk brake, disc brake',
538
- 536: 'dock, dockage, docking facility',
539
- 537: 'dogsled, dog sled, dog sleigh',
540
- 538: 'dome',
541
- 539: 'doormat, welcome mat',
542
- 540: 'drilling platform, offshore rig',
543
- 541: 'drum, membranophone, tympan',
544
- 542: 'drumstick',
545
- 543: 'dumbbell',
546
- 544: 'Dutch oven',
547
- 545: 'electric fan, blower',
548
- 546: 'electric guitar',
549
- 547: 'electric locomotive',
550
- 548: 'entertainment center',
551
- 549: 'envelope',
552
- 550: 'espresso maker',
553
- 551: 'face powder',
554
- 552: 'feather boa, boa',
555
- 553: 'file, file cabinet, filing cabinet',
556
- 554: 'fireboat',
557
- 555: 'fire engine, fire truck',
558
- 556: 'fire screen, fireguard',
559
- 557: 'flagpole, flagstaff',
560
- 558: 'flute, transverse flute',
561
- 559: 'folding chair',
562
- 560: 'football helmet',
563
- 561: 'forklift',
564
- 562: 'fountain',
565
- 563: 'fountain pen',
566
- 564: 'four-poster',
567
- 565: 'freight car',
568
- 566: 'French horn, horn',
569
- 567: 'frying pan, frypan, skillet',
570
- 568: 'fur coat',
571
- 569: 'garbage truck, dustcart',
572
- 570: 'gasmask, respirator, gas helmet',
573
- 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
574
- 572: 'goblet',
575
- 573: 'go-kart',
576
- 574: 'golf ball',
577
- 575: 'golfcart, golf cart',
578
- 576: 'gondola',
579
- 577: 'gong, tam-tam',
580
- 578: 'gown',
581
- 579: 'grand piano, grand',
582
- 580: 'greenhouse, nursery, glasshouse',
583
- 581: 'grille, radiator grille',
584
- 582: 'grocery store, grocery, food market, market',
585
- 583: 'guillotine',
586
- 584: 'hair slide',
587
- 585: 'hair spray',
588
- 586: 'half track',
589
- 587: 'hammer',
590
- 588: 'hamper',
591
- 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
592
- 590: 'hand-held computer, hand-held microcomputer',
593
- 591: 'handkerchief, hankie, hanky, hankey',
594
- 592: 'hard disc, hard disk, fixed disk',
595
- 593: 'harmonica, mouth organ, harp, mouth harp',
596
- 594: 'harp',
597
- 595: 'harvester, reaper',
598
- 596: 'hatchet',
599
- 597: 'holster',
600
- 598: 'home theater, home theatre',
601
- 599: 'honeycomb',
602
- 600: 'hook, claw',
603
- 601: 'hoopskirt, crinoline',
604
- 602: 'horizontal bar, high bar',
605
- 603: 'horse cart, horse-cart',
606
- 604: 'hourglass',
607
- 605: 'iPod',
608
- 606: 'iron, smoothing iron',
609
- 607: "jack-o'-lantern",
610
- 608: 'jean, blue jean, denim',
611
- 609: 'jeep, landrover',
612
- 610: 'jersey, T-shirt, tee shirt',
613
- 611: 'jigsaw puzzle',
614
- 612: 'jinrikisha, ricksha, rickshaw',
615
- 613: 'joystick',
616
- 614: 'kimono',
617
- 615: 'knee pad',
618
- 616: 'knot',
619
- 617: 'lab coat, laboratory coat',
620
- 618: 'ladle',
621
- 619: 'lampshade, lamp shade',
622
- 620: 'laptop, laptop computer',
623
- 621: 'lawn mower, mower',
624
- 622: 'lens cap, lens cover',
625
- 623: 'letter opener, paper knife, paperknife',
626
- 624: 'library',
627
- 625: 'lifeboat',
628
- 626: 'lighter, light, igniter, ignitor',
629
- 627: 'limousine, limo',
630
- 628: 'liner, ocean liner',
631
- 629: 'lipstick, lip rouge',
632
- 630: 'Loafer',
633
- 631: 'lotion',
634
- 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
635
- 633: "loupe, jeweler's loupe",
636
- 634: 'lumbermill, sawmill',
637
- 635: 'magnetic compass',
638
- 636: 'mailbag, postbag',
639
- 637: 'mailbox, letter box',
640
- 638: 'maillot',
641
- 639: 'maillot, tank suit',
642
- 640: 'manhole cover',
643
- 641: 'maraca',
644
- 642: 'marimba, xylophone',
645
- 643: 'mask',
646
- 644: 'matchstick',
647
- 645: 'maypole',
648
- 646: 'maze, labyrinth',
649
- 647: 'measuring cup',
650
- 648: 'medicine chest, medicine cabinet',
651
- 649: 'megalith, megalithic structure',
652
- 650: 'microphone, mike',
653
- 651: 'microwave, microwave oven',
654
- 652: 'military uniform',
655
- 653: 'milk can',
656
- 654: 'minibus',
657
- 655: 'miniskirt, mini',
658
- 656: 'minivan',
659
- 657: 'missile',
660
- 658: 'mitten',
661
- 659: 'mixing bowl',
662
- 660: 'mobile home, manufactured home',
663
- 661: 'Model T',
664
- 662: 'modem',
665
- 663: 'monastery',
666
- 664: 'monitor',
667
- 665: 'moped',
668
- 666: 'mortar',
669
- 667: 'mortarboard',
670
- 668: 'mosque',
671
- 669: 'mosquito net',
672
- 670: 'motor scooter, scooter',
673
- 671: 'mountain bike, all-terrain bike, off-roader',
674
- 672: 'mountain tent',
675
- 673: 'mouse, computer mouse',
676
- 674: 'mousetrap',
677
- 675: 'moving van',
678
- 676: 'muzzle',
679
- 677: 'nail',
680
- 678: 'neck brace',
681
- 679: 'necklace',
682
- 680: 'nipple',
683
- 681: 'notebook, notebook computer',
684
- 682: 'obelisk',
685
- 683: 'oboe, hautboy, hautbois',
686
- 684: 'ocarina, sweet potato',
687
- 685: 'odometer, hodometer, mileometer, milometer',
688
- 686: 'oil filter',
689
- 687: 'organ, pipe organ',
690
- 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
691
- 689: 'overskirt',
692
- 690: 'oxcart',
693
- 691: 'oxygen mask',
694
- 692: 'packet',
695
- 693: 'paddle, boat paddle',
696
- 694: 'paddlewheel, paddle wheel',
697
- 695: 'padlock',
698
- 696: 'paintbrush',
699
- 697: "pajama, pyjama, pj's, jammies",
700
- 698: 'palace',
701
- 699: 'panpipe, pandean pipe, syrinx',
702
- 700: 'paper towel',
703
- 701: 'parachute, chute',
704
- 702: 'parallel bars, bars',
705
- 703: 'park bench',
706
- 704: 'parking meter',
707
- 705: 'passenger car, coach, carriage',
708
- 706: 'patio, terrace',
709
- 707: 'pay-phone, pay-station',
710
- 708: 'pedestal, plinth, footstall',
711
- 709: 'pencil box, pencil case',
712
- 710: 'pencil sharpener',
713
- 711: 'perfume, essence',
714
- 712: 'Petri dish',
715
- 713: 'photocopier',
716
- 714: 'pick, plectrum, plectron',
717
- 715: 'pickelhaube',
718
- 716: 'picket fence, paling',
719
- 717: 'pickup, pickup truck',
720
- 718: 'pier',
721
- 719: 'piggy bank, penny bank',
722
- 720: 'pill bottle',
723
- 721: 'pillow',
724
- 722: 'ping-pong ball',
725
- 723: 'pinwheel',
726
- 724: 'pirate, pirate ship',
727
- 725: 'pitcher, ewer',
728
- 726: "plane, carpenter's plane, woodworking plane",
729
- 727: 'planetarium',
730
- 728: 'plastic bag',
731
- 729: 'plate rack',
732
- 730: 'plow, plough',
733
- 731: "plunger, plumber's helper",
734
- 732: 'Polaroid camera, Polaroid Land camera',
735
- 733: 'pole',
736
- 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
737
- 735: 'poncho',
738
- 736: 'pool table, billiard table, snooker table',
739
- 737: 'pop bottle, soda bottle',
740
- 738: 'pot, flowerpot',
741
- 739: "potter's wheel",
742
- 740: 'power drill',
743
- 741: 'prayer rug, prayer mat',
744
- 742: 'printer',
745
- 743: 'prison, prison house',
746
- 744: 'projectile, missile',
747
- 745: 'projector',
748
- 746: 'puck, hockey puck',
749
- 747: 'punching bag, punch bag, punching ball, punchball',
750
- 748: 'purse',
751
- 749: 'quill, quill pen',
752
- 750: 'quilt, comforter, comfort, puff',
753
- 751: 'racer, race car, racing car',
754
- 752: 'racket, racquet',
755
- 753: 'radiator',
756
- 754: 'radio, wireless',
757
- 755: 'radio telescope, radio reflector',
758
- 756: 'rain barrel',
759
- 757: 'recreational vehicle, RV, R.V.',
760
- 758: 'reel',
761
- 759: 'reflex camera',
762
- 760: 'refrigerator, icebox',
763
- 761: 'remote control, remote',
764
- 762: 'restaurant, eating house, eating place, eatery',
765
- 763: 'revolver, six-gun, six-shooter',
766
- 764: 'rifle',
767
- 765: 'rocking chair, rocker',
768
- 766: 'rotisserie',
769
- 767: 'rubber eraser, rubber, pencil eraser',
770
- 768: 'rugby ball',
771
- 769: 'rule, ruler',
772
- 770: 'running shoe',
773
- 771: 'safe',
774
- 772: 'safety pin',
775
- 773: 'saltshaker, salt shaker',
776
- 774: 'sandal',
777
- 775: 'sarong',
778
- 776: 'sax, saxophone',
779
- 777: 'scabbard',
780
- 778: 'scale, weighing machine',
781
- 779: 'school bus',
782
- 780: 'schooner',
783
- 781: 'scoreboard',
784
- 782: 'screen, CRT screen',
785
- 783: 'screw',
786
- 784: 'screwdriver',
787
- 785: 'seat belt, seatbelt',
788
- 786: 'sewing machine',
789
- 787: 'shield, buckler',
790
- 788: 'shoe shop, shoe-shop, shoe store',
791
- 789: 'shoji',
792
- 790: 'shopping basket',
793
- 791: 'shopping cart',
794
- 792: 'shovel',
795
- 793: 'shower cap',
796
- 794: 'shower curtain',
797
- 795: 'ski',
798
- 796: 'ski mask',
799
- 797: 'sleeping bag',
800
- 798: 'slide rule, slipstick',
801
- 799: 'sliding door',
802
- 800: 'slot, one-armed bandit',
803
- 801: 'snorkel',
804
- 802: 'snowmobile',
805
- 803: 'snowplow, snowplough',
806
- 804: 'soap dispenser',
807
- 805: 'soccer ball',
808
- 806: 'sock',
809
- 807: 'solar dish, solar collector, solar furnace',
810
- 808: 'sombrero',
811
- 809: 'soup bowl',
812
- 810: 'space bar',
813
- 811: 'space heater',
814
- 812: 'space shuttle',
815
- 813: 'spatula',
816
- 814: 'speedboat',
817
- 815: "spider web, spider's web",
818
- 816: 'spindle',
819
- 817: 'sports car, sport car',
820
- 818: 'spotlight, spot',
821
- 819: 'stage',
822
- 820: 'steam locomotive',
823
- 821: 'steel arch bridge',
824
- 822: 'steel drum',
825
- 823: 'stethoscope',
826
- 824: 'stole',
827
- 825: 'stone wall',
828
- 826: 'stopwatch, stop watch',
829
- 827: 'stove',
830
- 828: 'strainer',
831
- 829: 'streetcar, tram, tramcar, trolley, trolley car',
832
- 830: 'stretcher',
833
- 831: 'studio couch, day bed',
834
- 832: 'stupa, tope',
835
- 833: 'submarine, pigboat, sub, U-boat',
836
- 834: 'suit, suit of clothes',
837
- 835: 'sundial',
838
- 836: 'sunglass',
839
- 837: 'sunglasses, dark glasses, shades',
840
- 838: 'sunscreen, sunblock, sun blocker',
841
- 839: 'suspension bridge',
842
- 840: 'swab, swob, mop',
843
- 841: 'sweatshirt',
844
- 842: 'swimming trunks, bathing trunks',
845
- 843: 'swing',
846
- 844: 'switch, electric switch, electrical switch',
847
- 845: 'syringe',
848
- 846: 'table lamp',
849
- 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
850
- 848: 'tape player',
851
- 849: 'teapot',
852
- 850: 'teddy, teddy bear',
853
- 851: 'television, television system',
854
- 852: 'tennis ball',
855
- 853: 'thatch, thatched roof',
856
- 854: 'theater curtain, theatre curtain',
857
- 855: 'thimble',
858
- 856: 'thresher, thrasher, threshing machine',
859
- 857: 'throne',
860
- 858: 'tile roof',
861
- 859: 'toaster',
862
- 860: 'tobacco shop, tobacconist shop, tobacconist',
863
- 861: 'toilet seat',
864
- 862: 'torch',
865
- 863: 'totem pole',
866
- 864: 'tow truck, tow car, wrecker',
867
- 865: 'toyshop',
868
- 866: 'tractor',
869
- 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
870
- 868: 'tray',
871
- 869: 'trench coat',
872
- 870: 'tricycle, trike, velocipede',
873
- 871: 'trimaran',
874
- 872: 'tripod',
875
- 873: 'triumphal arch',
876
- 874: 'trolleybus, trolley coach, trackless trolley',
877
- 875: 'trombone',
878
- 876: 'tub, vat',
879
- 877: 'turnstile',
880
- 878: 'typewriter keyboard',
881
- 879: 'umbrella',
882
- 880: 'unicycle, monocycle',
883
- 881: 'upright, upright piano',
884
- 882: 'vacuum, vacuum cleaner',
885
- 883: 'vase',
886
- 884: 'vault',
887
- 885: 'velvet',
888
- 886: 'vending machine',
889
- 887: 'vestment',
890
- 888: 'viaduct',
891
- 889: 'violin, fiddle',
892
- 890: 'volleyball',
893
- 891: 'waffle iron',
894
- 892: 'wall clock',
895
- 893: 'wallet, billfold, notecase, pocketbook',
896
- 894: 'wardrobe, closet, press',
897
- 895: 'warplane, military plane',
898
- 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
899
- 897: 'washer, automatic washer, washing machine',
900
- 898: 'water bottle',
901
- 899: 'water jug',
902
- 900: 'water tower',
903
- 901: 'whiskey jug',
904
- 902: 'whistle',
905
- 903: 'wig',
906
- 904: 'window screen',
907
- 905: 'window shade',
908
- 906: 'Windsor tie',
909
- 907: 'wine bottle',
910
- 908: 'wing',
911
- 909: 'wok',
912
- 910: 'wooden spoon',
913
- 911: 'wool, woolen, woollen',
914
- 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
915
- 913: 'wreck',
916
- 914: 'yawl',
917
- 915: 'yurt',
918
- 916: 'web site, website, internet site, site',
919
- 917: 'comic book',
920
- 918: 'crossword puzzle, crossword',
921
- 919: 'street sign',
922
- 920: 'traffic light, traffic signal, stoplight',
923
- 921: 'book jacket, dust cover, dust jacket, dust wrapper',
924
- 922: 'menu',
925
- 923: 'plate',
926
- 924: 'guacamole',
927
- 925: 'consomme',
928
- 926: 'hot pot, hotpot',
929
- 927: 'trifle',
930
- 928: 'ice cream, icecream',
931
- 929: 'ice lolly, lolly, lollipop, popsicle',
932
- 930: 'French loaf',
933
- 931: 'bagel, beigel',
934
- 932: 'pretzel',
935
- 933: 'cheeseburger',
936
- 934: 'hotdog, hot dog, red hot',
937
- 935: 'mashed potato',
938
- 936: 'head cabbage',
939
- 937: 'broccoli',
940
- 938: 'cauliflower',
941
- 939: 'zucchini, courgette',
942
- 940: 'spaghetti squash',
943
- 941: 'acorn squash',
944
- 942: 'butternut squash',
945
- 943: 'cucumber, cuke',
946
- 944: 'artichoke, globe artichoke',
947
- 945: 'bell pepper',
948
- 946: 'cardoon',
949
- 947: 'mushroom',
950
- 948: 'Granny Smith',
951
- 949: 'strawberry',
952
- 950: 'orange',
953
- 951: 'lemon',
954
- 952: 'fig',
955
- 953: 'pineapple, ananas',
956
- 954: 'banana',
957
- 955: 'jackfruit, jak, jack',
958
- 956: 'custard apple',
959
- 957: 'pomegranate',
960
- 958: 'hay',
961
- 959: 'carbonara',
962
- 960: 'chocolate sauce, chocolate syrup',
963
- 961: 'dough',
964
- 962: 'meat loaf, meatloaf',
965
- 963: 'pizza, pizza pie',
966
- 964: 'potpie',
967
- 965: 'burrito',
968
- 966: 'red wine',
969
- 967: 'espresso',
970
- 968: 'cup',
971
- 969: 'eggnog',
972
- 970: 'alp',
973
- 971: 'bubble',
974
- 972: 'cliff, drop, drop-off',
975
- 973: 'coral reef',
976
- 974: 'geyser',
977
- 975: 'lakeside, lakeshore',
978
- 976: 'promontory, headland, head, foreland',
979
- 977: 'sandbar, sand bar',
980
- 978: 'seashore, coast, seacoast, sea-coast',
981
- 979: 'valley, vale',
982
- 980: 'volcano',
983
- 981: 'ballplayer, baseball player',
984
- 982: 'groom, bridegroom',
985
- 983: 'scuba diver',
986
- 984: 'rapeseed',
987
- 985: 'daisy',
988
- 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
989
- 987: 'corn',
990
- 988: 'acorn',
991
- 989: 'hip, rose hip, rosehip',
992
- 990: 'buckeye, horse chestnut, conker',
993
- 991: 'coral fungus',
994
- 992: 'agaric',
995
- 993: 'gyromitra',
996
- 994: 'stinkhorn, carrion fungus',
997
- 995: 'earthstar',
998
- 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
999
- 997: 'bolete',
1000
- 998: 'ear, spike, capitulum',
1001
- 999: 'toilet tissue, toilet paper, bathroom tissue'
1002
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/transforms.py DELETED
@@ -1,442 +0,0 @@
1
- from __future__ import division
2
- import sys
3
- import random
4
- from PIL import Image
5
-
6
- try:
7
- import accimage
8
- except ImportError:
9
- accimage = None
10
- import numbers
11
- import collections
12
-
13
- from torchvision.transforms import functional as F
14
-
15
- if sys.version_info < (3, 3):
16
- Sequence = collections.Sequence
17
- Iterable = collections.Iterable
18
- else:
19
- Sequence = collections.abc.Sequence
20
- Iterable = collections.abc.Iterable
21
-
22
- _pil_interpolation_to_str = {
23
- Image.NEAREST: 'PIL.Image.NEAREST',
24
- Image.BILINEAR: 'PIL.Image.BILINEAR',
25
- Image.BICUBIC: 'PIL.Image.BICUBIC',
26
- Image.LANCZOS: 'PIL.Image.LANCZOS',
27
- Image.HAMMING: 'PIL.Image.HAMMING',
28
- Image.BOX: 'PIL.Image.BOX',
29
- }
30
-
31
-
32
- class Compose(object):
33
- """Composes several transforms together.
34
-
35
- Args:
36
- transforms (list of ``Transform`` objects): list of transforms to compose.
37
-
38
- Example:
39
- >>> transforms.Compose([
40
- >>> transforms.CenterCrop(10),
41
- >>> transforms.ToTensor(),
42
- >>> ])
43
- """
44
-
45
- def __init__(self, transforms):
46
- self.transforms = transforms
47
-
48
- def __call__(self, img, tgt):
49
- for t in self.transforms:
50
- img, tgt = t(img, tgt)
51
- return img, tgt
52
-
53
- def __repr__(self):
54
- format_string = self.__class__.__name__ + '('
55
- for t in self.transforms:
56
- format_string += '\n'
57
- format_string += ' {0}'.format(t)
58
- format_string += '\n)'
59
- return format_string
60
-
61
-
62
- class Resize(object):
63
- """Resize the input PIL Image to the given size.
64
-
65
- Args:
66
- size (sequence or int): Desired output size. If size is a sequence like
67
- (h, w), output size will be matched to this. If size is an int,
68
- smaller edge of the image will be matched to this number.
69
- i.e, if height > width, then image will be rescaled to
70
- (size * height / width, size)
71
- interpolation (int, optional): Desired interpolation. Default is
72
- ``PIL.Image.BILINEAR``
73
- """
74
-
75
- def __init__(self, size, interpolation=Image.BILINEAR):
76
- assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
77
- self.size = size
78
- self.interpolation = interpolation
79
-
80
- def __call__(self, img, tgt):
81
- """
82
- Args:
83
- img (PIL Image): Image to be scaled.
84
-
85
- Returns:
86
- PIL Image: Rescaled image.
87
- """
88
- return F.resize(img, self.size, self.interpolation), F.resize(tgt, self.size, Image.NEAREST)
89
-
90
- def __repr__(self):
91
- interpolate_str = _pil_interpolation_to_str[self.interpolation]
92
- return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
93
-
94
-
95
- class CenterCrop(object):
96
- """Crops the given PIL Image at the center.
97
-
98
- Args:
99
- size (sequence or int): Desired output size of the crop. If size is an
100
- int instead of sequence like (h, w), a square crop (size, size) is
101
- made.
102
- """
103
-
104
- def __init__(self, size):
105
- if isinstance(size, numbers.Number):
106
- self.size = (int(size), int(size))
107
- else:
108
- self.size = size
109
-
110
- def __call__(self, img, tgt):
111
- """
112
- Args:
113
- img (PIL Image): Image to be cropped.
114
-
115
- Returns:
116
- PIL Image: Cropped image.
117
- """
118
- return F.center_crop(img, self.size), F.center_crop(tgt, self.size)
119
-
120
- def __repr__(self):
121
- return self.__class__.__name__ + '(size={0})'.format(self.size)
122
-
123
-
124
- class RandomCrop(object):
125
- """Crop the given PIL Image at a random location.
126
-
127
- Args:
128
- size (sequence or int): Desired output size of the crop. If size is an
129
- int instead of sequence like (h, w), a square crop (size, size) is
130
- made.
131
- padding (int or sequence, optional): Optional padding on each border
132
- of the image. Default is None, i.e no padding. If a sequence of length
133
- 4 is provided, it is used to pad left, top, right, bottom borders
134
- respectively. If a sequence of length 2 is provided, it is used to
135
- pad left/right, top/bottom borders, respectively.
136
- pad_if_needed (boolean): It will pad the image if smaller than the
137
- desired size to avoid raising an exception.
138
- fill: Pixel fill value for constant fill. Default is 0. If a tuple of
139
- length 3, it is used to fill R, G, B channels respectively.
140
- This value is only used when the padding_mode is constant
141
- padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
142
-
143
- - constant: pads with a constant value, this value is specified with fill
144
-
145
- - edge: pads with the last value on the edge of the image
146
-
147
- - reflect: pads with reflection of image (without repeating the last value on the edge)
148
-
149
- padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
150
- will result in [3, 2, 1, 2, 3, 4, 3, 2]
151
-
152
- - symmetric: pads with reflection of image (repeating the last value on the edge)
153
-
154
- padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
155
- will result in [2, 1, 1, 2, 3, 4, 4, 3]
156
-
157
- """
158
-
159
- def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
160
- if isinstance(size, numbers.Number):
161
- self.size = (int(size), int(size))
162
- else:
163
- self.size = size
164
- self.padding = padding
165
- self.pad_if_needed = pad_if_needed
166
- self.fill = fill
167
- self.padding_mode = padding_mode
168
-
169
- @staticmethod
170
- def get_params(img, output_size):
171
- """Get parameters for ``crop`` for a random crop.
172
-
173
- Args:
174
- img (PIL Image): Image to be cropped.
175
- output_size (tuple): Expected output size of the crop.
176
-
177
- Returns:
178
- tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
179
- """
180
- w, h = img.size
181
- th, tw = output_size
182
- if w == tw and h == th:
183
- return 0, 0, h, w
184
-
185
- i = random.randint(0, h - th)
186
- j = random.randint(0, w - tw)
187
- return i, j, th, tw
188
-
189
- def __call__(self, img, tgt):
190
- """
191
- Args:
192
- img (PIL Image): Image to be cropped.
193
-
194
- Returns:
195
- PIL Image: Cropped image.
196
- """
197
- if self.padding is not None:
198
- img = F.pad(img, self.padding, self.fill, self.padding_mode)
199
- tgt = F.pad(tgt, self.padding, self.fill, self.padding_mode)
200
-
201
- # pad the width if needed
202
- if self.pad_if_needed and img.size[0] < self.size[1]:
203
- img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
204
- tgt = F.pad(tgt, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
205
- # pad the height if needed
206
- if self.pad_if_needed and img.size[1] < self.size[0]:
207
- img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
208
- tgt = F.pad(tgt, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
209
-
210
- i, j, h, w = self.get_params(img, self.size)
211
-
212
- return F.crop(img, i, j, h, w), F.crop(tgt, i, j, h, w)
213
-
214
- def __repr__(self):
215
- return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
216
-
217
-
218
- class RandomHorizontalFlip(object):
219
- """Horizontally flip the given PIL Image randomly with a given probability.
220
-
221
- Args:
222
- p (float): probability of the image being flipped. Default value is 0.5
223
- """
224
-
225
- def __init__(self, p=0.5):
226
- self.p = p
227
-
228
- def __call__(self, img, tgt):
229
- """
230
- Args:
231
- img (PIL Image): Image to be flipped.
232
-
233
- Returns:
234
- PIL Image: Randomly flipped image.
235
- """
236
- if random.random() < self.p:
237
- return F.hflip(img), F.hflip(tgt)
238
-
239
- return img, tgt
240
-
241
- def __repr__(self):
242
- return self.__class__.__name__ + '(p={})'.format(self.p)
243
-
244
-
245
- class RandomVerticalFlip(object):
246
- """Vertically flip the given PIL Image randomly with a given probability.
247
-
248
- Args:
249
- p (float): probability of the image being flipped. Default value is 0.5
250
- """
251
-
252
- def __init__(self, p=0.5):
253
- self.p = p
254
-
255
- def __call__(self, img, tgt):
256
- """
257
- Args:
258
- img (PIL Image): Image to be flipped.
259
-
260
- Returns:
261
- PIL Image: Randomly flipped image.
262
- """
263
- if random.random() < self.p:
264
- return F.vflip(img), F.vflip(tgt)
265
- return img, tgt
266
-
267
- def __repr__(self):
268
- return self.__class__.__name__ + '(p={})'.format(self.p)
269
-
270
-
271
- class Lambda(object):
272
- """Apply a user-defined lambda as a transform.
273
-
274
- Args:
275
- lambd (function): Lambda/function to be used for transform.
276
- """
277
-
278
- def __init__(self, lambd):
279
- assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
280
- self.lambd = lambd
281
-
282
- def __call__(self, img, tgt):
283
- return self.lambd(img, tgt)
284
-
285
- def __repr__(self):
286
- return self.__class__.__name__ + '()'
287
-
288
-
289
- class ColorJitter(object):
290
- """Randomly change the brightness, contrast and saturation of an image.
291
-
292
- Args:
293
- brightness (float or tuple of float (min, max)): How much to jitter brightness.
294
- brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
295
- or the given [min, max]. Should be non negative numbers.
296
- contrast (float or tuple of float (min, max)): How much to jitter contrast.
297
- contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
298
- or the given [min, max]. Should be non negative numbers.
299
- saturation (float or tuple of float (min, max)): How much to jitter saturation.
300
- saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
301
- or the given [min, max]. Should be non negative numbers.
302
- hue (float or tuple of float (min, max)): How much to jitter hue.
303
- hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
304
- Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
305
- """
306
- def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
307
- self.brightness = self._check_input(brightness, 'brightness')
308
- self.contrast = self._check_input(contrast, 'contrast')
309
- self.saturation = self._check_input(saturation, 'saturation')
310
- self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
311
- clip_first_on_zero=False)
312
-
313
- def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
314
- if isinstance(value, numbers.Number):
315
- if value < 0:
316
- raise ValueError("If {} is a single number, it must be non negative.".format(name))
317
- value = [center - value, center + value]
318
- if clip_first_on_zero:
319
- value[0] = max(value[0], 0)
320
- elif isinstance(value, (tuple, list)) and len(value) == 2:
321
- if not bound[0] <= value[0] <= value[1] <= bound[1]:
322
- raise ValueError("{} values should be between {}".format(name, bound))
323
- else:
324
- raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
325
-
326
- # if value is 0 or (1., 1.) for brightness/contrast/saturation
327
- # or (0., 0.) for hue, do nothing
328
- if value[0] == value[1] == center:
329
- value = None
330
- return value
331
-
332
- @staticmethod
333
- def get_params(brightness, contrast, saturation, hue):
334
- """Get a randomized transform to be applied on image.
335
-
336
- Arguments are same as that of __init__.
337
-
338
- Returns:
339
- Transform which randomly adjusts brightness, contrast and
340
- saturation in a random order.
341
- """
342
- transforms = []
343
-
344
- if brightness is not None:
345
- brightness_factor = random.uniform(brightness[0], brightness[1])
346
- transforms.append(Lambda(lambda img, tgt: (F.adjust_brightness(img, brightness_factor), tgt)))
347
-
348
- if contrast is not None:
349
- contrast_factor = random.uniform(contrast[0], contrast[1])
350
- transforms.append(Lambda(lambda img, tgt: (F.adjust_contrast(img, contrast_factor), tgt)))
351
-
352
- if saturation is not None:
353
- saturation_factor = random.uniform(saturation[0], saturation[1])
354
- transforms.append(Lambda(lambda img, tgt: (F.adjust_saturation(img, saturation_factor), tgt)))
355
-
356
- if hue is not None:
357
- hue_factor = random.uniform(hue[0], hue[1])
358
- transforms.append(Lambda(lambda img, tgt: (F.adjust_hue(img, hue_factor), tgt)))
359
-
360
- random.shuffle(transforms)
361
- transform = Compose(transforms)
362
-
363
- return transform
364
-
365
- def __call__(self, img, tgt):
366
- """
367
- Args:
368
- img (PIL Image): Input image.
369
-
370
- Returns:
371
- PIL Image: Color jittered image.
372
- """
373
- transform = self.get_params(self.brightness, self.contrast,
374
- self.saturation, self.hue)
375
- return transform(img, tgt)
376
-
377
- def __repr__(self):
378
- format_string = self.__class__.__name__ + '('
379
- format_string += 'brightness={0}'.format(self.brightness)
380
- format_string += ', contrast={0}'.format(self.contrast)
381
- format_string += ', saturation={0}'.format(self.saturation)
382
- format_string += ', hue={0})'.format(self.hue)
383
- return format_string
384
-
385
-
386
- class Normalize(object):
387
- """Normalize a tensor image with mean and standard deviation.
388
- Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
389
- will normalize each channel of the input ``torch.*Tensor`` i.e.
390
- ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
391
-
392
- .. note::
393
- This transform acts out of place, i.e., it does not mutates the input tensor.
394
-
395
- Args:
396
- mean (sequence): Sequence of means for each channel.
397
- std (sequence): Sequence of standard deviations for each channel.
398
- """
399
-
400
- def __init__(self, mean, std, inplace=False):
401
- self.mean = mean
402
- self.std = std
403
- self.inplace = inplace
404
-
405
- def __call__(self, img, tgt):
406
- """
407
- Args:
408
- tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
409
-
410
- Returns:
411
- Tensor: Normalized Tensor image.
412
- """
413
- # return F.normalize(img, self.mean, self.std, self.inplace), tgt
414
- return F.normalize(img, self.mean, self.std), tgt
415
-
416
- def __repr__(self):
417
- return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
418
-
419
-
420
- class ToTensor(object):
421
- """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
422
-
423
- Converts a PIL Image or numpy.ndarray (H x W x C) in the range
424
- [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
425
- if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
426
- or if the numpy.ndarray has dtype = np.uint8
427
-
428
- In the other cases, tensors are returned without scaling.
429
- """
430
-
431
- def __call__(self, img, tgt):
432
- """
433
- Args:
434
- pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
435
-
436
- Returns:
437
- Tensor: Converted image.
438
- """
439
- return F.to_tensor(img), tgt
440
-
441
- def __repr__(self):
442
- return self.__class__.__name__ + '()'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/generate_visualizations.py DELETED
@@ -1,208 +0,0 @@
1
- import os
2
- from tqdm import tqdm
3
- import h5py
4
-
5
- import argparse
6
-
7
- # Import saliency methods and models
8
- from misc_functions import *
9
-
10
- from ViT_explanation_generator import Baselines, LRP
11
- from ViT_new import vit_base_patch16_224
12
- from ViT_LRP import vit_base_patch16_224 as vit_LRP
13
- from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
14
-
15
- from torchvision.datasets import ImageNet
16
-
17
-
18
- def normalize(tensor,
19
- mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
20
- dtype = tensor.dtype
21
- mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
22
- std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
23
- tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
24
- return tensor
25
-
26
-
27
- def compute_saliency_and_save(args):
28
- first = True
29
- with h5py.File(os.path.join(args.method_dir, 'results.hdf5'), 'a') as f:
30
- data_cam = f.create_dataset('vis',
31
- (1, 1, 224, 224),
32
- maxshape=(None, 1, 224, 224),
33
- dtype=np.float32,
34
- compression="gzip")
35
- data_image = f.create_dataset('image',
36
- (1, 3, 224, 224),
37
- maxshape=(None, 3, 224, 224),
38
- dtype=np.float32,
39
- compression="gzip")
40
- data_target = f.create_dataset('target',
41
- (1,),
42
- maxshape=(None,),
43
- dtype=np.int32,
44
- compression="gzip")
45
- for batch_idx, (data, target) in enumerate(tqdm(sample_loader)):
46
- if first:
47
- first = False
48
- data_cam.resize(data_cam.shape[0] + data.shape[0] - 1, axis=0)
49
- data_image.resize(data_image.shape[0] + data.shape[0] - 1, axis=0)
50
- data_target.resize(data_target.shape[0] + data.shape[0] - 1, axis=0)
51
- else:
52
- data_cam.resize(data_cam.shape[0] + data.shape[0], axis=0)
53
- data_image.resize(data_image.shape[0] + data.shape[0], axis=0)
54
- data_target.resize(data_target.shape[0] + data.shape[0], axis=0)
55
-
56
- # Add data
57
- data_image[-data.shape[0]:] = data.data.cpu().numpy()
58
- data_target[-data.shape[0]:] = target.data.cpu().numpy()
59
-
60
- target = target.to(device)
61
-
62
- data = normalize(data)
63
- data = data.to(device)
64
- data.requires_grad_()
65
-
66
- index = None
67
- if args.vis_class == 'target':
68
- index = target
69
-
70
- if args.method == 'rollout':
71
- Res = baselines.generate_rollout(data, start_layer=1).reshape(data.shape[0], 1, 14, 14)
72
- # Res = Res - Res.mean()
73
-
74
- elif args.method == 'lrp':
75
- Res = lrp.generate_LRP(data, start_layer=1, index=index).reshape(data.shape[0], 1, 14, 14)
76
- # Res = Res - Res.mean()
77
-
78
- elif args.method == 'transformer_attribution':
79
- Res = lrp.generate_LRP(data, start_layer=1, method="grad", index=index).reshape(data.shape[0], 1, 14, 14)
80
- # Res = Res - Res.mean()
81
-
82
- elif args.method == 'full_lrp':
83
- Res = orig_lrp.generate_LRP(data, method="full", index=index).reshape(data.shape[0], 1, 224, 224)
84
- # Res = Res - Res.mean()
85
-
86
- elif args.method == 'lrp_last_layer':
87
- Res = orig_lrp.generate_LRP(data, method="last_layer", is_ablation=args.is_ablation, index=index) \
88
- .reshape(data.shape[0], 1, 14, 14)
89
- # Res = Res - Res.mean()
90
-
91
- elif args.method == 'attn_last_layer':
92
- Res = lrp.generate_LRP(data, method="last_layer_attn", is_ablation=args.is_ablation) \
93
- .reshape(data.shape[0], 1, 14, 14)
94
-
95
- elif args.method == 'attn_gradcam':
96
- Res = baselines.generate_cam_attn(data, index=index).reshape(data.shape[0], 1, 14, 14)
97
-
98
- if args.method != 'full_lrp' and args.method != 'input_grads':
99
- Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear').cuda()
100
- Res = (Res - Res.min()) / (Res.max() - Res.min())
101
-
102
- data_cam[-data.shape[0]:] = Res.data.cpu().numpy()
103
-
104
-
105
- if __name__ == "__main__":
106
- parser = argparse.ArgumentParser(description='Train a segmentation')
107
- parser.add_argument('--batch-size', type=int,
108
- default=1,
109
- help='')
110
- parser.add_argument('--method', type=str,
111
- default='grad_rollout',
112
- choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'lrp_last_layer',
113
- 'attn_last_layer', 'attn_gradcam'],
114
- help='')
115
- parser.add_argument('--lmd', type=float,
116
- default=10,
117
- help='')
118
- parser.add_argument('--vis-class', type=str,
119
- default='top',
120
- choices=['top', 'target', 'index'],
121
- help='')
122
- parser.add_argument('--class-id', type=int,
123
- default=0,
124
- help='')
125
- parser.add_argument('--cls-agn', action='store_true',
126
- default=False,
127
- help='')
128
- parser.add_argument('--no-ia', action='store_true',
129
- default=False,
130
- help='')
131
- parser.add_argument('--no-fx', action='store_true',
132
- default=False,
133
- help='')
134
- parser.add_argument('--no-fgx', action='store_true',
135
- default=False,
136
- help='')
137
- parser.add_argument('--no-m', action='store_true',
138
- default=False,
139
- help='')
140
- parser.add_argument('--no-reg', action='store_true',
141
- default=False,
142
- help='')
143
- parser.add_argument('--is-ablation', type=bool,
144
- default=False,
145
- help='')
146
- parser.add_argument('--imagenet-validation-path', type=str,
147
- required=True,
148
- help='')
149
- args = parser.parse_args()
150
-
151
- # PATH variables
152
- PATH = os.path.dirname(os.path.abspath(__file__)) + '/'
153
- os.makedirs(os.path.join(PATH, 'visualizations'), exist_ok=True)
154
-
155
- try:
156
- os.remove(os.path.join(PATH, 'visualizations/{}/{}/results.hdf5'.format(args.method,
157
- args.vis_class)))
158
- except OSError:
159
- pass
160
-
161
-
162
- os.makedirs(os.path.join(PATH, 'visualizations/{}'.format(args.method)), exist_ok=True)
163
- if args.vis_class == 'index':
164
- os.makedirs(os.path.join(PATH, 'visualizations/{}/{}_{}'.format(args.method,
165
- args.vis_class,
166
- args.class_id)), exist_ok=True)
167
- args.method_dir = os.path.join(PATH, 'visualizations/{}/{}_{}'.format(args.method,
168
- args.vis_class,
169
- args.class_id))
170
- else:
171
- ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
172
- os.makedirs(os.path.join(PATH, 'visualizations/{}/{}/{}'.format(args.method,
173
- args.vis_class, ablation_fold)), exist_ok=True)
174
- args.method_dir = os.path.join(PATH, 'visualizations/{}/{}/{}'.format(args.method,
175
- args.vis_class, ablation_fold))
176
-
177
- cuda = torch.cuda.is_available()
178
- device = torch.device("cuda" if cuda else "cpu")
179
-
180
- # Model
181
- model = vit_base_patch16_224(pretrained=True).cuda()
182
- baselines = Baselines(model)
183
-
184
- # LRP
185
- model_LRP = vit_LRP(pretrained=True).cuda()
186
- model_LRP.eval()
187
- lrp = LRP(model_LRP)
188
-
189
- # orig LRP
190
- model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
191
- model_orig_LRP.eval()
192
- orig_lrp = LRP(model_orig_LRP)
193
-
194
- # Dataset loader for sample images
195
- transform = transforms.Compose([
196
- transforms.Resize((224, 224)),
197
- transforms.ToTensor(),
198
- ])
199
-
200
- imagenet_ds = ImageNet(args.imagenet_validation_path, split='val', download=False, transform=transform)
201
- sample_loader = torch.utils.data.DataLoader(
202
- imagenet_ds,
203
- batch_size=args.batch_size,
204
- shuffle=False,
205
- num_workers=4
206
- )
207
-
208
- compute_saliency_and_save(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/helpers.py DELETED
@@ -1,295 +0,0 @@
1
- """ Model creation / weight loading / state_dict helpers
2
-
3
- Hacked together by / Copyright 2020 Ross Wightman
4
- """
5
- import logging
6
- import os
7
- import math
8
- from collections import OrderedDict
9
- from copy import deepcopy
10
- from typing import Callable
11
-
12
- import torch
13
- import torch.nn as nn
14
- import torch.utils.model_zoo as model_zoo
15
-
16
- _logger = logging.getLogger(__name__)
17
-
18
-
19
- def load_state_dict(checkpoint_path, use_ema=False):
20
- if checkpoint_path and os.path.isfile(checkpoint_path):
21
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
22
- state_dict_key = 'state_dict'
23
- if isinstance(checkpoint, dict):
24
- if use_ema and 'state_dict_ema' in checkpoint:
25
- state_dict_key = 'state_dict_ema'
26
- if state_dict_key and state_dict_key in checkpoint:
27
- new_state_dict = OrderedDict()
28
- for k, v in checkpoint[state_dict_key].items():
29
- # strip `module.` prefix
30
- name = k[7:] if k.startswith('module') else k
31
- new_state_dict[name] = v
32
- state_dict = new_state_dict
33
- else:
34
- state_dict = checkpoint
35
- _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
36
- return state_dict
37
- else:
38
- _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
39
- raise FileNotFoundError()
40
-
41
-
42
- def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
43
- state_dict = load_state_dict(checkpoint_path, use_ema)
44
- model.load_state_dict(state_dict, strict=strict)
45
-
46
-
47
- def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
48
- resume_epoch = None
49
- if os.path.isfile(checkpoint_path):
50
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
51
- if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
52
- if log_info:
53
- _logger.info('Restoring model state from checkpoint...')
54
- new_state_dict = OrderedDict()
55
- for k, v in checkpoint['state_dict'].items():
56
- name = k[7:] if k.startswith('module') else k
57
- new_state_dict[name] = v
58
- model.load_state_dict(new_state_dict)
59
-
60
- if optimizer is not None and 'optimizer' in checkpoint:
61
- if log_info:
62
- _logger.info('Restoring optimizer state from checkpoint...')
63
- optimizer.load_state_dict(checkpoint['optimizer'])
64
-
65
- if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
66
- if log_info:
67
- _logger.info('Restoring AMP loss scaler state from checkpoint...')
68
- loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
69
-
70
- if 'epoch' in checkpoint:
71
- resume_epoch = checkpoint['epoch']
72
- if 'version' in checkpoint and checkpoint['version'] > 1:
73
- resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
74
-
75
- if log_info:
76
- _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
77
- else:
78
- model.load_state_dict(checkpoint)
79
- if log_info:
80
- _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
81
- return resume_epoch
82
- else:
83
- _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
84
- raise FileNotFoundError()
85
-
86
-
87
- def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True):
88
- if cfg is None:
89
- cfg = getattr(model, 'default_cfg')
90
- if cfg is None or 'url' not in cfg or not cfg['url']:
91
- _logger.warning("Pretrained model URL is invalid, using random initialization.")
92
- return
93
-
94
- state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
95
-
96
- if filter_fn is not None:
97
- state_dict = filter_fn(state_dict)
98
-
99
- if in_chans == 1:
100
- conv1_name = cfg['first_conv']
101
- _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
102
- conv1_weight = state_dict[conv1_name + '.weight']
103
- # Some weights are in torch.half, ensure it's float for sum on CPU
104
- conv1_type = conv1_weight.dtype
105
- conv1_weight = conv1_weight.float()
106
- O, I, J, K = conv1_weight.shape
107
- if I > 3:
108
- assert conv1_weight.shape[1] % 3 == 0
109
- # For models with space2depth stems
110
- conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
111
- conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
112
- else:
113
- conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
114
- conv1_weight = conv1_weight.to(conv1_type)
115
- state_dict[conv1_name + '.weight'] = conv1_weight
116
- elif in_chans != 3:
117
- conv1_name = cfg['first_conv']
118
- conv1_weight = state_dict[conv1_name + '.weight']
119
- conv1_type = conv1_weight.dtype
120
- conv1_weight = conv1_weight.float()
121
- O, I, J, K = conv1_weight.shape
122
- if I != 3:
123
- _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
124
- del state_dict[conv1_name + '.weight']
125
- strict = False
126
- else:
127
- # NOTE this strategy should be better than random init, but there could be other combinations of
128
- # the original RGB input layer weights that'd work better for specific cases.
129
- _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
130
- repeat = int(math.ceil(in_chans / 3))
131
- conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
132
- conv1_weight *= (3 / float(in_chans))
133
- conv1_weight = conv1_weight.to(conv1_type)
134
- state_dict[conv1_name + '.weight'] = conv1_weight
135
-
136
- classifier_name = cfg['classifier']
137
- if num_classes == 1000 and cfg['num_classes'] == 1001:
138
- # special case for imagenet trained models with extra background class in pretrained weights
139
- classifier_weight = state_dict[classifier_name + '.weight']
140
- state_dict[classifier_name + '.weight'] = classifier_weight[1:]
141
- classifier_bias = state_dict[classifier_name + '.bias']
142
- state_dict[classifier_name + '.bias'] = classifier_bias[1:]
143
- elif num_classes != cfg['num_classes']:
144
- # completely discard fully connected for all other differences between pretrained and created model
145
- del state_dict[classifier_name + '.weight']
146
- del state_dict[classifier_name + '.bias']
147
- strict = False
148
-
149
- model.load_state_dict(state_dict, strict=strict)
150
-
151
-
152
- def extract_layer(model, layer):
153
- layer = layer.split('.')
154
- module = model
155
- if hasattr(model, 'module') and layer[0] != 'module':
156
- module = model.module
157
- if not hasattr(model, 'module') and layer[0] == 'module':
158
- layer = layer[1:]
159
- for l in layer:
160
- if hasattr(module, l):
161
- if not l.isdigit():
162
- module = getattr(module, l)
163
- else:
164
- module = module[int(l)]
165
- else:
166
- return module
167
- return module
168
-
169
-
170
- def set_layer(model, layer, val):
171
- layer = layer.split('.')
172
- module = model
173
- if hasattr(model, 'module') and layer[0] != 'module':
174
- module = model.module
175
- lst_index = 0
176
- module2 = module
177
- for l in layer:
178
- if hasattr(module2, l):
179
- if not l.isdigit():
180
- module2 = getattr(module2, l)
181
- else:
182
- module2 = module2[int(l)]
183
- lst_index += 1
184
- lst_index -= 1
185
- for l in layer[:lst_index]:
186
- if not l.isdigit():
187
- module = getattr(module, l)
188
- else:
189
- module = module[int(l)]
190
- l = layer[lst_index]
191
- setattr(module, l, val)
192
-
193
-
194
- def adapt_model_from_string(parent_module, model_string):
195
- separator = '***'
196
- state_dict = {}
197
- lst_shape = model_string.split(separator)
198
- for k in lst_shape:
199
- k = k.split(':')
200
- key = k[0]
201
- shape = k[1][1:-1].split(',')
202
- if shape[0] != '':
203
- state_dict[key] = [int(i) for i in shape]
204
-
205
- new_module = deepcopy(parent_module)
206
- for n, m in parent_module.named_modules():
207
- old_module = extract_layer(parent_module, n)
208
- if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
209
- if isinstance(old_module, Conv2dSame):
210
- conv = Conv2dSame
211
- else:
212
- conv = nn.Conv2d
213
- s = state_dict[n + '.weight']
214
- in_channels = s[1]
215
- out_channels = s[0]
216
- g = 1
217
- if old_module.groups > 1:
218
- in_channels = out_channels
219
- g = in_channels
220
- new_conv = conv(
221
- in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
222
- bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
223
- groups=g, stride=old_module.stride)
224
- set_layer(new_module, n, new_conv)
225
- if isinstance(old_module, nn.BatchNorm2d):
226
- new_bn = nn.BatchNorm2d(
227
- num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
228
- affine=old_module.affine, track_running_stats=True)
229
- set_layer(new_module, n, new_bn)
230
- if isinstance(old_module, nn.Linear):
231
- # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
232
- num_features = state_dict[n + '.weight'][1]
233
- new_fc = nn.Linear(
234
- in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
235
- set_layer(new_module, n, new_fc)
236
- if hasattr(new_module, 'num_features'):
237
- new_module.num_features = num_features
238
- new_module.eval()
239
- parent_module.eval()
240
-
241
- return new_module
242
-
243
-
244
- def adapt_model_from_file(parent_module, model_variant):
245
- adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
246
- with open(adapt_file, 'r') as f:
247
- return adapt_model_from_string(parent_module, f.read().strip())
248
-
249
-
250
- def build_model_with_cfg(
251
- model_cls: Callable,
252
- variant: str,
253
- pretrained: bool,
254
- default_cfg: dict,
255
- model_cfg: dict = None,
256
- feature_cfg: dict = None,
257
- pretrained_strict: bool = True,
258
- pretrained_filter_fn: Callable = None,
259
- **kwargs):
260
- pruned = kwargs.pop('pruned', False)
261
- features = False
262
- feature_cfg = feature_cfg or {}
263
-
264
- if kwargs.pop('features_only', False):
265
- features = True
266
- feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
267
- if 'out_indices' in kwargs:
268
- feature_cfg['out_indices'] = kwargs.pop('out_indices')
269
-
270
- model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
271
- model.default_cfg = deepcopy(default_cfg)
272
-
273
- if pruned:
274
- model = adapt_model_from_file(model, variant)
275
-
276
- if pretrained:
277
- load_pretrained(
278
- model,
279
- num_classes=kwargs.get('num_classes', 0),
280
- in_chans=kwargs.get('in_chans', 3),
281
- filter_fn=pretrained_filter_fn, strict=pretrained_strict)
282
-
283
- if features:
284
- feature_cls = FeatureListNet
285
- if 'feature_cls' in feature_cfg:
286
- feature_cls = feature_cfg.pop('feature_cls')
287
- if isinstance(feature_cls, str):
288
- feature_cls = feature_cls.lower()
289
- if 'hook' in feature_cls:
290
- feature_cls = FeatureHookNet
291
- else:
292
- assert False, f'Unknown feature class {feature_cls}'
293
- model = feature_cls(model, **feature_cfg)
294
-
295
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/layer_helpers.py DELETED
@@ -1,21 +0,0 @@
1
- """ Layer/Module Helpers
2
- Hacked together by / Copyright 2020 Ross Wightman
3
- """
4
- from itertools import repeat
5
- import collections.abc
6
-
7
-
8
- # From PyTorch internals
9
- def _ntuple(n):
10
- def parse(x):
11
- if isinstance(x, collections.abc.Iterable):
12
- return x
13
- return tuple(repeat(x, n))
14
- return parse
15
-
16
-
17
- to_1tuple = _ntuple(1)
18
- to_2tuple = _ntuple(2)
19
- to_3tuple = _ntuple(3)
20
- to_4tuple = _ntuple(4)
21
- to_ntuple = _ntuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/misc_functions.py DELETED
@@ -1,68 +0,0 @@
1
- #
2
- # Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/
3
- # Written by Suraj Srinivas <[email protected]>
4
- #
5
-
6
- """ Misc helper functions """
7
-
8
- import cv2
9
- import numpy as np
10
- import subprocess
11
-
12
- import torch
13
- import torchvision.transforms as transforms
14
-
15
-
16
- class NormalizeInverse(transforms.Normalize):
17
- # Undo normalization on images
18
-
19
- def __init__(self, mean, std):
20
- mean = torch.as_tensor(mean)
21
- std = torch.as_tensor(std)
22
- std_inv = 1 / (std + 1e-7)
23
- mean_inv = -mean * std_inv
24
- super(NormalizeInverse, self).__init__(mean=mean_inv, std=std_inv)
25
-
26
- def __call__(self, tensor):
27
- return super(NormalizeInverse, self).__call__(tensor.clone())
28
-
29
-
30
- def create_folder(folder_name):
31
- try:
32
- subprocess.call(['mkdir', '-p', folder_name])
33
- except OSError:
34
- None
35
-
36
-
37
- def save_saliency_map(image, saliency_map, filename):
38
- """
39
- Save saliency map on image.
40
-
41
- Args:
42
- image: Tensor of size (3,H,W)
43
- saliency_map: Tensor of size (1,H,W)
44
- filename: string with complete path and file extension
45
-
46
- """
47
-
48
- image = image.data.cpu().numpy()
49
- saliency_map = saliency_map.data.cpu().numpy()
50
-
51
- saliency_map = saliency_map - saliency_map.min()
52
- saliency_map = saliency_map / saliency_map.max()
53
- saliency_map = saliency_map.clip(0, 1)
54
-
55
- saliency_map = np.uint8(saliency_map * 255).transpose(1, 2, 0)
56
- saliency_map = cv2.resize(saliency_map, (224, 224))
57
-
58
- image = np.uint8(image * 255).transpose(1, 2, 0)
59
- image = cv2.resize(image, (224, 224))
60
-
61
- # Apply JET colormap
62
- color_heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET)
63
-
64
- # Combine image with heatmap
65
- img_with_heatmap = np.float32(color_heatmap) + np.float32(image)
66
- img_with_heatmap = img_with_heatmap / np.max(img_with_heatmap)
67
-
68
- cv2.imwrite(filename, np.uint8(255 * img_with_heatmap))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__init__.py DELETED
File without changes
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (223 Bytes)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__pycache__/layers_lrp.cpython-310.pyc DELETED
Binary file (9.31 kB)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__pycache__/layers_ours.cpython-310.pyc DELETED
Binary file (9.75 kB)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/layers_lrp.py DELETED
@@ -1,261 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6
- 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7
- 'LayerNorm', 'AddEye']
8
-
9
-
10
- def safe_divide(a, b):
11
- den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12
- den = den + den.eq(0).type(den.type()) * 1e-9
13
- return a / den * b.ne(0).type(b.type())
14
-
15
-
16
- def forward_hook(self, input, output):
17
- if type(input[0]) in (list, tuple):
18
- self.X = []
19
- for i in input[0]:
20
- x = i.detach()
21
- x.requires_grad = True
22
- self.X.append(x)
23
- else:
24
- self.X = input[0].detach()
25
- self.X.requires_grad = True
26
-
27
- self.Y = output
28
-
29
-
30
- def backward_hook(self, grad_input, grad_output):
31
- self.grad_input = grad_input
32
- self.grad_output = grad_output
33
-
34
-
35
- class RelProp(nn.Module):
36
- def __init__(self):
37
- super(RelProp, self).__init__()
38
- # if not self.training:
39
- self.register_forward_hook(forward_hook)
40
-
41
- def gradprop(self, Z, X, S):
42
- C = torch.autograd.grad(Z, X, S, retain_graph=True)
43
- return C
44
-
45
- def relprop(self, R, alpha):
46
- return R
47
-
48
-
49
- class RelPropSimple(RelProp):
50
- def relprop(self, R, alpha):
51
- Z = self.forward(self.X)
52
- S = safe_divide(R, Z)
53
- C = self.gradprop(Z, self.X, S)
54
-
55
- if torch.is_tensor(self.X) == False:
56
- outputs = []
57
- outputs.append(self.X[0] * C[0])
58
- outputs.append(self.X[1] * C[1])
59
- else:
60
- outputs = self.X * (C[0])
61
- return outputs
62
-
63
- class AddEye(RelPropSimple):
64
- # input of shape B, C, seq_len, seq_len
65
- def forward(self, input):
66
- return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
67
-
68
- class ReLU(nn.ReLU, RelProp):
69
- pass
70
-
71
- class GELU(nn.GELU, RelProp):
72
- pass
73
-
74
- class Softmax(nn.Softmax, RelProp):
75
- pass
76
-
77
- class LayerNorm(nn.LayerNorm, RelProp):
78
- pass
79
-
80
- class Dropout(nn.Dropout, RelProp):
81
- pass
82
-
83
-
84
- class MaxPool2d(nn.MaxPool2d, RelPropSimple):
85
- pass
86
-
87
- class LayerNorm(nn.LayerNorm, RelProp):
88
- pass
89
-
90
- class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
91
- pass
92
-
93
-
94
- class AvgPool2d(nn.AvgPool2d, RelPropSimple):
95
- pass
96
-
97
-
98
- class Add(RelPropSimple):
99
- def forward(self, inputs):
100
- return torch.add(*inputs)
101
-
102
- class einsum(RelPropSimple):
103
- def __init__(self, equation):
104
- super().__init__()
105
- self.equation = equation
106
- def forward(self, *operands):
107
- return torch.einsum(self.equation, *operands)
108
-
109
- class IndexSelect(RelProp):
110
- def forward(self, inputs, dim, indices):
111
- self.__setattr__('dim', dim)
112
- self.__setattr__('indices', indices)
113
-
114
- return torch.index_select(inputs, dim, indices)
115
-
116
- def relprop(self, R, alpha):
117
- Z = self.forward(self.X, self.dim, self.indices)
118
- S = safe_divide(R, Z)
119
- C = self.gradprop(Z, self.X, S)
120
-
121
- if torch.is_tensor(self.X) == False:
122
- outputs = []
123
- outputs.append(self.X[0] * C[0])
124
- outputs.append(self.X[1] * C[1])
125
- else:
126
- outputs = self.X * (C[0])
127
- return outputs
128
-
129
-
130
-
131
- class Clone(RelProp):
132
- def forward(self, input, num):
133
- self.__setattr__('num', num)
134
- outputs = []
135
- for _ in range(num):
136
- outputs.append(input)
137
-
138
- return outputs
139
-
140
- def relprop(self, R, alpha):
141
- Z = []
142
- for _ in range(self.num):
143
- Z.append(self.X)
144
- S = [safe_divide(r, z) for r, z in zip(R, Z)]
145
- C = self.gradprop(Z, self.X, S)[0]
146
-
147
- R = self.X * C
148
-
149
- return R
150
-
151
- class Cat(RelProp):
152
- def forward(self, inputs, dim):
153
- self.__setattr__('dim', dim)
154
- return torch.cat(inputs, dim)
155
-
156
- def relprop(self, R, alpha):
157
- Z = self.forward(self.X, self.dim)
158
- S = safe_divide(R, Z)
159
- C = self.gradprop(Z, self.X, S)
160
-
161
- outputs = []
162
- for x, c in zip(self.X, C):
163
- outputs.append(x * c)
164
-
165
- return outputs
166
-
167
-
168
- class Sequential(nn.Sequential):
169
- def relprop(self, R, alpha):
170
- for m in reversed(self._modules.values()):
171
- R = m.relprop(R, alpha)
172
- return R
173
-
174
-
175
- class BatchNorm2d(nn.BatchNorm2d, RelProp):
176
- def relprop(self, R, alpha):
177
- X = self.X
178
- beta = 1 - alpha
179
- weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
180
- (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
181
- Z = X * weight + 1e-9
182
- S = R / Z
183
- Ca = S * weight
184
- R = self.X * (Ca)
185
- return R
186
-
187
-
188
- class Linear(nn.Linear, RelProp):
189
- def relprop(self, R, alpha):
190
- beta = alpha - 1
191
- pw = torch.clamp(self.weight, min=0)
192
- nw = torch.clamp(self.weight, max=0)
193
- px = torch.clamp(self.X, min=0)
194
- nx = torch.clamp(self.X, max=0)
195
-
196
- def f(w1, w2, x1, x2):
197
- Z1 = F.linear(x1, w1)
198
- Z2 = F.linear(x2, w2)
199
- S1 = safe_divide(R, Z1)
200
- S2 = safe_divide(R, Z2)
201
- C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
202
- C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
203
-
204
- return C1 + C2
205
-
206
- activator_relevances = f(pw, nw, px, nx)
207
- inhibitor_relevances = f(nw, pw, px, nx)
208
-
209
- R = alpha * activator_relevances - beta * inhibitor_relevances
210
-
211
- return R
212
-
213
-
214
- class Conv2d(nn.Conv2d, RelProp):
215
- def gradprop2(self, DY, weight):
216
- Z = self.forward(self.X)
217
-
218
- output_padding = self.X.size()[2] - (
219
- (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
220
-
221
- return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
222
-
223
- def relprop(self, R, alpha):
224
- if self.X.shape[1] == 3:
225
- pw = torch.clamp(self.weight, min=0)
226
- nw = torch.clamp(self.weight, max=0)
227
- X = self.X
228
- L = self.X * 0 + \
229
- torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
230
- keepdim=True)[0]
231
- H = self.X * 0 + \
232
- torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
233
- keepdim=True)[0]
234
- Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
235
- torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
236
- torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
237
-
238
- S = R / Za
239
- C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
240
- R = C
241
- else:
242
- beta = alpha - 1
243
- pw = torch.clamp(self.weight, min=0)
244
- nw = torch.clamp(self.weight, max=0)
245
- px = torch.clamp(self.X, min=0)
246
- nx = torch.clamp(self.X, max=0)
247
-
248
- def f(w1, w2, x1, x2):
249
- Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
250
- Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
251
- S1 = safe_divide(R, Z1)
252
- S2 = safe_divide(R, Z2)
253
- C1 = x1 * self.gradprop(Z1, x1, S1)[0]
254
- C2 = x2 * self.gradprop(Z2, x2, S2)[0]
255
- return C1 + C2
256
-
257
- activator_relevances = f(pw, nw, px, nx)
258
- inhibitor_relevances = f(nw, pw, px, nx)
259
-
260
- R = alpha * activator_relevances - beta * inhibitor_relevances
261
- return R
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/layers_ours.py DELETED
@@ -1,280 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6
- 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7
- 'LayerNorm', 'AddEye']
8
-
9
-
10
- def safe_divide(a, b):
11
- den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12
- den = den + den.eq(0).type(den.type()) * 1e-9
13
- return a / den * b.ne(0).type(b.type())
14
-
15
-
16
- def forward_hook(self, input, output):
17
- if type(input[0]) in (list, tuple):
18
- self.X = []
19
- for i in input[0]:
20
- x = i.detach()
21
- x.requires_grad = True
22
- self.X.append(x)
23
- else:
24
- self.X = input[0].detach()
25
- self.X.requires_grad = True
26
-
27
- self.Y = output
28
-
29
-
30
- def backward_hook(self, grad_input, grad_output):
31
- self.grad_input = grad_input
32
- self.grad_output = grad_output
33
-
34
-
35
- class RelProp(nn.Module):
36
- def __init__(self):
37
- super(RelProp, self).__init__()
38
- # if not self.training:
39
- self.register_forward_hook(forward_hook)
40
-
41
- def gradprop(self, Z, X, S):
42
- C = torch.autograd.grad(Z, X, S, retain_graph=True)
43
- return C
44
-
45
- def relprop(self, R, alpha):
46
- return R
47
-
48
- class RelPropSimple(RelProp):
49
- def relprop(self, R, alpha):
50
- Z = self.forward(self.X)
51
- S = safe_divide(R, Z)
52
- C = self.gradprop(Z, self.X, S)
53
-
54
- if torch.is_tensor(self.X) == False:
55
- outputs = []
56
- outputs.append(self.X[0] * C[0])
57
- outputs.append(self.X[1] * C[1])
58
- else:
59
- outputs = self.X * (C[0])
60
- return outputs
61
-
62
- class AddEye(RelPropSimple):
63
- # input of shape B, C, seq_len, seq_len
64
- def forward(self, input):
65
- return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
66
-
67
- class ReLU(nn.ReLU, RelProp):
68
- pass
69
-
70
- class GELU(nn.GELU, RelProp):
71
- pass
72
-
73
- class Softmax(nn.Softmax, RelProp):
74
- pass
75
-
76
- class LayerNorm(nn.LayerNorm, RelProp):
77
- pass
78
-
79
- class Dropout(nn.Dropout, RelProp):
80
- pass
81
-
82
-
83
- class MaxPool2d(nn.MaxPool2d, RelPropSimple):
84
- pass
85
-
86
- class LayerNorm(nn.LayerNorm, RelProp):
87
- pass
88
-
89
- class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
90
- pass
91
-
92
-
93
- class AvgPool2d(nn.AvgPool2d, RelPropSimple):
94
- pass
95
-
96
-
97
- class Add(RelPropSimple):
98
- def forward(self, inputs):
99
- return torch.add(*inputs)
100
-
101
- def relprop(self, R, alpha):
102
- Z = self.forward(self.X)
103
- S = safe_divide(R, Z)
104
- C = self.gradprop(Z, self.X, S)
105
-
106
- a = self.X[0] * C[0]
107
- b = self.X[1] * C[1]
108
-
109
- a_sum = a.sum()
110
- b_sum = b.sum()
111
-
112
- a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
113
- b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
114
-
115
- a = a * safe_divide(a_fact, a.sum())
116
- b = b * safe_divide(b_fact, b.sum())
117
-
118
- outputs = [a, b]
119
-
120
- return outputs
121
-
122
- class einsum(RelPropSimple):
123
- def __init__(self, equation):
124
- super().__init__()
125
- self.equation = equation
126
- def forward(self, *operands):
127
- return torch.einsum(self.equation, *operands)
128
-
129
- class IndexSelect(RelProp):
130
- def forward(self, inputs, dim, indices):
131
- self.__setattr__('dim', dim)
132
- self.__setattr__('indices', indices)
133
-
134
- return torch.index_select(inputs, dim, indices)
135
-
136
- def relprop(self, R, alpha):
137
- Z = self.forward(self.X, self.dim, self.indices)
138
- S = safe_divide(R, Z)
139
- C = self.gradprop(Z, self.X, S)
140
-
141
- if torch.is_tensor(self.X) == False:
142
- outputs = []
143
- outputs.append(self.X[0] * C[0])
144
- outputs.append(self.X[1] * C[1])
145
- else:
146
- outputs = self.X * (C[0])
147
- return outputs
148
-
149
-
150
-
151
- class Clone(RelProp):
152
- def forward(self, input, num):
153
- self.__setattr__('num', num)
154
- outputs = []
155
- for _ in range(num):
156
- outputs.append(input)
157
-
158
- return outputs
159
-
160
- def relprop(self, R, alpha):
161
- Z = []
162
- for _ in range(self.num):
163
- Z.append(self.X)
164
- S = [safe_divide(r, z) for r, z in zip(R, Z)]
165
- C = self.gradprop(Z, self.X, S)[0]
166
-
167
- R = self.X * C
168
-
169
- return R
170
-
171
- class Cat(RelProp):
172
- def forward(self, inputs, dim):
173
- self.__setattr__('dim', dim)
174
- return torch.cat(inputs, dim)
175
-
176
- def relprop(self, R, alpha):
177
- Z = self.forward(self.X, self.dim)
178
- S = safe_divide(R, Z)
179
- C = self.gradprop(Z, self.X, S)
180
-
181
- outputs = []
182
- for x, c in zip(self.X, C):
183
- outputs.append(x * c)
184
-
185
- return outputs
186
-
187
-
188
- class Sequential(nn.Sequential):
189
- def relprop(self, R, alpha):
190
- for m in reversed(self._modules.values()):
191
- R = m.relprop(R, alpha)
192
- return R
193
-
194
- class BatchNorm2d(nn.BatchNorm2d, RelProp):
195
- def relprop(self, R, alpha):
196
- X = self.X
197
- beta = 1 - alpha
198
- weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
199
- (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
200
- Z = X * weight + 1e-9
201
- S = R / Z
202
- Ca = S * weight
203
- R = self.X * (Ca)
204
- return R
205
-
206
-
207
- class Linear(nn.Linear, RelProp):
208
- def relprop(self, R, alpha):
209
- beta = alpha - 1
210
- pw = torch.clamp(self.weight, min=0)
211
- nw = torch.clamp(self.weight, max=0)
212
- px = torch.clamp(self.X, min=0)
213
- nx = torch.clamp(self.X, max=0)
214
-
215
- def f(w1, w2, x1, x2):
216
- Z1 = F.linear(x1, w1)
217
- Z2 = F.linear(x2, w2)
218
- S1 = safe_divide(R, Z1 + Z2)
219
- S2 = safe_divide(R, Z1 + Z2)
220
- C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
221
- C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
222
-
223
- return C1 + C2
224
-
225
- activator_relevances = f(pw, nw, px, nx)
226
- inhibitor_relevances = f(nw, pw, px, nx)
227
-
228
- R = alpha * activator_relevances - beta * inhibitor_relevances
229
-
230
- return R
231
-
232
-
233
- class Conv2d(nn.Conv2d, RelProp):
234
- def gradprop2(self, DY, weight):
235
- Z = self.forward(self.X)
236
-
237
- output_padding = self.X.size()[2] - (
238
- (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
239
-
240
- return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
241
-
242
- def relprop(self, R, alpha):
243
- if self.X.shape[1] == 3:
244
- pw = torch.clamp(self.weight, min=0)
245
- nw = torch.clamp(self.weight, max=0)
246
- X = self.X
247
- L = self.X * 0 + \
248
- torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
249
- keepdim=True)[0]
250
- H = self.X * 0 + \
251
- torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
252
- keepdim=True)[0]
253
- Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
254
- torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
255
- torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
256
-
257
- S = R / Za
258
- C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
259
- R = C
260
- else:
261
- beta = alpha - 1
262
- pw = torch.clamp(self.weight, min=0)
263
- nw = torch.clamp(self.weight, max=0)
264
- px = torch.clamp(self.X, min=0)
265
- nx = torch.clamp(self.X, max=0)
266
-
267
- def f(w1, w2, x1, x2):
268
- Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
269
- Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
270
- S1 = safe_divide(R, Z1)
271
- S2 = safe_divide(R, Z2)
272
- C1 = x1 * self.gradprop(Z1, x1, S1)[0]
273
- C2 = x2 * self.gradprop(Z2, x2, S2)[0]
274
- return C1 + C2
275
-
276
- activator_relevances = f(pw, nw, px, nx)
277
- inhibitor_relevances = f(nw, pw, px, nx)
278
-
279
- R = alpha * activator_relevances - beta * inhibitor_relevances
280
- return R
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/pertubation_eval_from_hdf5.py DELETED
@@ -1,232 +0,0 @@
1
- import torch
2
- import os
3
- from tqdm import tqdm
4
- import numpy as np
5
- import argparse
6
-
7
- # Import saliency methods and models
8
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_explanation_generator import Baselines
9
- from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_new import vit_base_patch16_224
10
- # from models.vgg import vgg19
11
- import glob
12
-
13
- from dataset.expl_hdf5 import ImagenetResults
14
-
15
-
16
- def normalize(tensor,
17
- mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
18
- dtype = tensor.dtype
19
- mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
20
- std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
21
- tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
22
- return tensor
23
-
24
-
25
- def eval(args):
26
- num_samples = 0
27
- num_correct_model = np.zeros((len(imagenet_ds,)))
28
- dissimilarity_model = np.zeros((len(imagenet_ds,)))
29
- model_index = 0
30
-
31
- if args.scale == 'per':
32
- base_size = 224 * 224
33
- perturbation_steps = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
34
- elif args.scale == '100':
35
- base_size = 100
36
- perturbation_steps = [5, 10, 15, 20, 25, 30, 35, 40, 45]
37
- else:
38
- raise Exception('scale not valid')
39
-
40
- num_correct_pertub = np.zeros((9, len(imagenet_ds)))
41
- dissimilarity_pertub = np.zeros((9, len(imagenet_ds)))
42
- logit_diff_pertub = np.zeros((9, len(imagenet_ds)))
43
- prob_diff_pertub = np.zeros((9, len(imagenet_ds)))
44
- perturb_index = 0
45
-
46
- for batch_idx, (data, vis, target) in enumerate(tqdm(sample_loader)):
47
- # Update the number of samples
48
- num_samples += len(data)
49
-
50
- data = data.to(device)
51
- vis = vis.to(device)
52
- target = target.to(device)
53
- norm_data = normalize(data.clone())
54
-
55
- # Compute model accuracy
56
- pred = model(norm_data)
57
- pred_probabilities = torch.softmax(pred, dim=1)
58
- pred_org_logit = pred.data.max(1, keepdim=True)[0].squeeze(1)
59
- pred_org_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1)
60
- pred_class = pred.data.max(1, keepdim=True)[1].squeeze(1)
61
- tgt_pred = (target == pred_class).type(target.type()).data.cpu().numpy()
62
- num_correct_model[model_index:model_index+len(tgt_pred)] = tgt_pred
63
-
64
- probs = torch.softmax(pred, dim=1)
65
- target_probs = torch.gather(probs, 1, target[:, None])[:, 0]
66
- second_probs = probs.data.topk(2, dim=1)[0][:, 1]
67
- temp = torch.log(target_probs / second_probs).data.cpu().numpy()
68
- dissimilarity_model[model_index:model_index+len(temp)] = temp
69
-
70
- if args.wrong:
71
- wid = np.argwhere(tgt_pred == 0).flatten()
72
- if len(wid) == 0:
73
- continue
74
- wid = torch.from_numpy(wid).to(vis.device)
75
- vis = vis.index_select(0, wid)
76
- data = data.index_select(0, wid)
77
- target = target.index_select(0, wid)
78
-
79
- # Save original shape
80
- org_shape = data.shape
81
-
82
- if args.neg:
83
- vis = -vis
84
-
85
- vis = vis.reshape(org_shape[0], -1)
86
-
87
- for i in range(len(perturbation_steps)):
88
- _data = data.clone()
89
-
90
- _, idx = torch.topk(vis, int(base_size * perturbation_steps[i]), dim=-1)
91
- idx = idx.unsqueeze(1).repeat(1, org_shape[1], 1)
92
- _data = _data.reshape(org_shape[0], org_shape[1], -1)
93
- _data = _data.scatter_(-1, idx, 0)
94
- _data = _data.reshape(*org_shape)
95
-
96
- _norm_data = normalize(_data)
97
-
98
- out = model(_norm_data)
99
-
100
- pred_probabilities = torch.softmax(out, dim=1)
101
- pred_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1)
102
- diff = (pred_prob - pred_org_prob).data.cpu().numpy()
103
- prob_diff_pertub[i, perturb_index:perturb_index+len(diff)] = diff
104
-
105
- pred_logit = out.data.max(1, keepdim=True)[0].squeeze(1)
106
- diff = (pred_logit - pred_org_logit).data.cpu().numpy()
107
- logit_diff_pertub[i, perturb_index:perturb_index+len(diff)] = diff
108
-
109
- target_class = out.data.max(1, keepdim=True)[1].squeeze(1)
110
- temp = (target == target_class).type(target.type()).data.cpu().numpy()
111
- num_correct_pertub[i, perturb_index:perturb_index+len(temp)] = temp
112
-
113
- probs_pertub = torch.softmax(out, dim=1)
114
- target_probs = torch.gather(probs_pertub, 1, target[:, None])[:, 0]
115
- second_probs = probs_pertub.data.topk(2, dim=1)[0][:, 1]
116
- temp = torch.log(target_probs / second_probs).data.cpu().numpy()
117
- dissimilarity_pertub[i, perturb_index:perturb_index+len(temp)] = temp
118
-
119
- model_index += len(target)
120
- perturb_index += len(target)
121
-
122
- np.save(os.path.join(args.experiment_dir, 'model_hits.npy'), num_correct_model)
123
- np.save(os.path.join(args.experiment_dir, 'model_dissimilarities.npy'), dissimilarity_model)
124
- np.save(os.path.join(args.experiment_dir, 'perturbations_hits.npy'), num_correct_pertub[:, :perturb_index])
125
- np.save(os.path.join(args.experiment_dir, 'perturbations_dissimilarities.npy'), dissimilarity_pertub[:, :perturb_index])
126
- np.save(os.path.join(args.experiment_dir, 'perturbations_logit_diff.npy'), logit_diff_pertub[:, :perturb_index])
127
- np.save(os.path.join(args.experiment_dir, 'perturbations_prob_diff.npy'), prob_diff_pertub[:, :perturb_index])
128
-
129
- print(np.mean(num_correct_model), np.std(num_correct_model))
130
- print(np.mean(dissimilarity_model), np.std(dissimilarity_model))
131
- print(perturbation_steps)
132
- print(np.mean(num_correct_pertub, axis=1), np.std(num_correct_pertub, axis=1))
133
- print(np.mean(dissimilarity_pertub, axis=1), np.std(dissimilarity_pertub, axis=1))
134
-
135
-
136
- if __name__ == "__main__":
137
- parser = argparse.ArgumentParser(description='Train a segmentation')
138
- parser.add_argument('--batch-size', type=int,
139
- default=16,
140
- help='')
141
- parser.add_argument('--neg', type=bool,
142
- default=True,
143
- help='')
144
- parser.add_argument('--value', action='store_true',
145
- default=False,
146
- help='')
147
- parser.add_argument('--scale', type=str,
148
- default='per',
149
- choices=['per', '100'],
150
- help='')
151
- parser.add_argument('--method', type=str,
152
- default='grad_rollout',
153
- choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'v_gradcam', 'lrp_last_layer',
154
- 'lrp_second_layer', 'gradcam',
155
- 'attn_last_layer', 'attn_gradcam', 'input_grads'],
156
- help='')
157
- parser.add_argument('--vis-class', type=str,
158
- default='top',
159
- choices=['top', 'target', 'index'],
160
- help='')
161
- parser.add_argument('--wrong', action='store_true',
162
- default=False,
163
- help='')
164
- parser.add_argument('--class-id', type=int,
165
- default=0,
166
- help='')
167
- parser.add_argument('--is-ablation', type=bool,
168
- default=False,
169
- help='')
170
- args = parser.parse_args()
171
-
172
- torch.multiprocessing.set_start_method('spawn')
173
-
174
- # PATH variables
175
- PATH = os.path.dirname(os.path.abspath(__file__)) + '/'
176
- dataset = PATH + 'dataset/'
177
- os.makedirs(os.path.join(PATH, 'experiments'), exist_ok=True)
178
- os.makedirs(os.path.join(PATH, 'experiments/perturbations'), exist_ok=True)
179
-
180
- exp_name = args.method
181
- exp_name += '_neg' if args.neg else '_pos'
182
- print(exp_name)
183
-
184
- if args.vis_class == 'index':
185
- args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}_{}'.format(exp_name,
186
- args.vis_class,
187
- args.class_id))
188
- else:
189
- ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
190
- args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}/{}'.format(exp_name,
191
- args.vis_class, ablation_fold))
192
- # args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}'.format(exp_name,
193
- # args.vis_class))
194
-
195
- if args.wrong:
196
- args.runs_dir += '_wrong'
197
-
198
- experiments = sorted(glob.glob(os.path.join(args.runs_dir, 'experiment_*')))
199
- experiment_id = int(experiments[-1].split('_')[-1]) + 1 if experiments else 0
200
- args.experiment_dir = os.path.join(args.runs_dir, 'experiment_{}'.format(str(experiment_id)))
201
- os.makedirs(args.experiment_dir, exist_ok=True)
202
-
203
- cuda = torch.cuda.is_available()
204
- device = torch.device("cuda" if cuda else "cpu")
205
-
206
- if args.vis_class == 'index':
207
- vis_method_dir = os.path.join(PATH,'visualizations/{}/{}_{}'.format(args.method,
208
- args.vis_class,
209
- args.class_id))
210
- else:
211
- ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
212
- vis_method_dir = os.path.join(PATH,'visualizations/{}/{}/{}'.format(args.method,
213
- args.vis_class, ablation_fold))
214
- # vis_method_dir = os.path.join(PATH, 'visualizations/{}/{}'.format(args.method,
215
- # args.vis_class))
216
-
217
- # imagenet_ds = ImagenetResults('visualizations/{}'.format(args.method))
218
- imagenet_ds = ImagenetResults(vis_method_dir)
219
-
220
- # Model
221
- model = vit_base_patch16_224(pretrained=True).cuda()
222
- model.eval()
223
-
224
- save_path = PATH + 'results/'
225
-
226
- sample_loader = torch.utils.data.DataLoader(
227
- imagenet_ds,
228
- batch_size=args.batch_size,
229
- num_workers=2,
230
- shuffle=False)
231
-
232
- eval(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__init__.py DELETED
File without changes
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (221 Bytes)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__pycache__/confusionmatrix.cpython-310.pyc DELETED
Binary file (3.55 kB)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__pycache__/iou.cpython-310.pyc DELETED
Binary file (3.64 kB)
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__pycache__/metric.cpython-310.pyc DELETED
Binary file (821 Bytes)