UPstud commited on
Commit
e9c9507
Β·
verified Β·
1 Parent(s): 85ab3f2

Upload 40 files

Browse files
Files changed (40) hide show
  1. taming/__pycache__/util.cpython-38.pyc +0 -0
  2. taming/data/ade20k.py +124 -0
  3. taming/data/annotated_objects_coco.py +139 -0
  4. taming/data/annotated_objects_dataset.py +218 -0
  5. taming/data/annotated_objects_open_images.py +137 -0
  6. taming/data/base.py +70 -0
  7. taming/data/coco.py +176 -0
  8. taming/data/conditional_builder/objects_bbox.py +60 -0
  9. taming/data/conditional_builder/objects_center_points.py +168 -0
  10. taming/data/conditional_builder/utils.py +105 -0
  11. taming/data/custom.py +38 -0
  12. taming/data/faceshq.py +134 -0
  13. taming/data/helper_types.py +49 -0
  14. taming/data/image_transforms.py +132 -0
  15. taming/data/imagenet.py +558 -0
  16. taming/data/open_images_helper.py +379 -0
  17. taming/data/sflckr.py +91 -0
  18. taming/data/utils.py +169 -0
  19. taming/lr_scheduler.py +34 -0
  20. taming/models/cond_transformer.py +352 -0
  21. taming/models/dummy_cond_stage.py +22 -0
  22. taming/models/vqgan.py +404 -0
  23. taming/modules/__pycache__/util.cpython-38.pyc +0 -0
  24. taming/modules/autoencoder/lpips/vgg.pth +3 -0
  25. taming/modules/diffusionmodules/model.py +776 -0
  26. taming/modules/discriminator/__pycache__/model.cpython-38.pyc +0 -0
  27. taming/modules/discriminator/model.py +67 -0
  28. taming/modules/losses/__init__.py +2 -0
  29. taming/modules/losses/__pycache__/__init__.cpython-38.pyc +0 -0
  30. taming/modules/losses/__pycache__/lpips.cpython-38.pyc +0 -0
  31. taming/modules/losses/__pycache__/vqperceptual.cpython-38.pyc +0 -0
  32. taming/modules/losses/lpips.py +123 -0
  33. taming/modules/losses/segmentation.py +22 -0
  34. taming/modules/losses/vqperceptual.py +241 -0
  35. taming/modules/misc/coord.py +31 -0
  36. taming/modules/transformer/mingpt.py +415 -0
  37. taming/modules/transformer/permuter.py +248 -0
  38. taming/modules/util.py +130 -0
  39. taming/modules/vqvae/quantize.py +445 -0
  40. taming/util.py +157 -0
taming/__pycache__/util.cpython-38.pyc ADDED
Binary file (4.12 kB). View file
 
taming/data/ade20k.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import albumentations
5
+ from PIL import Image
6
+ from torch.utils.data import Dataset
7
+
8
+ from taming.data.sflckr import SegmentationBase # for examples included in repo
9
+
10
+
11
+ class Examples(SegmentationBase):
12
+ def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
13
+ super().__init__(data_csv="data/ade20k_examples.txt",
14
+ data_root="data/ade20k_images",
15
+ segmentation_root="data/ade20k_segmentations",
16
+ size=size, random_crop=random_crop,
17
+ interpolation=interpolation,
18
+ n_labels=151, shift_segmentation=False)
19
+
20
+
21
+ # With semantic map and scene label
22
+ class ADE20kBase(Dataset):
23
+ def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
24
+ self.split = self.get_split()
25
+ self.n_labels = 151 # unknown + 150
26
+ self.data_csv = {"train": "data/ade20k_train.txt",
27
+ "validation": "data/ade20k_test.txt"}[self.split]
28
+ self.data_root = "data/ade20k_root"
29
+ with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
30
+ self.scene_categories = f.read().splitlines()
31
+ self.scene_categories = dict(line.split() for line in self.scene_categories)
32
+ with open(self.data_csv, "r") as f:
33
+ self.image_paths = f.read().splitlines()
34
+ self._length = len(self.image_paths)
35
+ self.labels = {
36
+ "relative_file_path_": [l for l in self.image_paths],
37
+ "file_path_": [os.path.join(self.data_root, "images", l)
38
+ for l in self.image_paths],
39
+ "relative_segmentation_path_": [l.replace(".jpg", ".png")
40
+ for l in self.image_paths],
41
+ "segmentation_path_": [os.path.join(self.data_root, "annotations",
42
+ l.replace(".jpg", ".png"))
43
+ for l in self.image_paths],
44
+ "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
45
+ for l in self.image_paths],
46
+ }
47
+
48
+ size = None if size is not None and size<=0 else size
49
+ self.size = size
50
+ if crop_size is None:
51
+ self.crop_size = size if size is not None else None
52
+ else:
53
+ self.crop_size = crop_size
54
+ if self.size is not None:
55
+ self.interpolation = interpolation
56
+ self.interpolation = {
57
+ "nearest": cv2.INTER_NEAREST,
58
+ "bilinear": cv2.INTER_LINEAR,
59
+ "bicubic": cv2.INTER_CUBIC,
60
+ "area": cv2.INTER_AREA,
61
+ "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
62
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
63
+ interpolation=self.interpolation)
64
+ self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
65
+ interpolation=cv2.INTER_NEAREST)
66
+
67
+ if crop_size is not None:
68
+ self.center_crop = not random_crop
69
+ if self.center_crop:
70
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
71
+ else:
72
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
73
+ self.preprocessor = self.cropper
74
+
75
+ def __len__(self):
76
+ return self._length
77
+
78
+ def __getitem__(self, i):
79
+ example = dict((k, self.labels[k][i]) for k in self.labels)
80
+ image = Image.open(example["file_path_"])
81
+ if not image.mode == "RGB":
82
+ image = image.convert("RGB")
83
+ image = np.array(image).astype(np.uint8)
84
+ if self.size is not None:
85
+ image = self.image_rescaler(image=image)["image"]
86
+ segmentation = Image.open(example["segmentation_path_"])
87
+ segmentation = np.array(segmentation).astype(np.uint8)
88
+ if self.size is not None:
89
+ segmentation = self.segmentation_rescaler(image=segmentation)["image"]
90
+ if self.size is not None:
91
+ processed = self.preprocessor(image=image, mask=segmentation)
92
+ else:
93
+ processed = {"image": image, "mask": segmentation}
94
+ example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
95
+ segmentation = processed["mask"]
96
+ onehot = np.eye(self.n_labels)[segmentation]
97
+ example["segmentation"] = onehot
98
+ return example
99
+
100
+
101
+ class ADE20kTrain(ADE20kBase):
102
+ # default to random_crop=True
103
+ def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
104
+ super().__init__(config=config, size=size, random_crop=random_crop,
105
+ interpolation=interpolation, crop_size=crop_size)
106
+
107
+ def get_split(self):
108
+ return "train"
109
+
110
+
111
+ class ADE20kValidation(ADE20kBase):
112
+ def get_split(self):
113
+ return "validation"
114
+
115
+
116
+ if __name__ == "__main__":
117
+ dset = ADE20kValidation()
118
+ ex = dset[0]
119
+ for k in ["image", "scene_category", "segmentation"]:
120
+ print(type(ex[k]))
121
+ try:
122
+ print(ex[k].shape)
123
+ except:
124
+ print(ex[k])
taming/data/annotated_objects_coco.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from itertools import chain
3
+ from pathlib import Path
4
+ from typing import Iterable, Dict, List, Callable, Any
5
+ from collections import defaultdict
6
+
7
+ from tqdm import tqdm
8
+
9
+ from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
10
+ from taming.data.helper_types import Annotation, ImageDescription, Category
11
+
12
+ COCO_PATH_STRUCTURE = {
13
+ 'train': {
14
+ 'top_level': '',
15
+ 'instances_annotations': 'annotations/instances_train2017.json',
16
+ 'stuff_annotations': 'annotations/stuff_train2017.json',
17
+ 'files': 'train2017'
18
+ },
19
+ 'validation': {
20
+ 'top_level': '',
21
+ 'instances_annotations': 'annotations/instances_val2017.json',
22
+ 'stuff_annotations': 'annotations/stuff_val2017.json',
23
+ 'files': 'val2017'
24
+ }
25
+ }
26
+
27
+
28
+ def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
29
+ return {
30
+ str(img['id']): ImageDescription(
31
+ id=img['id'],
32
+ license=img.get('license'),
33
+ file_name=img['file_name'],
34
+ coco_url=img['coco_url'],
35
+ original_size=(img['width'], img['height']),
36
+ date_captured=img.get('date_captured'),
37
+ flickr_url=img.get('flickr_url')
38
+ )
39
+ for img in description_json
40
+ }
41
+
42
+
43
+ def load_categories(category_json: Iterable) -> Dict[str, Category]:
44
+ return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
45
+ for cat in category_json if cat['name'] != 'other'}
46
+
47
+
48
+ def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
49
+ category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
50
+ annotations = defaultdict(list)
51
+ total = sum(len(a) for a in annotations_json)
52
+ for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
53
+ image_id = str(ann['image_id'])
54
+ if image_id not in image_descriptions:
55
+ raise ValueError(f'image_id [{image_id}] has no image description.')
56
+ category_id = ann['category_id']
57
+ try:
58
+ category_no = category_no_for_id(str(category_id))
59
+ except KeyError:
60
+ continue
61
+
62
+ width, height = image_descriptions[image_id].original_size
63
+ bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
64
+
65
+ annotations[image_id].append(
66
+ Annotation(
67
+ id=ann['id'],
68
+ area=bbox[2]*bbox[3], # use bbox area
69
+ is_group_of=ann['iscrowd'],
70
+ image_id=ann['image_id'],
71
+ bbox=bbox,
72
+ category_id=str(category_id),
73
+ category_no=category_no
74
+ )
75
+ )
76
+ return dict(annotations)
77
+
78
+
79
+ class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
80
+ def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
81
+ """
82
+ @param data_path: is the path to the following folder structure:
83
+ coco/
84
+ β”œβ”€β”€ annotations
85
+ β”‚ β”œβ”€β”€ instances_train2017.json
86
+ β”‚ β”œβ”€β”€ instances_val2017.json
87
+ β”‚ β”œβ”€β”€ stuff_train2017.json
88
+ β”‚ └── stuff_val2017.json
89
+ β”œβ”€β”€ train2017
90
+ β”‚ β”œβ”€β”€ 000000000009.jpg
91
+ β”‚ β”œβ”€β”€ 000000000025.jpg
92
+ β”‚ └── ...
93
+ β”œβ”€β”€ val2017
94
+ β”‚ β”œβ”€β”€ 000000000139.jpg
95
+ β”‚ β”œβ”€β”€ 000000000285.jpg
96
+ β”‚ └── ...
97
+ @param: split: one of 'train' or 'validation'
98
+ @param: desired image size (give square images)
99
+ """
100
+ super().__init__(**kwargs)
101
+ self.use_things = use_things
102
+ self.use_stuff = use_stuff
103
+
104
+ with open(self.paths['instances_annotations']) as f:
105
+ inst_data_json = json.load(f)
106
+ with open(self.paths['stuff_annotations']) as f:
107
+ stuff_data_json = json.load(f)
108
+
109
+ category_jsons = []
110
+ annotation_jsons = []
111
+ if self.use_things:
112
+ category_jsons.append(inst_data_json['categories'])
113
+ annotation_jsons.append(inst_data_json['annotations'])
114
+ if self.use_stuff:
115
+ category_jsons.append(stuff_data_json['categories'])
116
+ annotation_jsons.append(stuff_data_json['annotations'])
117
+
118
+ self.categories = load_categories(chain(*category_jsons))
119
+ self.filter_categories()
120
+ self.setup_category_id_and_number()
121
+
122
+ self.image_descriptions = load_image_descriptions(inst_data_json['images'])
123
+ annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
124
+ self.annotations = self.filter_object_number(annotations, self.min_object_area,
125
+ self.min_objects_per_image, self.max_objects_per_image)
126
+ self.image_ids = list(self.annotations.keys())
127
+ self.clean_up_annotations_and_image_descriptions()
128
+
129
+ def get_path_structure(self) -> Dict[str, str]:
130
+ if self.split not in COCO_PATH_STRUCTURE:
131
+ raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
132
+ return COCO_PATH_STRUCTURE[self.split]
133
+
134
+ def get_image_path(self, image_id: str) -> Path:
135
+ return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
136
+
137
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
138
+ # noinspection PyProtectedMember
139
+ return self.image_descriptions[image_id]._asdict()
taming/data/annotated_objects_dataset.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Optional, List, Callable, Dict, Any, Union
3
+ import warnings
4
+
5
+ import PIL.Image as pil_image
6
+ from torch import Tensor
7
+ from torch.utils.data import Dataset
8
+ from torchvision import transforms
9
+
10
+ from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
11
+ from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
12
+ from taming.data.conditional_builder.utils import load_object_from_string
13
+ from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
14
+ from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
15
+ Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
16
+
17
+
18
+ class AnnotatedObjectsDataset(Dataset):
19
+ def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
20
+ min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
21
+ crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
22
+ encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
23
+ no_object_classes: Optional[int] = None):
24
+ self.data_path = data_path
25
+ self.split = split
26
+ self.keys = keys
27
+ self.target_image_size = target_image_size
28
+ self.min_object_area = min_object_area
29
+ self.min_objects_per_image = min_objects_per_image
30
+ self.max_objects_per_image = max_objects_per_image
31
+ self.crop_method = crop_method
32
+ self.random_flip = random_flip
33
+ self.no_tokens = no_tokens
34
+ self.use_group_parameter = use_group_parameter
35
+ self.encode_crop = encode_crop
36
+
37
+ self.annotations = None
38
+ self.image_descriptions = None
39
+ self.categories = None
40
+ self.category_ids = None
41
+ self.category_number = None
42
+ self.image_ids = None
43
+ self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
44
+ self.paths = self.build_paths(self.data_path)
45
+ self._conditional_builders = None
46
+ self.category_allow_list = None
47
+ if category_allow_list_target:
48
+ allow_list = load_object_from_string(category_allow_list_target)
49
+ self.category_allow_list = {name for name, _ in allow_list}
50
+ self.category_mapping = {}
51
+ if category_mapping_target:
52
+ self.category_mapping = load_object_from_string(category_mapping_target)
53
+ self.no_object_classes = no_object_classes
54
+
55
+ def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
56
+ top_level = Path(top_level)
57
+ sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
58
+ for path in sub_paths.values():
59
+ if not path.exists():
60
+ raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
61
+ return sub_paths
62
+
63
+ @staticmethod
64
+ def load_image_from_disk(path: Path) -> Image:
65
+ return pil_image.open(path).convert('RGB')
66
+
67
+ @staticmethod
68
+ def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
69
+ transform_functions = []
70
+ if crop_method == 'none':
71
+ transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
72
+ elif crop_method == 'center':
73
+ transform_functions.extend([
74
+ transforms.Resize(target_image_size),
75
+ CenterCropReturnCoordinates(target_image_size)
76
+ ])
77
+ elif crop_method == 'random-1d':
78
+ transform_functions.extend([
79
+ transforms.Resize(target_image_size),
80
+ RandomCrop1dReturnCoordinates(target_image_size)
81
+ ])
82
+ elif crop_method == 'random-2d':
83
+ transform_functions.extend([
84
+ Random2dCropReturnCoordinates(target_image_size),
85
+ transforms.Resize(target_image_size)
86
+ ])
87
+ elif crop_method is None:
88
+ return None
89
+ else:
90
+ raise ValueError(f'Received invalid crop method [{crop_method}].')
91
+ if random_flip:
92
+ transform_functions.append(RandomHorizontalFlipReturn())
93
+ transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
94
+ return transform_functions
95
+
96
+ def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
97
+ crop_bbox = None
98
+ flipped = None
99
+ for t in self.transform_functions:
100
+ if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
101
+ crop_bbox, x = t(x)
102
+ elif isinstance(t, RandomHorizontalFlipReturn):
103
+ flipped, x = t(x)
104
+ else:
105
+ x = t(x)
106
+ return crop_bbox, flipped, x
107
+
108
+ @property
109
+ def no_classes(self) -> int:
110
+ return self.no_object_classes if self.no_object_classes else len(self.categories)
111
+
112
+ @property
113
+ def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
114
+ # cannot set this up in init because no_classes is only known after loading data in init of superclass
115
+ if self._conditional_builders is None:
116
+ self._conditional_builders = {
117
+ 'objects_center_points': ObjectsCenterPointsConditionalBuilder(
118
+ self.no_classes,
119
+ self.max_objects_per_image,
120
+ self.no_tokens,
121
+ self.encode_crop,
122
+ self.use_group_parameter,
123
+ getattr(self, 'use_additional_parameters', False)
124
+ ),
125
+ 'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
126
+ self.no_classes,
127
+ self.max_objects_per_image,
128
+ self.no_tokens,
129
+ self.encode_crop,
130
+ self.use_group_parameter,
131
+ getattr(self, 'use_additional_parameters', False)
132
+ )
133
+ }
134
+ return self._conditional_builders
135
+
136
+ def filter_categories(self) -> None:
137
+ if self.category_allow_list:
138
+ self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
139
+ if self.category_mapping:
140
+ self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
141
+
142
+ def setup_category_id_and_number(self) -> None:
143
+ self.category_ids = list(self.categories.keys())
144
+ self.category_ids.sort()
145
+ if '/m/01s55n' in self.category_ids:
146
+ self.category_ids.remove('/m/01s55n')
147
+ self.category_ids.append('/m/01s55n')
148
+ self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
149
+ if self.category_allow_list is not None and self.category_mapping is None \
150
+ and len(self.category_ids) != len(self.category_allow_list):
151
+ warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
152
+ 'Make sure all names in category_allow_list exist.')
153
+
154
+ def clean_up_annotations_and_image_descriptions(self) -> None:
155
+ image_id_set = set(self.image_ids)
156
+ self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
157
+ self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}
158
+
159
+ @staticmethod
160
+ def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
161
+ min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
162
+ filtered = {}
163
+ for image_id, annotations in all_annotations.items():
164
+ annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
165
+ if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
166
+ filtered[image_id] = annotations_with_min_area
167
+ return filtered
168
+
169
+ def __len__(self):
170
+ return len(self.image_ids)
171
+
172
+ def __getitem__(self, n: int) -> Dict[str, Any]:
173
+ image_id = self.get_image_id(n)
174
+ sample = self.get_image_description(image_id)
175
+ sample['annotations'] = self.get_annotation(image_id)
176
+
177
+ if 'image' in self.keys:
178
+ sample['image_path'] = str(self.get_image_path(image_id))
179
+ sample['image'] = self.load_image_from_disk(sample['image_path'])
180
+ sample['image'] = convert_pil_to_tensor(sample['image'])
181
+ sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
182
+ sample['image'] = sample['image'].permute(1, 2, 0)
183
+
184
+ for conditional, builder in self.conditional_builders.items():
185
+ if conditional in self.keys:
186
+ sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])
187
+
188
+ if self.keys:
189
+ # only return specified keys
190
+ sample = {key: sample[key] for key in self.keys}
191
+ return sample
192
+
193
+ def get_image_id(self, no: int) -> str:
194
+ return self.image_ids[no]
195
+
196
+ def get_annotation(self, image_id: str) -> str:
197
+ return self.annotations[image_id]
198
+
199
+ def get_textual_label_for_category_id(self, category_id: str) -> str:
200
+ return self.categories[category_id].name
201
+
202
+ def get_textual_label_for_category_no(self, category_no: int) -> str:
203
+ return self.categories[self.get_category_id(category_no)].name
204
+
205
+ def get_category_number(self, category_id: str) -> int:
206
+ return self.category_number[category_id]
207
+
208
+ def get_category_id(self, category_no: int) -> str:
209
+ return self.category_ids[category_no]
210
+
211
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
212
+ raise NotImplementedError()
213
+
214
+ def get_path_structure(self):
215
+ raise NotImplementedError
216
+
217
+ def get_image_path(self, image_id: str) -> Path:
218
+ raise NotImplementedError
taming/data/annotated_objects_open_images.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from csv import DictReader, reader as TupleReader
3
+ from pathlib import Path
4
+ from typing import Dict, List, Any
5
+ import warnings
6
+
7
+ from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
8
+ from taming.data.helper_types import Annotation, Category
9
+ from tqdm import tqdm
10
+
11
+ OPEN_IMAGES_STRUCTURE = {
12
+ 'train': {
13
+ 'top_level': '',
14
+ 'class_descriptions': 'class-descriptions-boxable.csv',
15
+ 'annotations': 'oidv6-train-annotations-bbox.csv',
16
+ 'file_list': 'train-images-boxable.csv',
17
+ 'files': 'train'
18
+ },
19
+ 'validation': {
20
+ 'top_level': '',
21
+ 'class_descriptions': 'class-descriptions-boxable.csv',
22
+ 'annotations': 'validation-annotations-bbox.csv',
23
+ 'file_list': 'validation-images.csv',
24
+ 'files': 'validation'
25
+ },
26
+ 'test': {
27
+ 'top_level': '',
28
+ 'class_descriptions': 'class-descriptions-boxable.csv',
29
+ 'annotations': 'test-annotations-bbox.csv',
30
+ 'file_list': 'test-images.csv',
31
+ 'files': 'test'
32
+ }
33
+ }
34
+
35
+
36
+ def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
37
+ category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
38
+ annotations: Dict[str, List[Annotation]] = defaultdict(list)
39
+ with open(descriptor_path) as file:
40
+ reader = DictReader(file)
41
+ for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'):
42
+ width = float(row['XMax']) - float(row['XMin'])
43
+ height = float(row['YMax']) - float(row['YMin'])
44
+ area = width * height
45
+ category_id = row['LabelName']
46
+ if category_id in category_mapping:
47
+ category_id = category_mapping[category_id]
48
+ if area >= min_object_area and category_id in category_no_for_id:
49
+ annotations[row['ImageID']].append(
50
+ Annotation(
51
+ id=i,
52
+ image_id=row['ImageID'],
53
+ source=row['Source'],
54
+ category_id=category_id,
55
+ category_no=category_no_for_id[category_id],
56
+ confidence=float(row['Confidence']),
57
+ bbox=(float(row['XMin']), float(row['YMin']), width, height),
58
+ area=area,
59
+ is_occluded=bool(int(row['IsOccluded'])),
60
+ is_truncated=bool(int(row['IsTruncated'])),
61
+ is_group_of=bool(int(row['IsGroupOf'])),
62
+ is_depiction=bool(int(row['IsDepiction'])),
63
+ is_inside=bool(int(row['IsInside']))
64
+ )
65
+ )
66
+ if 'train' in str(descriptor_path) and i < 14000000:
67
+ warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].')
68
+ return dict(annotations)
69
+
70
+
71
+ def load_image_ids(csv_path: Path) -> List[str]:
72
+ with open(csv_path) as file:
73
+ reader = DictReader(file)
74
+ return [row['image_name'] for row in reader]
75
+
76
+
77
+ def load_categories(csv_path: Path) -> Dict[str, Category]:
78
+ with open(csv_path) as file:
79
+ reader = TupleReader(file)
80
+ return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader}
81
+
82
+
83
+ class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset):
84
+ def __init__(self, use_additional_parameters: bool, **kwargs):
85
+ """
86
+ @param data_path: is the path to the following folder structure:
87
+ open_images/
88
+ β”‚ oidv6-train-annotations-bbox.csv
89
+ β”œβ”€β”€ class-descriptions-boxable.csv
90
+ β”œβ”€β”€ oidv6-train-annotations-bbox.csv
91
+ β”œβ”€β”€ test
92
+ β”‚ β”œβ”€β”€ 000026e7ee790996.jpg
93
+ β”‚ β”œβ”€β”€ 000062a39995e348.jpg
94
+ β”‚ └── ...
95
+ β”œβ”€β”€ test-annotations-bbox.csv
96
+ β”œβ”€β”€ test-images.csv
97
+ β”œβ”€β”€ train
98
+ β”‚ β”œβ”€β”€ 000002b66c9c498e.jpg
99
+ β”‚ β”œβ”€β”€ 000002b97e5471a0.jpg
100
+ β”‚ └── ...
101
+ β”œβ”€β”€ train-images-boxable.csv
102
+ β”œβ”€β”€ validation
103
+ β”‚ β”œβ”€β”€ 0001eeaf4aed83f9.jpg
104
+ β”‚ β”œβ”€β”€ 0004886b7d043cfd.jpg
105
+ β”‚ └── ...
106
+ β”œβ”€β”€ validation-annotations-bbox.csv
107
+ └── validation-images.csv
108
+ @param: split: one of 'train', 'validation' or 'test'
109
+ @param: desired image size (returns square images)
110
+ """
111
+
112
+ super().__init__(**kwargs)
113
+ self.use_additional_parameters = use_additional_parameters
114
+
115
+ self.categories = load_categories(self.paths['class_descriptions'])
116
+ self.filter_categories()
117
+ self.setup_category_id_and_number()
118
+
119
+ self.image_descriptions = {}
120
+ annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
121
+ self.category_number)
122
+ self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
123
+ self.max_objects_per_image)
124
+ self.image_ids = list(self.annotations.keys())
125
+ self.clean_up_annotations_and_image_descriptions()
126
+
127
+ def get_path_structure(self) -> Dict[str, str]:
128
+ if self.split not in OPEN_IMAGES_STRUCTURE:
129
+ raise ValueError(f'Split [{self.split} does not exist for Open Images data.]')
130
+ return OPEN_IMAGES_STRUCTURE[self.split]
131
+
132
+ def get_image_path(self, image_id: str) -> Path:
133
+ return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
134
+
135
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
136
+ image_path = self.get_image_path(image_id)
137
+ return {'file_path': str(image_path), 'file_name': image_path.name}
taming/data/base.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import numpy as np
3
+ import albumentations
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset, ConcatDataset
6
+
7
+
8
+ class ConcatDatasetWithIndex(ConcatDataset):
9
+ """Modified from original pytorch code to return dataset idx"""
10
+ def __getitem__(self, idx):
11
+ if idx < 0:
12
+ if -idx > len(self):
13
+ raise ValueError("absolute value of index should not exceed dataset length")
14
+ idx = len(self) + idx
15
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
16
+ if dataset_idx == 0:
17
+ sample_idx = idx
18
+ else:
19
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
20
+ return self.datasets[dataset_idx][sample_idx], dataset_idx
21
+
22
+
23
+ class ImagePaths(Dataset):
24
+ def __init__(self, paths, size=None, random_crop=False, labels=None):
25
+ self.size = size
26
+ self.random_crop = random_crop
27
+
28
+ self.labels = dict() if labels is None else labels
29
+ self.labels["file_path_"] = paths
30
+ self._length = len(paths)
31
+
32
+ if self.size is not None and self.size > 0:
33
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
34
+ if not self.random_crop:
35
+ self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
36
+ else:
37
+ self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
38
+ self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
39
+ else:
40
+ self.preprocessor = lambda **kwargs: kwargs
41
+
42
+ def __len__(self):
43
+ return self._length
44
+
45
+ def preprocess_image(self, image_path):
46
+ image = Image.open(image_path)
47
+ if not image.mode == "RGB":
48
+ image = image.convert("RGB")
49
+ image = np.array(image).astype(np.uint8)
50
+ image = self.preprocessor(image=image)["image"]
51
+ image = (image/127.5 - 1.0).astype(np.float32)
52
+ return image
53
+
54
+ def __getitem__(self, i):
55
+ example = dict()
56
+ example["image"] = self.preprocess_image(self.labels["file_path_"][i])
57
+ for k in self.labels:
58
+ example[k] = self.labels[k][i]
59
+ return example
60
+
61
+
62
+ class NumpyPaths(ImagePaths):
63
+ def preprocess_image(self, image_path):
64
+ image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
65
+ image = np.transpose(image, (1,2,0))
66
+ image = Image.fromarray(image, mode="RGB")
67
+ image = np.array(image).astype(np.uint8)
68
+ image = self.preprocessor(image=image)["image"]
69
+ image = (image/127.5 - 1.0).astype(np.float32)
70
+ return image
taming/data/coco.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import albumentations
4
+ import numpy as np
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ from torch.utils.data import Dataset
8
+
9
+ from taming.data.sflckr import SegmentationBase # for examples included in repo
10
+
11
+
12
+ class Examples(SegmentationBase):
13
+ def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
14
+ super().__init__(data_csv="data/coco_examples.txt",
15
+ data_root="data/coco_images",
16
+ segmentation_root="data/coco_segmentations",
17
+ size=size, random_crop=random_crop,
18
+ interpolation=interpolation,
19
+ n_labels=183, shift_segmentation=True)
20
+
21
+
22
+ class CocoBase(Dataset):
23
+ """needed for (image, caption, segmentation) pairs"""
24
+ def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
25
+ crop_size=None, force_no_crop=False, given_files=None):
26
+ self.split = self.get_split()
27
+ self.size = size
28
+ if crop_size is None:
29
+ self.crop_size = size
30
+ else:
31
+ self.crop_size = crop_size
32
+
33
+ self.onehot = onehot_segmentation # return segmentation as rgb or one hot
34
+ self.stuffthing = use_stuffthing # include thing in segmentation
35
+ if self.onehot and not self.stuffthing:
36
+ raise NotImplemented("One hot mode is only supported for the "
37
+ "stuffthings version because labels are stored "
38
+ "a bit different.")
39
+
40
+ data_json = datajson
41
+ with open(data_json) as json_file:
42
+ self.json_data = json.load(json_file)
43
+ self.img_id_to_captions = dict()
44
+ self.img_id_to_filepath = dict()
45
+ self.img_id_to_segmentation_filepath = dict()
46
+
47
+ assert data_json.split("/")[-1] in ["captions_train2017.json",
48
+ "captions_val2017.json"]
49
+ if self.stuffthing:
50
+ self.segmentation_prefix = (
51
+ "data/cocostuffthings/val2017" if
52
+ data_json.endswith("captions_val2017.json") else
53
+ "data/cocostuffthings/train2017")
54
+ else:
55
+ self.segmentation_prefix = (
56
+ "data/coco/annotations/stuff_val2017_pixelmaps" if
57
+ data_json.endswith("captions_val2017.json") else
58
+ "data/coco/annotations/stuff_train2017_pixelmaps")
59
+
60
+ imagedirs = self.json_data["images"]
61
+ self.labels = {"image_ids": list()}
62
+ for imgdir in tqdm(imagedirs, desc="ImgToPath"):
63
+ self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
64
+ self.img_id_to_captions[imgdir["id"]] = list()
65
+ pngfilename = imgdir["file_name"].replace("jpg", "png")
66
+ self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
67
+ self.segmentation_prefix, pngfilename)
68
+ if given_files is not None:
69
+ if pngfilename in given_files:
70
+ self.labels["image_ids"].append(imgdir["id"])
71
+ else:
72
+ self.labels["image_ids"].append(imgdir["id"])
73
+
74
+ capdirs = self.json_data["annotations"]
75
+ for capdir in tqdm(capdirs, desc="ImgToCaptions"):
76
+ # there are in average 5 captions per image
77
+ self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
78
+
79
+ self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
80
+ if self.split=="validation":
81
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
82
+ else:
83
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
84
+ self.preprocessor = albumentations.Compose(
85
+ [self.rescaler, self.cropper],
86
+ additional_targets={"segmentation": "image"})
87
+ if force_no_crop:
88
+ self.rescaler = albumentations.Resize(height=self.size, width=self.size)
89
+ self.preprocessor = albumentations.Compose(
90
+ [self.rescaler],
91
+ additional_targets={"segmentation": "image"})
92
+
93
+ def __len__(self):
94
+ return len(self.labels["image_ids"])
95
+
96
+ def preprocess_image(self, image_path, segmentation_path):
97
+ image = Image.open(image_path)
98
+ if not image.mode == "RGB":
99
+ image = image.convert("RGB")
100
+ image = np.array(image).astype(np.uint8)
101
+
102
+ segmentation = Image.open(segmentation_path)
103
+ if not self.onehot and not segmentation.mode == "RGB":
104
+ segmentation = segmentation.convert("RGB")
105
+ segmentation = np.array(segmentation).astype(np.uint8)
106
+ if self.onehot:
107
+ assert self.stuffthing
108
+ # stored in caffe format: unlabeled==255. stuff and thing from
109
+ # 0-181. to be compatible with the labels in
110
+ # https://github.com/nightrome/cocostuff/blob/master/labels.txt
111
+ # we shift stuffthing one to the right and put unlabeled in zero
112
+ # as long as segmentation is uint8 shifting to right handles the
113
+ # latter too
114
+ assert segmentation.dtype == np.uint8
115
+ segmentation = segmentation + 1
116
+
117
+ processed = self.preprocessor(image=image, segmentation=segmentation)
118
+ image, segmentation = processed["image"], processed["segmentation"]
119
+ image = (image / 127.5 - 1.0).astype(np.float32)
120
+
121
+ if self.onehot:
122
+ assert segmentation.dtype == np.uint8
123
+ # make it one hot
124
+ n_labels = 183
125
+ flatseg = np.ravel(segmentation)
126
+ onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
127
+ onehot[np.arange(flatseg.size), flatseg] = True
128
+ onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
129
+ segmentation = onehot
130
+ else:
131
+ segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
132
+ return image, segmentation
133
+
134
+ def __getitem__(self, i):
135
+ img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
136
+ seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
137
+ image, segmentation = self.preprocess_image(img_path, seg_path)
138
+ captions = self.img_id_to_captions[self.labels["image_ids"][i]]
139
+ # randomly draw one of all available captions per image
140
+ caption = captions[np.random.randint(0, len(captions))]
141
+ example = {"image": image,
142
+ "caption": [str(caption[0])],
143
+ "segmentation": segmentation,
144
+ "img_path": img_path,
145
+ "seg_path": seg_path,
146
+ "filename_": img_path.split(os.sep)[-1]
147
+ }
148
+ return example
149
+
150
+
151
+ class CocoImagesAndCaptionsTrain(CocoBase):
152
+ """returns a pair of (image, caption)"""
153
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
154
+ super().__init__(size=size,
155
+ dataroot="data/coco/train2017",
156
+ datajson="data/coco/annotations/captions_train2017.json",
157
+ onehot_segmentation=onehot_segmentation,
158
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
159
+
160
+ def get_split(self):
161
+ return "train"
162
+
163
+
164
+ class CocoImagesAndCaptionsValidation(CocoBase):
165
+ """returns a pair of (image, caption)"""
166
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
167
+ given_files=None):
168
+ super().__init__(size=size,
169
+ dataroot="data/coco/val2017",
170
+ datajson="data/coco/annotations/captions_val2017.json",
171
+ onehot_segmentation=onehot_segmentation,
172
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
173
+ given_files=given_files)
174
+
175
+ def get_split(self):
176
+ return "validation"
taming/data/conditional_builder/objects_bbox.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import cycle
2
+ from typing import List, Tuple, Callable, Optional
3
+
4
+ from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
5
+ from more_itertools.recipes import grouper
6
+ from taming.data.image_transforms import convert_pil_to_tensor
7
+ from torch import LongTensor, Tensor
8
+
9
+ from taming.data.helper_types import BoundingBox, Annotation
10
+ from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
11
+ from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
12
+ pad_list, get_plot_font_size, absolute_bbox
13
+
14
+
15
+ class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
16
+ @property
17
+ def object_descriptor_length(self) -> int:
18
+ return 3
19
+
20
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
21
+ object_triples = [
22
+ (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
23
+ for ann in annotations
24
+ ]
25
+ empty_triple = (self.none, self.none, self.none)
26
+ object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
27
+ return object_triples
28
+
29
+ def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
30
+ conditional_list = conditional.tolist()
31
+ crop_coordinates = None
32
+ if self.encode_crop:
33
+ crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
34
+ conditional_list = conditional_list[:-2]
35
+ object_triples = grouper(conditional_list, 3)
36
+ assert conditional.shape[0] == self.embedding_dim
37
+ return [
38
+ (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
39
+ for object_triple in object_triples if object_triple[0] != self.none
40
+ ], crop_coordinates
41
+
42
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
43
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
44
+ plot = pil_image.new('RGB', figure_size, WHITE)
45
+ draw = pil_img_draw.Draw(plot)
46
+ font = ImageFont.truetype(
47
+ "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
48
+ size=get_plot_font_size(font_size, figure_size)
49
+ )
50
+ width, height = plot.size
51
+ description, crop_coordinates = self.inverse_build(conditional)
52
+ for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
53
+ annotation = self.representation_to_annotation(representation)
54
+ class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
55
+ bbox = absolute_bbox(bbox, width, height)
56
+ draw.rectangle(bbox, outline=color, width=line_width)
57
+ draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
58
+ if crop_coordinates is not None:
59
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
60
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
taming/data/conditional_builder/objects_center_points.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import warnings
4
+ from itertools import cycle
5
+ from typing import List, Optional, Tuple, Callable
6
+
7
+ from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
8
+ from more_itertools.recipes import grouper
9
+ from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
10
+ additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
11
+ absolute_bbox, rescale_annotations
12
+ from taming.data.helper_types import BoundingBox, Annotation
13
+ from taming.data.image_transforms import convert_pil_to_tensor
14
+ from torch import LongTensor, Tensor
15
+
16
+
17
+ class ObjectsCenterPointsConditionalBuilder:
18
+ def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
19
+ use_group_parameter: bool, use_additional_parameters: bool):
20
+ self.no_object_classes = no_object_classes
21
+ self.no_max_objects = no_max_objects
22
+ self.no_tokens = no_tokens
23
+ self.encode_crop = encode_crop
24
+ self.no_sections = int(math.sqrt(self.no_tokens))
25
+ self.use_group_parameter = use_group_parameter
26
+ self.use_additional_parameters = use_additional_parameters
27
+
28
+ @property
29
+ def none(self) -> int:
30
+ return self.no_tokens - 1
31
+
32
+ @property
33
+ def object_descriptor_length(self) -> int:
34
+ return 2
35
+
36
+ @property
37
+ def embedding_dim(self) -> int:
38
+ extra_length = 2 if self.encode_crop else 0
39
+ return self.no_max_objects * self.object_descriptor_length + extra_length
40
+
41
+ def tokenize_coordinates(self, x: float, y: float) -> int:
42
+ """
43
+ Express 2d coordinates with one number.
44
+ Example: assume self.no_tokens = 16, then no_sections = 4:
45
+ 0 0 0 0
46
+ 0 0 # 0
47
+ 0 0 0 0
48
+ 0 0 0 x
49
+ Then the # position corresponds to token 6, the x position to token 15.
50
+ @param x: float in [0, 1]
51
+ @param y: float in [0, 1]
52
+ @return: discrete tokenized coordinate
53
+ """
54
+ x_discrete = int(round(x * (self.no_sections - 1)))
55
+ y_discrete = int(round(y * (self.no_sections - 1)))
56
+ return y_discrete * self.no_sections + x_discrete
57
+
58
+ def coordinates_from_token(self, token: int) -> (float, float):
59
+ x = token % self.no_sections
60
+ y = token // self.no_sections
61
+ return x / (self.no_sections - 1), y / (self.no_sections - 1)
62
+
63
+ def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
64
+ x0, y0 = self.coordinates_from_token(token1)
65
+ x1, y1 = self.coordinates_from_token(token2)
66
+ return x0, y0, x1 - x0, y1 - y0
67
+
68
+ def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
69
+ return self.tokenize_coordinates(bbox[0], bbox[1]), \
70
+ self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])
71
+
72
+ def inverse_build(self, conditional: LongTensor) \
73
+ -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
74
+ conditional_list = conditional.tolist()
75
+ crop_coordinates = None
76
+ if self.encode_crop:
77
+ crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
78
+ conditional_list = conditional_list[:-2]
79
+ table_of_content = grouper(conditional_list, self.object_descriptor_length)
80
+ assert conditional.shape[0] == self.embedding_dim
81
+ return [
82
+ (object_tuple[0], self.coordinates_from_token(object_tuple[1]))
83
+ for object_tuple in table_of_content if object_tuple[0] != self.none
84
+ ], crop_coordinates
85
+
86
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
87
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
88
+ plot = pil_image.new('RGB', figure_size, WHITE)
89
+ draw = pil_img_draw.Draw(plot)
90
+ circle_size = get_circle_size(figure_size)
91
+ font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
92
+ size=get_plot_font_size(font_size, figure_size))
93
+ width, height = plot.size
94
+ description, crop_coordinates = self.inverse_build(conditional)
95
+ for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
96
+ x_abs, y_abs = x * width, y * height
97
+ ann = self.representation_to_annotation(representation)
98
+ label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
99
+ ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
100
+ draw.ellipse(ellipse_bbox, fill=color, width=0)
101
+ draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
102
+ if crop_coordinates is not None:
103
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
104
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
105
+
106
+ def object_representation(self, annotation: Annotation) -> int:
107
+ modifier = 0
108
+ if self.use_group_parameter:
109
+ modifier |= 1 * (annotation.is_group_of is True)
110
+ if self.use_additional_parameters:
111
+ modifier |= 2 * (annotation.is_occluded is True)
112
+ modifier |= 4 * (annotation.is_depiction is True)
113
+ modifier |= 8 * (annotation.is_inside is True)
114
+ return annotation.category_no + self.no_object_classes * modifier
115
+
116
+ def representation_to_annotation(self, representation: int) -> Annotation:
117
+ category_no = representation % self.no_object_classes
118
+ modifier = representation // self.no_object_classes
119
+ # noinspection PyTypeChecker
120
+ return Annotation(
121
+ area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
122
+ category_no=category_no,
123
+ is_group_of=bool((modifier & 1) * self.use_group_parameter),
124
+ is_occluded=bool((modifier & 2) * self.use_additional_parameters),
125
+ is_depiction=bool((modifier & 4) * self.use_additional_parameters),
126
+ is_inside=bool((modifier & 8) * self.use_additional_parameters)
127
+ )
128
+
129
+ def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
130
+ return list(self.token_pair_from_bbox(crop_coordinates))
131
+
132
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
133
+ object_tuples = [
134
+ (self.object_representation(a),
135
+ self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
136
+ for a in annotations
137
+ ]
138
+ empty_tuple = (self.none, self.none)
139
+ object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
140
+ return object_tuples
141
+
142
+ def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
143
+ -> LongTensor:
144
+ if len(annotations) == 0:
145
+ warnings.warn('Did not receive any annotations.')
146
+ if len(annotations) > self.no_max_objects:
147
+ warnings.warn('Received more annotations than allowed.')
148
+ annotations = annotations[:self.no_max_objects]
149
+
150
+ if not crop_coordinates:
151
+ crop_coordinates = FULL_CROP
152
+
153
+ random.shuffle(annotations)
154
+ annotations = filter_annotations(annotations, crop_coordinates)
155
+ if self.encode_crop:
156
+ annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
157
+ if horizontal_flip:
158
+ crop_coordinates = horizontally_flip_bbox(crop_coordinates)
159
+ extra = self._crop_encoder(crop_coordinates)
160
+ else:
161
+ annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
162
+ extra = []
163
+
164
+ object_tuples = self._make_object_descriptors(annotations)
165
+ flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
166
+ assert len(flattened) == self.embedding_dim
167
+ assert all(0 <= value < self.no_tokens for value in flattened)
168
+ return LongTensor(flattened)
taming/data/conditional_builder/utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from typing import List, Any, Tuple, Optional
3
+
4
+ from taming.data.helper_types import BoundingBox, Annotation
5
+
6
+ # source: seaborn, color palette tab10
7
+ COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
8
+ (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
9
+ BLACK = (0, 0, 0)
10
+ GRAY_75 = (63, 63, 63)
11
+ GRAY_50 = (127, 127, 127)
12
+ GRAY_25 = (191, 191, 191)
13
+ WHITE = (255, 255, 255)
14
+ FULL_CROP = (0., 0., 1., 1.)
15
+
16
+
17
+ def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
18
+ """
19
+ Give intersection area of two rectangles.
20
+ @param rectangle1: (x0, y0, w, h) of first rectangle
21
+ @param rectangle2: (x0, y0, w, h) of second rectangle
22
+ """
23
+ rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
24
+ rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
25
+ x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
26
+ y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
27
+ return x_overlap * y_overlap
28
+
29
+
30
+ def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
31
+ return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
32
+
33
+
34
+ def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
35
+ bbox = relative_bbox
36
+ bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
37
+ return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
38
+
39
+
40
+ def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
41
+ return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
42
+
43
+
44
+ def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
45
+ List[Annotation]:
46
+ def clamp(x: float):
47
+ return max(min(x, 1.), 0.)
48
+
49
+ def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
50
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
51
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
52
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
53
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
54
+ if flip:
55
+ x0 = 1 - (x0 + w)
56
+ return x0, y0, w, h
57
+
58
+ return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
59
+
60
+
61
+ def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
62
+ return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
63
+
64
+
65
+ def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
66
+ sl = slice(1) if short else slice(None)
67
+ string = ''
68
+ if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
69
+ return string
70
+ if annotation.is_group_of:
71
+ string += 'group'[sl] + ','
72
+ if annotation.is_occluded:
73
+ string += 'occluded'[sl] + ','
74
+ if annotation.is_depiction:
75
+ string += 'depiction'[sl] + ','
76
+ if annotation.is_inside:
77
+ string += 'inside'[sl]
78
+ return '(' + string.strip(",") + ')'
79
+
80
+
81
+ def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
82
+ if font_size is None:
83
+ font_size = 10
84
+ if max(figure_size) >= 256:
85
+ font_size = 12
86
+ if max(figure_size) >= 512:
87
+ font_size = 15
88
+ return font_size
89
+
90
+
91
+ def get_circle_size(figure_size: Tuple[int, int]) -> int:
92
+ circle_size = 2
93
+ if max(figure_size) >= 256:
94
+ circle_size = 3
95
+ if max(figure_size) >= 512:
96
+ circle_size = 4
97
+ return circle_size
98
+
99
+
100
+ def load_object_from_string(object_string: str) -> Any:
101
+ """
102
+ Source: https://stackoverflow.com/a/10773699
103
+ """
104
+ module_name, class_name = object_string.rsplit(".", 1)
105
+ return getattr(importlib.import_module(module_name), class_name)
taming/data/custom.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import albumentations
4
+ from torch.utils.data import Dataset
5
+
6
+ from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
7
+
8
+
9
+ class CustomBase(Dataset):
10
+ def __init__(self, *args, **kwargs):
11
+ super().__init__()
12
+ self.data = None
13
+
14
+ def __len__(self):
15
+ return len(self.data)
16
+
17
+ def __getitem__(self, i):
18
+ example = self.data[i]
19
+ return example
20
+
21
+
22
+
23
+ class CustomTrain(CustomBase):
24
+ def __init__(self, size, training_images_list_file):
25
+ super().__init__()
26
+ with open(training_images_list_file, "r") as f:
27
+ paths = f.read().splitlines()
28
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
29
+
30
+
31
+ class CustomTest(CustomBase):
32
+ def __init__(self, size, test_images_list_file):
33
+ super().__init__()
34
+ with open(test_images_list_file, "r") as f:
35
+ paths = f.read().splitlines()
36
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
37
+
38
+
taming/data/faceshq.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import albumentations
4
+ from torch.utils.data import Dataset
5
+
6
+ from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
7
+
8
+
9
+ class FacesBase(Dataset):
10
+ def __init__(self, *args, **kwargs):
11
+ super().__init__()
12
+ self.data = None
13
+ self.keys = None
14
+
15
+ def __len__(self):
16
+ return len(self.data)
17
+
18
+ def __getitem__(self, i):
19
+ example = self.data[i]
20
+ ex = {}
21
+ if self.keys is not None:
22
+ for k in self.keys:
23
+ ex[k] = example[k]
24
+ else:
25
+ ex = example
26
+ return ex
27
+
28
+
29
+ class CelebAHQTrain(FacesBase):
30
+ def __init__(self, size, keys=None):
31
+ super().__init__()
32
+ root = "data/celebahq"
33
+ with open("data/celebahqtrain.txt", "r") as f:
34
+ relpaths = f.read().splitlines()
35
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
36
+ self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
37
+ self.keys = keys
38
+
39
+
40
+ class CelebAHQValidation(FacesBase):
41
+ def __init__(self, size, keys=None):
42
+ super().__init__()
43
+ root = "data/celebahq"
44
+ with open("data/celebahqvalidation.txt", "r") as f:
45
+ relpaths = f.read().splitlines()
46
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
47
+ self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
48
+ self.keys = keys
49
+
50
+
51
+ class FFHQTrain(FacesBase):
52
+ def __init__(self, size, keys=None):
53
+ super().__init__()
54
+ root = "data/ffhq"
55
+ with open("data/ffhqtrain.txt", "r") as f:
56
+ relpaths = f.read().splitlines()
57
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
58
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
59
+ self.keys = keys
60
+
61
+
62
+ class FFHQValidation(FacesBase):
63
+ def __init__(self, size, keys=None):
64
+ super().__init__()
65
+ root = "data/ffhq"
66
+ with open("data/ffhqvalidation.txt", "r") as f:
67
+ relpaths = f.read().splitlines()
68
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
69
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
70
+ self.keys = keys
71
+
72
+
73
+ class FacesHQTrain(Dataset):
74
+ # CelebAHQ [0] + FFHQ [1]
75
+ def __init__(self, size, keys=None, crop_size=None, coord=False):
76
+ d1 = CelebAHQTrain(size=size, keys=keys)
77
+ d2 = FFHQTrain(size=size, keys=keys)
78
+ self.data = ConcatDatasetWithIndex([d1, d2])
79
+ self.coord = coord
80
+ if crop_size is not None:
81
+ self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
82
+ if self.coord:
83
+ self.cropper = albumentations.Compose([self.cropper],
84
+ additional_targets={"coord": "image"})
85
+
86
+ def __len__(self):
87
+ return len(self.data)
88
+
89
+ def __getitem__(self, i):
90
+ ex, y = self.data[i]
91
+ if hasattr(self, "cropper"):
92
+ if not self.coord:
93
+ out = self.cropper(image=ex["image"])
94
+ ex["image"] = out["image"]
95
+ else:
96
+ h,w,_ = ex["image"].shape
97
+ coord = np.arange(h*w).reshape(h,w,1)/(h*w)
98
+ out = self.cropper(image=ex["image"], coord=coord)
99
+ ex["image"] = out["image"]
100
+ ex["coord"] = out["coord"]
101
+ ex["class"] = y
102
+ return ex
103
+
104
+
105
+ class FacesHQValidation(Dataset):
106
+ # CelebAHQ [0] + FFHQ [1]
107
+ def __init__(self, size, keys=None, crop_size=None, coord=False):
108
+ d1 = CelebAHQValidation(size=size, keys=keys)
109
+ d2 = FFHQValidation(size=size, keys=keys)
110
+ self.data = ConcatDatasetWithIndex([d1, d2])
111
+ self.coord = coord
112
+ if crop_size is not None:
113
+ self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
114
+ if self.coord:
115
+ self.cropper = albumentations.Compose([self.cropper],
116
+ additional_targets={"coord": "image"})
117
+
118
+ def __len__(self):
119
+ return len(self.data)
120
+
121
+ def __getitem__(self, i):
122
+ ex, y = self.data[i]
123
+ if hasattr(self, "cropper"):
124
+ if not self.coord:
125
+ out = self.cropper(image=ex["image"])
126
+ ex["image"] = out["image"]
127
+ else:
128
+ h,w,_ = ex["image"].shape
129
+ coord = np.arange(h*w).reshape(h,w,1)/(h*w)
130
+ out = self.cropper(image=ex["image"], coord=coord)
131
+ ex["image"] = out["image"]
132
+ ex["coord"] = out["coord"]
133
+ ex["class"] = y
134
+ return ex
taming/data/helper_types.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple, Optional, NamedTuple, Union
2
+ from PIL.Image import Image as pil_image
3
+ from torch import Tensor
4
+
5
+ try:
6
+ from typing import Literal
7
+ except ImportError:
8
+ from typing_extensions import Literal
9
+
10
+ Image = Union[Tensor, pil_image]
11
+ BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
12
+ CropMethodType = Literal['none', 'random', 'center', 'random-2d']
13
+ SplitType = Literal['train', 'validation', 'test']
14
+
15
+
16
+ class ImageDescription(NamedTuple):
17
+ id: int
18
+ file_name: str
19
+ original_size: Tuple[int, int] # w, h
20
+ url: Optional[str] = None
21
+ license: Optional[int] = None
22
+ coco_url: Optional[str] = None
23
+ date_captured: Optional[str] = None
24
+ flickr_url: Optional[str] = None
25
+ flickr_id: Optional[str] = None
26
+ coco_id: Optional[str] = None
27
+
28
+
29
+ class Category(NamedTuple):
30
+ id: str
31
+ super_category: Optional[str]
32
+ name: str
33
+
34
+
35
+ class Annotation(NamedTuple):
36
+ area: float
37
+ image_id: str
38
+ bbox: BoundingBox
39
+ category_no: int
40
+ category_id: str
41
+ id: Optional[int] = None
42
+ source: Optional[str] = None
43
+ confidence: Optional[float] = None
44
+ is_group_of: Optional[bool] = None
45
+ is_truncated: Optional[bool] = None
46
+ is_occluded: Optional[bool] = None
47
+ is_depiction: Optional[bool] = None
48
+ is_inside: Optional[bool] = None
49
+ segmentation: Optional[Dict] = None
taming/data/image_transforms.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import warnings
3
+ from typing import Union
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
8
+ from torchvision.transforms.functional import _get_image_size as get_image_size
9
+
10
+ from taming.data.helper_types import BoundingBox, Image
11
+
12
+ pil_to_tensor = PILToTensor()
13
+
14
+
15
+ def convert_pil_to_tensor(image: Image) -> Tensor:
16
+ with warnings.catch_warnings():
17
+ # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
18
+ warnings.simplefilter("ignore")
19
+ return pil_to_tensor(image)
20
+
21
+
22
+ class RandomCrop1dReturnCoordinates(RandomCrop):
23
+ def forward(self, img: Image) -> (BoundingBox, Image):
24
+ """
25
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
26
+ Args:
27
+ img (PIL Image or Tensor): Image to be cropped.
28
+
29
+ Returns:
30
+ Bounding box: x0, y0, w, h
31
+ PIL Image or Tensor: Cropped image.
32
+
33
+ Based on:
34
+ torchvision.transforms.RandomCrop, torchvision 1.7.0
35
+ """
36
+ if self.padding is not None:
37
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
38
+
39
+ width, height = get_image_size(img)
40
+ # pad the width if needed
41
+ if self.pad_if_needed and width < self.size[1]:
42
+ padding = [self.size[1] - width, 0]
43
+ img = F.pad(img, padding, self.fill, self.padding_mode)
44
+ # pad the height if needed
45
+ if self.pad_if_needed and height < self.size[0]:
46
+ padding = [0, self.size[0] - height]
47
+ img = F.pad(img, padding, self.fill, self.padding_mode)
48
+
49
+ i, j, h, w = self.get_params(img, self.size)
50
+ bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
51
+ return bbox, F.crop(img, i, j, h, w)
52
+
53
+
54
+ class Random2dCropReturnCoordinates(torch.nn.Module):
55
+ """
56
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
57
+ Args:
58
+ img (PIL Image or Tensor): Image to be cropped.
59
+
60
+ Returns:
61
+ Bounding box: x0, y0, w, h
62
+ PIL Image or Tensor: Cropped image.
63
+
64
+ Based on:
65
+ torchvision.transforms.RandomCrop, torchvision 1.7.0
66
+ """
67
+
68
+ def __init__(self, min_size: int):
69
+ super().__init__()
70
+ self.min_size = min_size
71
+
72
+ def forward(self, img: Image) -> (BoundingBox, Image):
73
+ width, height = get_image_size(img)
74
+ max_size = min(width, height)
75
+ if max_size <= self.min_size:
76
+ size = max_size
77
+ else:
78
+ size = random.randint(self.min_size, max_size)
79
+ top = random.randint(0, height - size)
80
+ left = random.randint(0, width - size)
81
+ bbox = left / width, top / height, size / width, size / height
82
+ return bbox, F.crop(img, top, left, size, size)
83
+
84
+
85
+ class CenterCropReturnCoordinates(CenterCrop):
86
+ @staticmethod
87
+ def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
88
+ if width > height:
89
+ w = height / width
90
+ h = 1.0
91
+ x0 = 0.5 - w / 2
92
+ y0 = 0.
93
+ else:
94
+ w = 1.0
95
+ h = width / height
96
+ x0 = 0.
97
+ y0 = 0.5 - h / 2
98
+ return x0, y0, w, h
99
+
100
+ def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
101
+ """
102
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
103
+ Args:
104
+ img (PIL Image or Tensor): Image to be cropped.
105
+
106
+ Returns:
107
+ Bounding box: x0, y0, w, h
108
+ PIL Image or Tensor: Cropped image.
109
+ Based on:
110
+ torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
111
+ """
112
+ width, height = get_image_size(img)
113
+ return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
114
+
115
+
116
+ class RandomHorizontalFlipReturn(RandomHorizontalFlip):
117
+ def forward(self, img: Image) -> (bool, Image):
118
+ """
119
+ Additionally to flipping, returns a boolean whether it was flipped or not.
120
+ Args:
121
+ img (PIL Image or Tensor): Image to be flipped.
122
+
123
+ Returns:
124
+ flipped: whether the image was flipped or not
125
+ PIL Image or Tensor: Randomly flipped image.
126
+
127
+ Based on:
128
+ torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
129
+ """
130
+ if torch.rand(1) < self.p:
131
+ return True, F.hflip(img)
132
+ return False, img
taming/data/imagenet.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, tarfile, glob, shutil
2
+ import yaml
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from PIL import Image
6
+ import albumentations
7
+ from omegaconf import OmegaConf
8
+ from torch.utils.data import Dataset
9
+
10
+ from taming.data.base import ImagePaths
11
+ from taming.util import download, retrieve
12
+ import taming.data.utils as bdu
13
+
14
+
15
+ def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"):
16
+ synsets = []
17
+ with open(path_to_yaml) as f:
18
+ di2s = yaml.load(f)
19
+ for idx in indices:
20
+ synsets.append(str(di2s[idx]))
21
+ print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets)))
22
+ return synsets
23
+
24
+
25
+ def str_to_indices(string):
26
+ """Expects a string in the format '32-123, 256, 280-321'"""
27
+ assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string)
28
+ subs = string.split(",")
29
+ indices = []
30
+ for sub in subs:
31
+ subsubs = sub.split("-")
32
+ assert len(subsubs) > 0
33
+ if len(subsubs) == 1:
34
+ indices.append(int(subsubs[0]))
35
+ else:
36
+ rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))]
37
+ indices.extend(rang)
38
+ return sorted(indices)
39
+
40
+
41
+ class ImageNetBase(Dataset):
42
+ def __init__(self, config=None):
43
+ self.config = config or OmegaConf.create()
44
+ if not type(self.config)==dict:
45
+ self.config = OmegaConf.to_container(self.config)
46
+ self._prepare()
47
+ self._prepare_synset_to_human()
48
+ self._prepare_idx_to_synset()
49
+ self._load()
50
+
51
+ def __len__(self):
52
+ return len(self.data)
53
+
54
+ def __getitem__(self, i):
55
+ return self.data[i]
56
+
57
+ def _prepare(self):
58
+ raise NotImplementedError()
59
+
60
+ def _filter_relpaths(self, relpaths):
61
+ ignore = set([
62
+ "n06596364_9591.JPEG",
63
+ ])
64
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
65
+ if "sub_indices" in self.config:
66
+ indices = str_to_indices(self.config["sub_indices"])
67
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
68
+ files = []
69
+ for rpath in relpaths:
70
+ syn = rpath.split("/")[0]
71
+ if syn in synsets:
72
+ files.append(rpath)
73
+ return files
74
+ else:
75
+ return relpaths
76
+
77
+ def _prepare_synset_to_human(self):
78
+ SIZE = 2655750
79
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
80
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
81
+ if (not os.path.exists(self.human_dict) or
82
+ not os.path.getsize(self.human_dict)==SIZE):
83
+ download(URL, self.human_dict)
84
+
85
+ def _prepare_idx_to_synset(self):
86
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
87
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
88
+ if (not os.path.exists(self.idx2syn)):
89
+ download(URL, self.idx2syn)
90
+
91
+ def _load(self):
92
+ with open(self.txt_filelist, "r") as f:
93
+ self.relpaths = f.read().splitlines()
94
+ l1 = len(self.relpaths)
95
+ self.relpaths = self._filter_relpaths(self.relpaths)
96
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
97
+
98
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
99
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
100
+
101
+ unique_synsets = np.unique(self.synsets)
102
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
103
+ self.class_labels = [class_dict[s] for s in self.synsets]
104
+
105
+ with open(self.human_dict, "r") as f:
106
+ human_dict = f.read().splitlines()
107
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
108
+
109
+ self.human_labels = [human_dict[s] for s in self.synsets]
110
+
111
+ labels = {
112
+ "relpath": np.array(self.relpaths),
113
+ "synsets": np.array(self.synsets),
114
+ "class_label": np.array(self.class_labels),
115
+ "human_label": np.array(self.human_labels),
116
+ }
117
+ self.data = ImagePaths(self.abspaths,
118
+ labels=labels,
119
+ size=retrieve(self.config, "size", default=0),
120
+ random_crop=self.random_crop)
121
+
122
+
123
+ class ImageNetTrain(ImageNetBase):
124
+ NAME = "ILSVRC2012_train"
125
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
126
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
127
+ FILES = [
128
+ "ILSVRC2012_img_train.tar",
129
+ ]
130
+ SIZES = [
131
+ 147897477120,
132
+ ]
133
+
134
+ def _prepare(self):
135
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
136
+ default=True)
137
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
138
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
139
+ self.datadir = os.path.join(self.root, "data")
140
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
141
+ self.expected_length = 1281167
142
+ if not bdu.is_prepared(self.root):
143
+ # prep
144
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
145
+
146
+ datadir = self.datadir
147
+ if not os.path.exists(datadir):
148
+ path = os.path.join(self.root, self.FILES[0])
149
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
150
+ import academictorrents as at
151
+ atpath = at.get(self.AT_HASH, datastore=self.root)
152
+ assert atpath == path
153
+
154
+ print("Extracting {} to {}".format(path, datadir))
155
+ os.makedirs(datadir, exist_ok=True)
156
+ with tarfile.open(path, "r:") as tar:
157
+ tar.extractall(path=datadir)
158
+
159
+ print("Extracting sub-tars.")
160
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
161
+ for subpath in tqdm(subpaths):
162
+ subdir = subpath[:-len(".tar")]
163
+ os.makedirs(subdir, exist_ok=True)
164
+ with tarfile.open(subpath, "r:") as tar:
165
+ tar.extractall(path=subdir)
166
+
167
+
168
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
169
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
170
+ filelist = sorted(filelist)
171
+ filelist = "\n".join(filelist)+"\n"
172
+ with open(self.txt_filelist, "w") as f:
173
+ f.write(filelist)
174
+
175
+ bdu.mark_prepared(self.root)
176
+
177
+
178
+ class ImageNetValidation(ImageNetBase):
179
+ NAME = "ILSVRC2012_validation"
180
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
181
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
182
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
183
+ FILES = [
184
+ "ILSVRC2012_img_val.tar",
185
+ "validation_synset.txt",
186
+ ]
187
+ SIZES = [
188
+ 6744924160,
189
+ 1950000,
190
+ ]
191
+
192
+ def _prepare(self):
193
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
194
+ default=False)
195
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
196
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
197
+ self.datadir = os.path.join(self.root, "data")
198
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
199
+ self.expected_length = 50000
200
+ if not bdu.is_prepared(self.root):
201
+ # prep
202
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
203
+
204
+ datadir = self.datadir
205
+ if not os.path.exists(datadir):
206
+ path = os.path.join(self.root, self.FILES[0])
207
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
208
+ import academictorrents as at
209
+ atpath = at.get(self.AT_HASH, datastore=self.root)
210
+ assert atpath == path
211
+
212
+ print("Extracting {} to {}".format(path, datadir))
213
+ os.makedirs(datadir, exist_ok=True)
214
+ with tarfile.open(path, "r:") as tar:
215
+ tar.extractall(path=datadir)
216
+
217
+ vspath = os.path.join(self.root, self.FILES[1])
218
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
219
+ download(self.VS_URL, vspath)
220
+
221
+ with open(vspath, "r") as f:
222
+ synset_dict = f.read().splitlines()
223
+ synset_dict = dict(line.split() for line in synset_dict)
224
+
225
+ print("Reorganizing into synset folders")
226
+ synsets = np.unique(list(synset_dict.values()))
227
+ for s in synsets:
228
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
229
+ for k, v in synset_dict.items():
230
+ src = os.path.join(datadir, k)
231
+ dst = os.path.join(datadir, v)
232
+ shutil.move(src, dst)
233
+
234
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
235
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
236
+ filelist = sorted(filelist)
237
+ filelist = "\n".join(filelist)+"\n"
238
+ with open(self.txt_filelist, "w") as f:
239
+ f.write(filelist)
240
+
241
+ bdu.mark_prepared(self.root)
242
+
243
+
244
+ def get_preprocessor(size=None, random_crop=False, additional_targets=None,
245
+ crop_size=None):
246
+ if size is not None and size > 0:
247
+ transforms = list()
248
+ rescaler = albumentations.SmallestMaxSize(max_size = size)
249
+ transforms.append(rescaler)
250
+ if not random_crop:
251
+ cropper = albumentations.CenterCrop(height=size,width=size)
252
+ transforms.append(cropper)
253
+ else:
254
+ cropper = albumentations.RandomCrop(height=size,width=size)
255
+ transforms.append(cropper)
256
+ flipper = albumentations.HorizontalFlip()
257
+ transforms.append(flipper)
258
+ preprocessor = albumentations.Compose(transforms,
259
+ additional_targets=additional_targets)
260
+ elif crop_size is not None and crop_size > 0:
261
+ if not random_crop:
262
+ cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
263
+ else:
264
+ cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
265
+ transforms = [cropper]
266
+ preprocessor = albumentations.Compose(transforms,
267
+ additional_targets=additional_targets)
268
+ else:
269
+ preprocessor = lambda **kwargs: kwargs
270
+ return preprocessor
271
+
272
+
273
+ def rgba_to_depth(x):
274
+ assert x.dtype == np.uint8
275
+ assert len(x.shape) == 3 and x.shape[2] == 4
276
+ y = x.copy()
277
+ y.dtype = np.float32
278
+ y = y.reshape(x.shape[:2])
279
+ return np.ascontiguousarray(y)
280
+
281
+
282
+ class BaseWithDepth(Dataset):
283
+ DEFAULT_DEPTH_ROOT="data/imagenet_depth"
284
+
285
+ def __init__(self, config=None, size=None, random_crop=False,
286
+ crop_size=None, root=None):
287
+ self.config = config
288
+ self.base_dset = self.get_base_dset()
289
+ self.preprocessor = get_preprocessor(
290
+ size=size,
291
+ crop_size=crop_size,
292
+ random_crop=random_crop,
293
+ additional_targets={"depth": "image"})
294
+ self.crop_size = crop_size
295
+ if self.crop_size is not None:
296
+ self.rescaler = albumentations.Compose(
297
+ [albumentations.SmallestMaxSize(max_size = self.crop_size)],
298
+ additional_targets={"depth": "image"})
299
+ if root is not None:
300
+ self.DEFAULT_DEPTH_ROOT = root
301
+
302
+ def __len__(self):
303
+ return len(self.base_dset)
304
+
305
+ def preprocess_depth(self, path):
306
+ rgba = np.array(Image.open(path))
307
+ depth = rgba_to_depth(rgba)
308
+ depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
309
+ depth = 2.0*depth-1.0
310
+ return depth
311
+
312
+ def __getitem__(self, i):
313
+ e = self.base_dset[i]
314
+ e["depth"] = self.preprocess_depth(self.get_depth_path(e))
315
+ # up if necessary
316
+ h,w,c = e["image"].shape
317
+ if self.crop_size and min(h,w) < self.crop_size:
318
+ # have to upscale to be able to crop - this just uses bilinear
319
+ out = self.rescaler(image=e["image"], depth=e["depth"])
320
+ e["image"] = out["image"]
321
+ e["depth"] = out["depth"]
322
+ transformed = self.preprocessor(image=e["image"], depth=e["depth"])
323
+ e["image"] = transformed["image"]
324
+ e["depth"] = transformed["depth"]
325
+ return e
326
+
327
+
328
+ class ImageNetTrainWithDepth(BaseWithDepth):
329
+ # default to random_crop=True
330
+ def __init__(self, random_crop=True, sub_indices=None, **kwargs):
331
+ self.sub_indices = sub_indices
332
+ super().__init__(random_crop=random_crop, **kwargs)
333
+
334
+ def get_base_dset(self):
335
+ if self.sub_indices is None:
336
+ return ImageNetTrain()
337
+ else:
338
+ return ImageNetTrain({"sub_indices": self.sub_indices})
339
+
340
+ def get_depth_path(self, e):
341
+ fid = os.path.splitext(e["relpath"])[0]+".png"
342
+ fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid)
343
+ return fid
344
+
345
+
346
+ class ImageNetValidationWithDepth(BaseWithDepth):
347
+ def __init__(self, sub_indices=None, **kwargs):
348
+ self.sub_indices = sub_indices
349
+ super().__init__(**kwargs)
350
+
351
+ def get_base_dset(self):
352
+ if self.sub_indices is None:
353
+ return ImageNetValidation()
354
+ else:
355
+ return ImageNetValidation({"sub_indices": self.sub_indices})
356
+
357
+ def get_depth_path(self, e):
358
+ fid = os.path.splitext(e["relpath"])[0]+".png"
359
+ fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid)
360
+ return fid
361
+
362
+
363
+ class RINTrainWithDepth(ImageNetTrainWithDepth):
364
+ def __init__(self, config=None, size=None, random_crop=True, crop_size=None):
365
+ sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
366
+ super().__init__(config=config, size=size, random_crop=random_crop,
367
+ sub_indices=sub_indices, crop_size=crop_size)
368
+
369
+
370
+ class RINValidationWithDepth(ImageNetValidationWithDepth):
371
+ def __init__(self, config=None, size=None, random_crop=False, crop_size=None):
372
+ sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
373
+ super().__init__(config=config, size=size, random_crop=random_crop,
374
+ sub_indices=sub_indices, crop_size=crop_size)
375
+
376
+
377
+ class DRINExamples(Dataset):
378
+ def __init__(self):
379
+ self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"})
380
+ with open("data/drin_examples.txt", "r") as f:
381
+ relpaths = f.read().splitlines()
382
+ self.image_paths = [os.path.join("data/drin_images",
383
+ relpath) for relpath in relpaths]
384
+ self.depth_paths = [os.path.join("data/drin_depth",
385
+ relpath.replace(".JPEG", ".png")) for relpath in relpaths]
386
+
387
+ def __len__(self):
388
+ return len(self.image_paths)
389
+
390
+ def preprocess_image(self, image_path):
391
+ image = Image.open(image_path)
392
+ if not image.mode == "RGB":
393
+ image = image.convert("RGB")
394
+ image = np.array(image).astype(np.uint8)
395
+ image = self.preprocessor(image=image)["image"]
396
+ image = (image/127.5 - 1.0).astype(np.float32)
397
+ return image
398
+
399
+ def preprocess_depth(self, path):
400
+ rgba = np.array(Image.open(path))
401
+ depth = rgba_to_depth(rgba)
402
+ depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
403
+ depth = 2.0*depth-1.0
404
+ return depth
405
+
406
+ def __getitem__(self, i):
407
+ e = dict()
408
+ e["image"] = self.preprocess_image(self.image_paths[i])
409
+ e["depth"] = self.preprocess_depth(self.depth_paths[i])
410
+ transformed = self.preprocessor(image=e["image"], depth=e["depth"])
411
+ e["image"] = transformed["image"]
412
+ e["depth"] = transformed["depth"]
413
+ return e
414
+
415
+
416
+ def imscale(x, factor, keepshapes=False, keepmode="bicubic"):
417
+ if factor is None or factor==1:
418
+ return x
419
+
420
+ dtype = x.dtype
421
+ assert dtype in [np.float32, np.float64]
422
+ assert x.min() >= -1
423
+ assert x.max() <= 1
424
+
425
+ keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR,
426
+ "bicubic": Image.BICUBIC}[keepmode]
427
+
428
+ lr = (x+1.0)*127.5
429
+ lr = lr.clip(0,255).astype(np.uint8)
430
+ lr = Image.fromarray(lr)
431
+
432
+ h, w, _ = x.shape
433
+ nh = h//factor
434
+ nw = w//factor
435
+ assert nh > 0 and nw > 0, (nh, nw)
436
+
437
+ lr = lr.resize((nw,nh), Image.BICUBIC)
438
+ if keepshapes:
439
+ lr = lr.resize((w,h), keepmode)
440
+ lr = np.array(lr)/127.5-1.0
441
+ lr = lr.astype(dtype)
442
+
443
+ return lr
444
+
445
+
446
+ class ImageNetScale(Dataset):
447
+ def __init__(self, size=None, crop_size=None, random_crop=False,
448
+ up_factor=None, hr_factor=None, keep_mode="bicubic"):
449
+ self.base = self.get_base()
450
+
451
+ self.size = size
452
+ self.crop_size = crop_size if crop_size is not None else self.size
453
+ self.random_crop = random_crop
454
+ self.up_factor = up_factor
455
+ self.hr_factor = hr_factor
456
+ self.keep_mode = keep_mode
457
+
458
+ transforms = list()
459
+
460
+ if self.size is not None and self.size > 0:
461
+ rescaler = albumentations.SmallestMaxSize(max_size = self.size)
462
+ self.rescaler = rescaler
463
+ transforms.append(rescaler)
464
+
465
+ if self.crop_size is not None and self.crop_size > 0:
466
+ if len(transforms) == 0:
467
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size)
468
+
469
+ if not self.random_crop:
470
+ cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size)
471
+ else:
472
+ cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size)
473
+ transforms.append(cropper)
474
+
475
+ if len(transforms) > 0:
476
+ if self.up_factor is not None:
477
+ additional_targets = {"lr": "image"}
478
+ else:
479
+ additional_targets = None
480
+ self.preprocessor = albumentations.Compose(transforms,
481
+ additional_targets=additional_targets)
482
+ else:
483
+ self.preprocessor = lambda **kwargs: kwargs
484
+
485
+ def __len__(self):
486
+ return len(self.base)
487
+
488
+ def __getitem__(self, i):
489
+ example = self.base[i]
490
+ image = example["image"]
491
+ # adjust resolution
492
+ image = imscale(image, self.hr_factor, keepshapes=False)
493
+ h,w,c = image.shape
494
+ if self.crop_size and min(h,w) < self.crop_size:
495
+ # have to upscale to be able to crop - this just uses bilinear
496
+ image = self.rescaler(image=image)["image"]
497
+ if self.up_factor is None:
498
+ image = self.preprocessor(image=image)["image"]
499
+ example["image"] = image
500
+ else:
501
+ lr = imscale(image, self.up_factor, keepshapes=True,
502
+ keepmode=self.keep_mode)
503
+
504
+ out = self.preprocessor(image=image, lr=lr)
505
+ example["image"] = out["image"]
506
+ example["lr"] = out["lr"]
507
+
508
+ return example
509
+
510
+ class ImageNetScaleTrain(ImageNetScale):
511
+ def __init__(self, random_crop=True, **kwargs):
512
+ super().__init__(random_crop=random_crop, **kwargs)
513
+
514
+ def get_base(self):
515
+ return ImageNetTrain()
516
+
517
+ class ImageNetScaleValidation(ImageNetScale):
518
+ def get_base(self):
519
+ return ImageNetValidation()
520
+
521
+
522
+ from skimage.feature import canny
523
+ from skimage.color import rgb2gray
524
+
525
+
526
+ class ImageNetEdges(ImageNetScale):
527
+ def __init__(self, up_factor=1, **kwargs):
528
+ super().__init__(up_factor=1, **kwargs)
529
+
530
+ def __getitem__(self, i):
531
+ example = self.base[i]
532
+ image = example["image"]
533
+ h,w,c = image.shape
534
+ if self.crop_size and min(h,w) < self.crop_size:
535
+ # have to upscale to be able to crop - this just uses bilinear
536
+ image = self.rescaler(image=image)["image"]
537
+
538
+ lr = canny(rgb2gray(image), sigma=2)
539
+ lr = lr.astype(np.float32)
540
+ lr = lr[:,:,None][:,:,[0,0,0]]
541
+
542
+ out = self.preprocessor(image=image, lr=lr)
543
+ example["image"] = out["image"]
544
+ example["lr"] = out["lr"]
545
+
546
+ return example
547
+
548
+
549
+ class ImageNetEdgesTrain(ImageNetEdges):
550
+ def __init__(self, random_crop=True, **kwargs):
551
+ super().__init__(random_crop=random_crop, **kwargs)
552
+
553
+ def get_base(self):
554
+ return ImageNetTrain()
555
+
556
+ class ImageNetEdgesValidation(ImageNetEdges):
557
+ def get_base(self):
558
+ return ImageNetValidation()
taming/data/open_images_helper.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ open_images_unify_categories_for_coco = {
2
+ '/m/03bt1vf': '/m/01g317',
3
+ '/m/04yx4': '/m/01g317',
4
+ '/m/05r655': '/m/01g317',
5
+ '/m/01bl7v': '/m/01g317',
6
+ '/m/0cnyhnx': '/m/01xq0k1',
7
+ '/m/01226z': '/m/018xm',
8
+ '/m/05ctyq': '/m/018xm',
9
+ '/m/058qzx': '/m/04ctx',
10
+ '/m/06pcq': '/m/0l515',
11
+ '/m/03m3pdh': '/m/02crq1',
12
+ '/m/046dlr': '/m/01x3z',
13
+ '/m/0h8mzrc': '/m/01x3z',
14
+ }
15
+
16
+
17
+ top_300_classes_plus_coco_compatibility = [
18
+ ('Man', 1060962),
19
+ ('Clothing', 986610),
20
+ ('Tree', 748162),
21
+ ('Woman', 611896),
22
+ ('Person', 610294),
23
+ ('Human face', 442948),
24
+ ('Girl', 175399),
25
+ ('Building', 162147),
26
+ ('Car', 159135),
27
+ ('Plant', 155704),
28
+ ('Human body', 137073),
29
+ ('Flower', 133128),
30
+ ('Window', 127485),
31
+ ('Human arm', 118380),
32
+ ('House', 114365),
33
+ ('Wheel', 111684),
34
+ ('Suit', 99054),
35
+ ('Human hair', 98089),
36
+ ('Human head', 92763),
37
+ ('Chair', 88624),
38
+ ('Boy', 79849),
39
+ ('Table', 73699),
40
+ ('Jeans', 57200),
41
+ ('Tire', 55725),
42
+ ('Skyscraper', 53321),
43
+ ('Food', 52400),
44
+ ('Footwear', 50335),
45
+ ('Dress', 50236),
46
+ ('Human leg', 47124),
47
+ ('Toy', 46636),
48
+ ('Tower', 45605),
49
+ ('Boat', 43486),
50
+ ('Land vehicle', 40541),
51
+ ('Bicycle wheel', 34646),
52
+ ('Palm tree', 33729),
53
+ ('Fashion accessory', 32914),
54
+ ('Glasses', 31940),
55
+ ('Bicycle', 31409),
56
+ ('Furniture', 30656),
57
+ ('Sculpture', 29643),
58
+ ('Bottle', 27558),
59
+ ('Dog', 26980),
60
+ ('Snack', 26796),
61
+ ('Human hand', 26664),
62
+ ('Bird', 25791),
63
+ ('Book', 25415),
64
+ ('Guitar', 24386),
65
+ ('Jacket', 23998),
66
+ ('Poster', 22192),
67
+ ('Dessert', 21284),
68
+ ('Baked goods', 20657),
69
+ ('Drink', 19754),
70
+ ('Flag', 18588),
71
+ ('Houseplant', 18205),
72
+ ('Tableware', 17613),
73
+ ('Airplane', 17218),
74
+ ('Door', 17195),
75
+ ('Sports uniform', 17068),
76
+ ('Shelf', 16865),
77
+ ('Drum', 16612),
78
+ ('Vehicle', 16542),
79
+ ('Microphone', 15269),
80
+ ('Street light', 14957),
81
+ ('Cat', 14879),
82
+ ('Fruit', 13684),
83
+ ('Fast food', 13536),
84
+ ('Animal', 12932),
85
+ ('Vegetable', 12534),
86
+ ('Train', 12358),
87
+ ('Horse', 11948),
88
+ ('Flowerpot', 11728),
89
+ ('Motorcycle', 11621),
90
+ ('Fish', 11517),
91
+ ('Desk', 11405),
92
+ ('Helmet', 10996),
93
+ ('Truck', 10915),
94
+ ('Bus', 10695),
95
+ ('Hat', 10532),
96
+ ('Auto part', 10488),
97
+ ('Musical instrument', 10303),
98
+ ('Sunglasses', 10207),
99
+ ('Picture frame', 10096),
100
+ ('Sports equipment', 10015),
101
+ ('Shorts', 9999),
102
+ ('Wine glass', 9632),
103
+ ('Duck', 9242),
104
+ ('Wine', 9032),
105
+ ('Rose', 8781),
106
+ ('Tie', 8693),
107
+ ('Butterfly', 8436),
108
+ ('Beer', 7978),
109
+ ('Cabinetry', 7956),
110
+ ('Laptop', 7907),
111
+ ('Insect', 7497),
112
+ ('Goggles', 7363),
113
+ ('Shirt', 7098),
114
+ ('Dairy Product', 7021),
115
+ ('Marine invertebrates', 7014),
116
+ ('Cattle', 7006),
117
+ ('Trousers', 6903),
118
+ ('Van', 6843),
119
+ ('Billboard', 6777),
120
+ ('Balloon', 6367),
121
+ ('Human nose', 6103),
122
+ ('Tent', 6073),
123
+ ('Camera', 6014),
124
+ ('Doll', 6002),
125
+ ('Coat', 5951),
126
+ ('Mobile phone', 5758),
127
+ ('Swimwear', 5729),
128
+ ('Strawberry', 5691),
129
+ ('Stairs', 5643),
130
+ ('Goose', 5599),
131
+ ('Umbrella', 5536),
132
+ ('Cake', 5508),
133
+ ('Sun hat', 5475),
134
+ ('Bench', 5310),
135
+ ('Bookcase', 5163),
136
+ ('Bee', 5140),
137
+ ('Computer monitor', 5078),
138
+ ('Hiking equipment', 4983),
139
+ ('Office building', 4981),
140
+ ('Coffee cup', 4748),
141
+ ('Curtain', 4685),
142
+ ('Plate', 4651),
143
+ ('Box', 4621),
144
+ ('Tomato', 4595),
145
+ ('Coffee table', 4529),
146
+ ('Office supplies', 4473),
147
+ ('Maple', 4416),
148
+ ('Muffin', 4365),
149
+ ('Cocktail', 4234),
150
+ ('Castle', 4197),
151
+ ('Couch', 4134),
152
+ ('Pumpkin', 3983),
153
+ ('Computer keyboard', 3960),
154
+ ('Human mouth', 3926),
155
+ ('Christmas tree', 3893),
156
+ ('Mushroom', 3883),
157
+ ('Swimming pool', 3809),
158
+ ('Pastry', 3799),
159
+ ('Lavender (Plant)', 3769),
160
+ ('Football helmet', 3732),
161
+ ('Bread', 3648),
162
+ ('Traffic sign', 3628),
163
+ ('Common sunflower', 3597),
164
+ ('Television', 3550),
165
+ ('Bed', 3525),
166
+ ('Cookie', 3485),
167
+ ('Fountain', 3484),
168
+ ('Paddle', 3447),
169
+ ('Bicycle helmet', 3429),
170
+ ('Porch', 3420),
171
+ ('Deer', 3387),
172
+ ('Fedora', 3339),
173
+ ('Canoe', 3338),
174
+ ('Carnivore', 3266),
175
+ ('Bowl', 3202),
176
+ ('Human eye', 3166),
177
+ ('Ball', 3118),
178
+ ('Pillow', 3077),
179
+ ('Salad', 3061),
180
+ ('Beetle', 3060),
181
+ ('Orange', 3050),
182
+ ('Drawer', 2958),
183
+ ('Platter', 2937),
184
+ ('Elephant', 2921),
185
+ ('Seafood', 2921),
186
+ ('Monkey', 2915),
187
+ ('Countertop', 2879),
188
+ ('Watercraft', 2831),
189
+ ('Helicopter', 2805),
190
+ ('Kitchen appliance', 2797),
191
+ ('Personal flotation device', 2781),
192
+ ('Swan', 2739),
193
+ ('Lamp', 2711),
194
+ ('Boot', 2695),
195
+ ('Bronze sculpture', 2693),
196
+ ('Chicken', 2677),
197
+ ('Taxi', 2643),
198
+ ('Juice', 2615),
199
+ ('Cowboy hat', 2604),
200
+ ('Apple', 2600),
201
+ ('Tin can', 2590),
202
+ ('Necklace', 2564),
203
+ ('Ice cream', 2560),
204
+ ('Human beard', 2539),
205
+ ('Coin', 2536),
206
+ ('Candle', 2515),
207
+ ('Cart', 2512),
208
+ ('High heels', 2441),
209
+ ('Weapon', 2433),
210
+ ('Handbag', 2406),
211
+ ('Penguin', 2396),
212
+ ('Rifle', 2352),
213
+ ('Violin', 2336),
214
+ ('Skull', 2304),
215
+ ('Lantern', 2285),
216
+ ('Scarf', 2269),
217
+ ('Saucer', 2225),
218
+ ('Sheep', 2215),
219
+ ('Vase', 2189),
220
+ ('Lily', 2180),
221
+ ('Mug', 2154),
222
+ ('Parrot', 2140),
223
+ ('Human ear', 2137),
224
+ ('Sandal', 2115),
225
+ ('Lizard', 2100),
226
+ ('Kitchen & dining room table', 2063),
227
+ ('Spider', 1977),
228
+ ('Coffee', 1974),
229
+ ('Goat', 1926),
230
+ ('Squirrel', 1922),
231
+ ('Cello', 1913),
232
+ ('Sushi', 1881),
233
+ ('Tortoise', 1876),
234
+ ('Pizza', 1870),
235
+ ('Studio couch', 1864),
236
+ ('Barrel', 1862),
237
+ ('Cosmetics', 1841),
238
+ ('Moths and butterflies', 1841),
239
+ ('Convenience store', 1817),
240
+ ('Watch', 1792),
241
+ ('Home appliance', 1786),
242
+ ('Harbor seal', 1780),
243
+ ('Luggage and bags', 1756),
244
+ ('Vehicle registration plate', 1754),
245
+ ('Shrimp', 1751),
246
+ ('Jellyfish', 1730),
247
+ ('French fries', 1723),
248
+ ('Egg (Food)', 1698),
249
+ ('Football', 1697),
250
+ ('Musical keyboard', 1683),
251
+ ('Falcon', 1674),
252
+ ('Candy', 1660),
253
+ ('Medical equipment', 1654),
254
+ ('Eagle', 1651),
255
+ ('Dinosaur', 1634),
256
+ ('Surfboard', 1630),
257
+ ('Tank', 1628),
258
+ ('Grape', 1624),
259
+ ('Lion', 1624),
260
+ ('Owl', 1622),
261
+ ('Ski', 1613),
262
+ ('Waste container', 1606),
263
+ ('Frog', 1591),
264
+ ('Sparrow', 1585),
265
+ ('Rabbit', 1581),
266
+ ('Pen', 1546),
267
+ ('Sea lion', 1537),
268
+ ('Spoon', 1521),
269
+ ('Sink', 1512),
270
+ ('Teddy bear', 1507),
271
+ ('Bull', 1495),
272
+ ('Sofa bed', 1490),
273
+ ('Dragonfly', 1479),
274
+ ('Brassiere', 1478),
275
+ ('Chest of drawers', 1472),
276
+ ('Aircraft', 1466),
277
+ ('Human foot', 1463),
278
+ ('Pig', 1455),
279
+ ('Fork', 1454),
280
+ ('Antelope', 1438),
281
+ ('Tripod', 1427),
282
+ ('Tool', 1424),
283
+ ('Cheese', 1422),
284
+ ('Lemon', 1397),
285
+ ('Hamburger', 1393),
286
+ ('Dolphin', 1390),
287
+ ('Mirror', 1390),
288
+ ('Marine mammal', 1387),
289
+ ('Giraffe', 1385),
290
+ ('Snake', 1368),
291
+ ('Gondola', 1364),
292
+ ('Wheelchair', 1360),
293
+ ('Piano', 1358),
294
+ ('Cupboard', 1348),
295
+ ('Banana', 1345),
296
+ ('Trumpet', 1335),
297
+ ('Lighthouse', 1333),
298
+ ('Invertebrate', 1317),
299
+ ('Carrot', 1268),
300
+ ('Sock', 1260),
301
+ ('Tiger', 1241),
302
+ ('Camel', 1224),
303
+ ('Parachute', 1224),
304
+ ('Bathroom accessory', 1223),
305
+ ('Earrings', 1221),
306
+ ('Headphones', 1218),
307
+ ('Skirt', 1198),
308
+ ('Skateboard', 1190),
309
+ ('Sandwich', 1148),
310
+ ('Saxophone', 1141),
311
+ ('Goldfish', 1136),
312
+ ('Stool', 1104),
313
+ ('Traffic light', 1097),
314
+ ('Shellfish', 1081),
315
+ ('Backpack', 1079),
316
+ ('Sea turtle', 1078),
317
+ ('Cucumber', 1075),
318
+ ('Tea', 1051),
319
+ ('Toilet', 1047),
320
+ ('Roller skates', 1040),
321
+ ('Mule', 1039),
322
+ ('Bust', 1031),
323
+ ('Broccoli', 1030),
324
+ ('Crab', 1020),
325
+ ('Oyster', 1019),
326
+ ('Cannon', 1012),
327
+ ('Zebra', 1012),
328
+ ('French horn', 1008),
329
+ ('Grapefruit', 998),
330
+ ('Whiteboard', 997),
331
+ ('Zucchini', 997),
332
+ ('Crocodile', 992),
333
+
334
+ ('Clock', 960),
335
+ ('Wall clock', 958),
336
+
337
+ ('Doughnut', 869),
338
+ ('Snail', 868),
339
+
340
+ ('Baseball glove', 859),
341
+
342
+ ('Panda', 830),
343
+ ('Tennis racket', 830),
344
+
345
+ ('Pear', 652),
346
+
347
+ ('Bagel', 617),
348
+ ('Oven', 616),
349
+ ('Ladybug', 615),
350
+ ('Shark', 615),
351
+ ('Polar bear', 614),
352
+ ('Ostrich', 609),
353
+
354
+ ('Hot dog', 473),
355
+ ('Microwave oven', 467),
356
+ ('Fire hydrant', 20),
357
+ ('Stop sign', 20),
358
+ ('Parking meter', 20),
359
+ ('Bear', 20),
360
+ ('Flying disc', 20),
361
+ ('Snowboard', 20),
362
+ ('Tennis ball', 20),
363
+ ('Kite', 20),
364
+ ('Baseball bat', 20),
365
+ ('Kitchen knife', 20),
366
+ ('Knife', 20),
367
+ ('Submarine sandwich', 20),
368
+ ('Computer mouse', 20),
369
+ ('Remote control', 20),
370
+ ('Toaster', 20),
371
+ ('Sink', 20),
372
+ ('Refrigerator', 20),
373
+ ('Alarm clock', 20),
374
+ ('Wall clock', 20),
375
+ ('Scissors', 20),
376
+ ('Hair dryer', 20),
377
+ ('Toothbrush', 20),
378
+ ('Suitcase', 20)
379
+ ]
taming/data/sflckr.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import albumentations
5
+ from PIL import Image
6
+ from torch.utils.data import Dataset
7
+
8
+
9
+ class SegmentationBase(Dataset):
10
+ def __init__(self,
11
+ data_csv, data_root, segmentation_root,
12
+ size=None, random_crop=False, interpolation="bicubic",
13
+ n_labels=182, shift_segmentation=False,
14
+ ):
15
+ self.n_labels = n_labels
16
+ self.shift_segmentation = shift_segmentation
17
+ self.data_csv = data_csv
18
+ self.data_root = data_root
19
+ self.segmentation_root = segmentation_root
20
+ with open(self.data_csv, "r") as f:
21
+ self.image_paths = f.read().splitlines()
22
+ self._length = len(self.image_paths)
23
+ self.labels = {
24
+ "relative_file_path_": [l for l in self.image_paths],
25
+ "file_path_": [os.path.join(self.data_root, l)
26
+ for l in self.image_paths],
27
+ "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
28
+ for l in self.image_paths]
29
+ }
30
+
31
+ size = None if size is not None and size<=0 else size
32
+ self.size = size
33
+ if self.size is not None:
34
+ self.interpolation = interpolation
35
+ self.interpolation = {
36
+ "nearest": cv2.INTER_NEAREST,
37
+ "bilinear": cv2.INTER_LINEAR,
38
+ "bicubic": cv2.INTER_CUBIC,
39
+ "area": cv2.INTER_AREA,
40
+ "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
41
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
42
+ interpolation=self.interpolation)
43
+ self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
44
+ interpolation=cv2.INTER_NEAREST)
45
+ self.center_crop = not random_crop
46
+ if self.center_crop:
47
+ self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
48
+ else:
49
+ self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
50
+ self.preprocessor = self.cropper
51
+
52
+ def __len__(self):
53
+ return self._length
54
+
55
+ def __getitem__(self, i):
56
+ example = dict((k, self.labels[k][i]) for k in self.labels)
57
+ image = Image.open(example["file_path_"])
58
+ if not image.mode == "RGB":
59
+ image = image.convert("RGB")
60
+ image = np.array(image).astype(np.uint8)
61
+ if self.size is not None:
62
+ image = self.image_rescaler(image=image)["image"]
63
+ segmentation = Image.open(example["segmentation_path_"])
64
+ assert segmentation.mode == "L", segmentation.mode
65
+ segmentation = np.array(segmentation).astype(np.uint8)
66
+ if self.shift_segmentation:
67
+ # used to support segmentations containing unlabeled==255 label
68
+ segmentation = segmentation+1
69
+ if self.size is not None:
70
+ segmentation = self.segmentation_rescaler(image=segmentation)["image"]
71
+ if self.size is not None:
72
+ processed = self.preprocessor(image=image,
73
+ mask=segmentation
74
+ )
75
+ else:
76
+ processed = {"image": image,
77
+ "mask": segmentation
78
+ }
79
+ example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
80
+ segmentation = processed["mask"]
81
+ onehot = np.eye(self.n_labels)[segmentation]
82
+ example["segmentation"] = onehot
83
+ return example
84
+
85
+
86
+ class Examples(SegmentationBase):
87
+ def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
88
+ super().__init__(data_csv="data/sflckr_examples.txt",
89
+ data_root="data/sflckr_images",
90
+ segmentation_root="data/sflckr_segmentations",
91
+ size=size, random_crop=random_crop, interpolation=interpolation)
taming/data/utils.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os
3
+ import tarfile
4
+ import urllib
5
+ import zipfile
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import torch
10
+ from taming.data.helper_types import Annotation
11
+ from torch._six import string_classes
12
+ from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
13
+ from tqdm import tqdm
14
+
15
+
16
+ def unpack(path):
17
+ if path.endswith("tar.gz"):
18
+ with tarfile.open(path, "r:gz") as tar:
19
+ tar.extractall(path=os.path.split(path)[0])
20
+ elif path.endswith("tar"):
21
+ with tarfile.open(path, "r:") as tar:
22
+ tar.extractall(path=os.path.split(path)[0])
23
+ elif path.endswith("zip"):
24
+ with zipfile.ZipFile(path, "r") as f:
25
+ f.extractall(path=os.path.split(path)[0])
26
+ else:
27
+ raise NotImplementedError(
28
+ "Unknown file extension: {}".format(os.path.splitext(path)[1])
29
+ )
30
+
31
+
32
+ def reporthook(bar):
33
+ """tqdm progress bar for downloads."""
34
+
35
+ def hook(b=1, bsize=1, tsize=None):
36
+ if tsize is not None:
37
+ bar.total = tsize
38
+ bar.update(b * bsize - bar.n)
39
+
40
+ return hook
41
+
42
+
43
+ def get_root(name):
44
+ base = "data/"
45
+ root = os.path.join(base, name)
46
+ os.makedirs(root, exist_ok=True)
47
+ return root
48
+
49
+
50
+ def is_prepared(root):
51
+ return Path(root).joinpath(".ready").exists()
52
+
53
+
54
+ def mark_prepared(root):
55
+ Path(root).joinpath(".ready").touch()
56
+
57
+
58
+ def prompt_download(file_, source, target_dir, content_dir=None):
59
+ targetpath = os.path.join(target_dir, file_)
60
+ while not os.path.exists(targetpath):
61
+ if content_dir is not None and os.path.exists(
62
+ os.path.join(target_dir, content_dir)
63
+ ):
64
+ break
65
+ print(
66
+ "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
67
+ )
68
+ if content_dir is not None:
69
+ print(
70
+ "Or place its content into '{}'.".format(
71
+ os.path.join(target_dir, content_dir)
72
+ )
73
+ )
74
+ input("Press Enter when done...")
75
+ return targetpath
76
+
77
+
78
+ def download_url(file_, url, target_dir):
79
+ targetpath = os.path.join(target_dir, file_)
80
+ os.makedirs(target_dir, exist_ok=True)
81
+ with tqdm(
82
+ unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
83
+ ) as bar:
84
+ urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
85
+ return targetpath
86
+
87
+
88
+ def download_urls(urls, target_dir):
89
+ paths = dict()
90
+ for fname, url in urls.items():
91
+ outpath = download_url(fname, url, target_dir)
92
+ paths[fname] = outpath
93
+ return paths
94
+
95
+
96
+ def quadratic_crop(x, bbox, alpha=1.0):
97
+ """bbox is xmin, ymin, xmax, ymax"""
98
+ im_h, im_w = x.shape[:2]
99
+ bbox = np.array(bbox, dtype=np.float32)
100
+ bbox = np.clip(bbox, 0, max(im_h, im_w))
101
+ center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
102
+ w = bbox[2] - bbox[0]
103
+ h = bbox[3] - bbox[1]
104
+ l = int(alpha * max(w, h))
105
+ l = max(l, 2)
106
+
107
+ required_padding = -1 * min(
108
+ center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
109
+ )
110
+ required_padding = int(np.ceil(required_padding))
111
+ if required_padding > 0:
112
+ padding = [
113
+ [required_padding, required_padding],
114
+ [required_padding, required_padding],
115
+ ]
116
+ padding += [[0, 0]] * (len(x.shape) - 2)
117
+ x = np.pad(x, padding, "reflect")
118
+ center = center[0] + required_padding, center[1] + required_padding
119
+ xmin = int(center[0] - l / 2)
120
+ ymin = int(center[1] - l / 2)
121
+ return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
122
+
123
+
124
+ def custom_collate(batch):
125
+ r"""source: pytorch 1.9.0, only one modification to original code """
126
+
127
+ elem = batch[0]
128
+ elem_type = type(elem)
129
+ if isinstance(elem, torch.Tensor):
130
+ out = None
131
+ if torch.utils.data.get_worker_info() is not None:
132
+ # If we're in a background process, concatenate directly into a
133
+ # shared memory tensor to avoid an extra copy
134
+ numel = sum([x.numel() for x in batch])
135
+ storage = elem.storage()._new_shared(numel)
136
+ out = elem.new(storage)
137
+ return torch.stack(batch, 0, out=out)
138
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
139
+ and elem_type.__name__ != 'string_':
140
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
141
+ # array of string classes and object
142
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
143
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
144
+
145
+ return custom_collate([torch.as_tensor(b) for b in batch])
146
+ elif elem.shape == (): # scalars
147
+ return torch.as_tensor(batch)
148
+ elif isinstance(elem, float):
149
+ return torch.tensor(batch, dtype=torch.float64)
150
+ elif isinstance(elem, int):
151
+ return torch.tensor(batch)
152
+ elif isinstance(elem, string_classes):
153
+ return batch
154
+ elif isinstance(elem, collections.abc.Mapping):
155
+ return {key: custom_collate([d[key] for d in batch]) for key in elem}
156
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
157
+ return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
158
+ if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
159
+ return batch # added
160
+ elif isinstance(elem, collections.abc.Sequence):
161
+ # check to make sure that the elements in batch have consistent size
162
+ it = iter(batch)
163
+ elem_size = len(next(it))
164
+ if not all(len(elem) == elem_size for elem in it):
165
+ raise RuntimeError('each element in list of batch should be of equal size')
166
+ transposed = zip(*batch)
167
+ return [custom_collate(samples) for samples in transposed]
168
+
169
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
taming/lr_scheduler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n):
33
+ return self.schedule(n)
34
+
taming/models/cond_transformer.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import pytorch_lightning as pl
5
+
6
+ from main import instantiate_from_config
7
+ from taming.modules.util import SOSProvider
8
+
9
+
10
+ def disabled_train(self, mode=True):
11
+ """Overwrite model.train with this function to make sure train/eval mode
12
+ does not change anymore."""
13
+ return self
14
+
15
+
16
+ class Net2NetTransformer(pl.LightningModule):
17
+ def __init__(self,
18
+ transformer_config,
19
+ first_stage_config,
20
+ cond_stage_config,
21
+ permuter_config=None,
22
+ ckpt_path=None,
23
+ ignore_keys=[],
24
+ first_stage_key="image",
25
+ cond_stage_key="depth",
26
+ downsample_cond_size=-1,
27
+ pkeep=1.0,
28
+ sos_token=0,
29
+ unconditional=False,
30
+ ):
31
+ super().__init__()
32
+ self.be_unconditional = unconditional
33
+ self.sos_token = sos_token
34
+ self.first_stage_key = first_stage_key
35
+ self.cond_stage_key = cond_stage_key
36
+ self.init_first_stage_from_ckpt(first_stage_config)
37
+ self.init_cond_stage_from_ckpt(cond_stage_config)
38
+ if permuter_config is None:
39
+ permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
40
+ self.permuter = instantiate_from_config(config=permuter_config)
41
+ self.transformer = instantiate_from_config(config=transformer_config)
42
+
43
+ if ckpt_path is not None:
44
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45
+ self.downsample_cond_size = downsample_cond_size
46
+ self.pkeep = pkeep
47
+
48
+ def init_from_ckpt(self, path, ignore_keys=list()):
49
+ sd = torch.load(path, map_location="cpu")["state_dict"]
50
+ for k in sd.keys():
51
+ for ik in ignore_keys:
52
+ if k.startswith(ik):
53
+ self.print("Deleting key {} from state_dict.".format(k))
54
+ del sd[k]
55
+ self.load_state_dict(sd, strict=False)
56
+ print(f"Restored from {path}")
57
+
58
+ def init_first_stage_from_ckpt(self, config):
59
+ model = instantiate_from_config(config)
60
+ model = model.eval()
61
+ model.train = disabled_train
62
+ self.first_stage_model = model
63
+
64
+ def init_cond_stage_from_ckpt(self, config):
65
+ if config == "__is_first_stage__":
66
+ print("Using first stage also as cond stage.")
67
+ self.cond_stage_model = self.first_stage_model
68
+ elif config == "__is_unconditional__" or self.be_unconditional:
69
+ print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
70
+ f"Prepending {self.sos_token} as a sos token.")
71
+ self.be_unconditional = True
72
+ self.cond_stage_key = self.first_stage_key
73
+ self.cond_stage_model = SOSProvider(self.sos_token)
74
+ else:
75
+ model = instantiate_from_config(config)
76
+ model = model.eval()
77
+ model.train = disabled_train
78
+ self.cond_stage_model = model
79
+
80
+ def forward(self, x, c):
81
+ # one step to produce the logits
82
+ _, z_indices = self.encode_to_z(x)
83
+ _, c_indices = self.encode_to_c(c)
84
+
85
+ if self.training and self.pkeep < 1.0:
86
+ mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
87
+ device=z_indices.device))
88
+ mask = mask.round().to(dtype=torch.int64)
89
+ r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
90
+ a_indices = mask*z_indices+(1-mask)*r_indices
91
+ else:
92
+ a_indices = z_indices
93
+
94
+ cz_indices = torch.cat((c_indices, a_indices), dim=1)
95
+
96
+ # target includes all sequence elements (no need to handle first one
97
+ # differently because we are conditioning)
98
+ target = z_indices
99
+ # make the prediction
100
+ logits, _ = self.transformer(cz_indices[:, :-1])
101
+ # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
102
+ logits = logits[:, c_indices.shape[1]-1:]
103
+
104
+ return logits, target
105
+
106
+ def top_k_logits(self, logits, k):
107
+ v, ix = torch.topk(logits, k)
108
+ out = logits.clone()
109
+ out[out < v[..., [-1]]] = -float('Inf')
110
+ return out
111
+
112
+ @torch.no_grad()
113
+ def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
114
+ callback=lambda k: None):
115
+ x = torch.cat((c,x),dim=1)
116
+ block_size = self.transformer.get_block_size()
117
+ assert not self.transformer.training
118
+ if self.pkeep <= 0.0:
119
+ # one pass suffices since input is pure noise anyway
120
+ assert len(x.shape)==2
121
+ noise_shape = (x.shape[0], steps-1)
122
+ #noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
123
+ noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
124
+ x = torch.cat((x,noise),dim=1)
125
+ logits, _ = self.transformer(x)
126
+ # take all logits for now and scale by temp
127
+ logits = logits / temperature
128
+ # optionally crop probabilities to only the top k options
129
+ if top_k is not None:
130
+ logits = self.top_k_logits(logits, top_k)
131
+ # apply softmax to convert to probabilities
132
+ probs = F.softmax(logits, dim=-1)
133
+ # sample from the distribution or take the most likely
134
+ if sample:
135
+ shape = probs.shape
136
+ probs = probs.reshape(shape[0]*shape[1],shape[2])
137
+ ix = torch.multinomial(probs, num_samples=1)
138
+ probs = probs.reshape(shape[0],shape[1],shape[2])
139
+ ix = ix.reshape(shape[0],shape[1])
140
+ else:
141
+ _, ix = torch.topk(probs, k=1, dim=-1)
142
+ # cut off conditioning
143
+ x = ix[:, c.shape[1]-1:]
144
+ else:
145
+ for k in range(steps):
146
+ callback(k)
147
+ assert x.size(1) <= block_size # make sure model can see conditioning
148
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
149
+ logits, _ = self.transformer(x_cond)
150
+ # pluck the logits at the final step and scale by temperature
151
+ logits = logits[:, -1, :] / temperature
152
+ # optionally crop probabilities to only the top k options
153
+ if top_k is not None:
154
+ logits = self.top_k_logits(logits, top_k)
155
+ # apply softmax to convert to probabilities
156
+ probs = F.softmax(logits, dim=-1)
157
+ # sample from the distribution or take the most likely
158
+ if sample:
159
+ ix = torch.multinomial(probs, num_samples=1)
160
+ else:
161
+ _, ix = torch.topk(probs, k=1, dim=-1)
162
+ # append to the sequence and continue
163
+ x = torch.cat((x, ix), dim=1)
164
+ # cut off conditioning
165
+ x = x[:, c.shape[1]:]
166
+ return x
167
+
168
+ @torch.no_grad()
169
+ def encode_to_z(self, x):
170
+ quant_z, _, info = self.first_stage_model.encode(x)
171
+ indices = info[2].view(quant_z.shape[0], -1)
172
+ indices = self.permuter(indices)
173
+ return quant_z, indices
174
+
175
+ @torch.no_grad()
176
+ def encode_to_c(self, c):
177
+ if self.downsample_cond_size > -1:
178
+ c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
179
+ quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
180
+ if len(indices.shape) > 2:
181
+ indices = indices.view(c.shape[0], -1)
182
+ return quant_c, indices
183
+
184
+ @torch.no_grad()
185
+ def decode_to_img(self, index, zshape):
186
+ index = self.permuter(index, reverse=True)
187
+ bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
188
+ quant_z = self.first_stage_model.quantize.get_codebook_entry(
189
+ index.reshape(-1), shape=bhwc)
190
+ x = self.first_stage_model.decode(quant_z)
191
+ return x
192
+
193
+ @torch.no_grad()
194
+ def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
195
+ log = dict()
196
+
197
+ N = 4
198
+ if lr_interface:
199
+ x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
200
+ else:
201
+ x, c = self.get_xc(batch, N)
202
+ x = x.to(device=self.device)
203
+ c = c.to(device=self.device)
204
+
205
+ quant_z, z_indices = self.encode_to_z(x)
206
+ quant_c, c_indices = self.encode_to_c(c)
207
+
208
+ # create a "half"" sample
209
+ z_start_indices = z_indices[:,:z_indices.shape[1]//2]
210
+ index_sample = self.sample(z_start_indices, c_indices,
211
+ steps=z_indices.shape[1]-z_start_indices.shape[1],
212
+ temperature=temperature if temperature is not None else 1.0,
213
+ sample=True,
214
+ top_k=top_k if top_k is not None else 100,
215
+ callback=callback if callback is not None else lambda k: None)
216
+ x_sample = self.decode_to_img(index_sample, quant_z.shape)
217
+
218
+ # sample
219
+ z_start_indices = z_indices[:, :0]
220
+ index_sample = self.sample(z_start_indices, c_indices,
221
+ steps=z_indices.shape[1],
222
+ temperature=temperature if temperature is not None else 1.0,
223
+ sample=True,
224
+ top_k=top_k if top_k is not None else 100,
225
+ callback=callback if callback is not None else lambda k: None)
226
+ x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
227
+
228
+ # det sample
229
+ z_start_indices = z_indices[:, :0]
230
+ index_sample = self.sample(z_start_indices, c_indices,
231
+ steps=z_indices.shape[1],
232
+ sample=False,
233
+ callback=callback if callback is not None else lambda k: None)
234
+ x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
235
+
236
+ # reconstruction
237
+ x_rec = self.decode_to_img(z_indices, quant_z.shape)
238
+
239
+ log["inputs"] = x
240
+ log["reconstructions"] = x_rec
241
+
242
+ if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
243
+ figure_size = (x_rec.shape[2], x_rec.shape[3])
244
+ dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
245
+ label_for_category_no = dataset.get_textual_label_for_category_no
246
+ plotter = dataset.conditional_builders[self.cond_stage_key].plot
247
+ log["conditioning"] = torch.zeros_like(log["reconstructions"])
248
+ for i in range(quant_c.shape[0]):
249
+ log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
250
+ log["conditioning_rec"] = log["conditioning"]
251
+ elif self.cond_stage_key != "image":
252
+ cond_rec = self.cond_stage_model.decode(quant_c)
253
+ if self.cond_stage_key == "segmentation":
254
+ # get image from segmentation mask
255
+ num_classes = cond_rec.shape[1]
256
+
257
+ c = torch.argmax(c, dim=1, keepdim=True)
258
+ c = F.one_hot(c, num_classes=num_classes)
259
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
260
+ c = self.cond_stage_model.to_rgb(c)
261
+
262
+ cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
263
+ cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
264
+ cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
265
+ cond_rec = self.cond_stage_model.to_rgb(cond_rec)
266
+ log["conditioning_rec"] = cond_rec
267
+ log["conditioning"] = c
268
+
269
+ log["samples_half"] = x_sample
270
+ log["samples_nopix"] = x_sample_nopix
271
+ log["samples_det"] = x_sample_det
272
+ return log
273
+
274
+ def get_input(self, key, batch):
275
+ x = batch[key]
276
+ if len(x.shape) == 3:
277
+ x = x[..., None]
278
+ if len(x.shape) == 4:
279
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
280
+ if x.dtype == torch.double:
281
+ x = x.float()
282
+ return x
283
+
284
+ def get_xc(self, batch, N=None):
285
+ x = self.get_input(self.first_stage_key, batch)
286
+ c = self.get_input(self.cond_stage_key, batch)
287
+ if N is not None:
288
+ x = x[:N]
289
+ c = c[:N]
290
+ return x, c
291
+
292
+ def shared_step(self, batch, batch_idx):
293
+ x, c = self.get_xc(batch)
294
+ logits, target = self(x, c)
295
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
296
+ return loss
297
+
298
+ def training_step(self, batch, batch_idx):
299
+ loss = self.shared_step(batch, batch_idx)
300
+ self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
301
+ return loss
302
+
303
+ def validation_step(self, batch, batch_idx):
304
+ loss = self.shared_step(batch, batch_idx)
305
+ self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
306
+ return loss
307
+
308
+ def configure_optimizers(self):
309
+ """
310
+ Following minGPT:
311
+ This long function is unfortunately doing something very simple and is being very defensive:
312
+ We are separating out all parameters of the model into two buckets: those that will experience
313
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
314
+ We are then returning the PyTorch optimizer object.
315
+ """
316
+ # separate out all parameters to those that will and won't experience regularizing weight decay
317
+ decay = set()
318
+ no_decay = set()
319
+ whitelist_weight_modules = (torch.nn.Linear, )
320
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
321
+ for mn, m in self.transformer.named_modules():
322
+ for pn, p in m.named_parameters():
323
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
324
+
325
+ if pn.endswith('bias'):
326
+ # all biases will not be decayed
327
+ no_decay.add(fpn)
328
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
329
+ # weights of whitelist modules will be weight decayed
330
+ decay.add(fpn)
331
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
332
+ # weights of blacklist modules will NOT be weight decayed
333
+ no_decay.add(fpn)
334
+
335
+ # special case the position embedding parameter in the root GPT module as not decayed
336
+ no_decay.add('pos_emb')
337
+
338
+ # validate that we considered every parameter
339
+ param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
340
+ inter_params = decay & no_decay
341
+ union_params = decay | no_decay
342
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
343
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
344
+ % (str(param_dict.keys() - union_params), )
345
+
346
+ # create the pytorch optimizer object
347
+ optim_groups = [
348
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
349
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
350
+ ]
351
+ optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
352
+ return optimizer
taming/models/dummy_cond_stage.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+
3
+
4
+ class DummyCondStage:
5
+ def __init__(self, conditional_key):
6
+ self.conditional_key = conditional_key
7
+ self.train = None
8
+
9
+ def eval(self):
10
+ return self
11
+
12
+ @staticmethod
13
+ def encode(c: Tensor):
14
+ return c, None, (None, None, c)
15
+
16
+ @staticmethod
17
+ def decode(c: Tensor):
18
+ return c
19
+
20
+ @staticmethod
21
+ def to_rgb(c: Tensor):
22
+ return c
taming/models/vqgan.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import pytorch_lightning as pl
4
+
5
+ from main import instantiate_from_config
6
+
7
+ from taming.modules.diffusionmodules.model import Encoder, Decoder
8
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
9
+ from taming.modules.vqvae.quantize import GumbelQuantize
10
+ from taming.modules.vqvae.quantize import EMAVectorQuantizer
11
+
12
+ class VQModel(pl.LightningModule):
13
+ def __init__(self,
14
+ ddconfig,
15
+ lossconfig,
16
+ n_embed,
17
+ embed_dim,
18
+ ckpt_path=None,
19
+ ignore_keys=[],
20
+ image_key="image",
21
+ colorize_nlabels=None,
22
+ monitor=None,
23
+ remap=None,
24
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
25
+ ):
26
+ super().__init__()
27
+ self.image_key = image_key
28
+ self.encoder = Encoder(**ddconfig)
29
+ self.decoder = Decoder(**ddconfig)
30
+ self.loss = instantiate_from_config(lossconfig)
31
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
32
+ remap=remap, sane_index_shape=sane_index_shape)
33
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
34
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35
+ if ckpt_path is not None:
36
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
37
+ self.image_key = image_key
38
+ if colorize_nlabels is not None:
39
+ assert type(colorize_nlabels)==int
40
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
41
+ if monitor is not None:
42
+ self.monitor = monitor
43
+
44
+ def init_from_ckpt(self, path, ignore_keys=list()):
45
+ sd = torch.load(path, map_location="cpu")["state_dict"]
46
+ keys = list(sd.keys())
47
+ for k in keys:
48
+ for ik in ignore_keys:
49
+ if k.startswith(ik):
50
+ print("Deleting key {} from state_dict.".format(k))
51
+ del sd[k]
52
+ self.load_state_dict(sd, strict=False)
53
+ print(f"Restored from {path}")
54
+
55
+ def encode(self, x):
56
+ h = self.encoder(x)
57
+ h = self.quant_conv(h)
58
+ quant, emb_loss, info = self.quantize(h)
59
+ return quant, emb_loss, info
60
+
61
+ def decode(self, quant):
62
+ quant = self.post_quant_conv(quant)
63
+ dec = self.decoder(quant)
64
+ return dec
65
+
66
+ def decode_code(self, code_b):
67
+ quant_b = self.quantize.embed_code(code_b)
68
+ dec = self.decode(quant_b)
69
+ return dec
70
+
71
+ def forward(self, input):
72
+ quant, diff, _ = self.encode(input)
73
+ dec = self.decode(quant)
74
+ return dec, diff
75
+
76
+ def get_input(self, batch, k):
77
+ x = batch[k]
78
+ if len(x.shape) == 3:
79
+ x = x[..., None]
80
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
81
+ return x.float()
82
+
83
+ def training_step(self, batch, batch_idx, optimizer_idx):
84
+ x = self.get_input(batch, self.image_key)
85
+ xrec, qloss = self(x)
86
+
87
+ if optimizer_idx == 0:
88
+ # autoencode
89
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
90
+ last_layer=self.get_last_layer(), split="train")
91
+
92
+ self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
93
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
94
+ return aeloss
95
+
96
+ if optimizer_idx == 1:
97
+ # discriminator
98
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
99
+ last_layer=self.get_last_layer(), split="train")
100
+ self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
101
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
102
+ return discloss
103
+
104
+ def validation_step(self, batch, batch_idx):
105
+ x = self.get_input(batch, self.image_key)
106
+ xrec, qloss = self(x)
107
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
108
+ last_layer=self.get_last_layer(), split="val")
109
+
110
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
111
+ last_layer=self.get_last_layer(), split="val")
112
+ rec_loss = log_dict_ae["val/rec_loss"]
113
+ self.log("val/rec_loss", rec_loss,
114
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
115
+ self.log("val/aeloss", aeloss,
116
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
117
+ self.log_dict(log_dict_ae)
118
+ self.log_dict(log_dict_disc)
119
+ return self.log_dict
120
+
121
+ def configure_optimizers(self):
122
+ lr = self.learning_rate
123
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
124
+ list(self.decoder.parameters())+
125
+ list(self.quantize.parameters())+
126
+ list(self.quant_conv.parameters())+
127
+ list(self.post_quant_conv.parameters()),
128
+ lr=lr, betas=(0.5, 0.9))
129
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
130
+ lr=lr, betas=(0.5, 0.9))
131
+ return [opt_ae, opt_disc], []
132
+
133
+ def get_last_layer(self):
134
+ return self.decoder.conv_out.weight
135
+
136
+ def log_images(self, batch, **kwargs):
137
+ log = dict()
138
+ x = self.get_input(batch, self.image_key)
139
+ x = x.to(self.device)
140
+ xrec, _ = self(x)
141
+ if x.shape[1] > 3:
142
+ # colorize with random projection
143
+ assert xrec.shape[1] > 3
144
+ x = self.to_rgb(x)
145
+ xrec = self.to_rgb(xrec)
146
+ log["inputs"] = x
147
+ log["reconstructions"] = xrec
148
+ return log
149
+
150
+ def to_rgb(self, x):
151
+ assert self.image_key == "segmentation"
152
+ if not hasattr(self, "colorize"):
153
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
154
+ x = F.conv2d(x, weight=self.colorize)
155
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
156
+ return x
157
+
158
+
159
+ class VQSegmentationModel(VQModel):
160
+ def __init__(self, n_labels, *args, **kwargs):
161
+ super().__init__(*args, **kwargs)
162
+ self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
163
+
164
+ def configure_optimizers(self):
165
+ lr = self.learning_rate
166
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
167
+ list(self.decoder.parameters())+
168
+ list(self.quantize.parameters())+
169
+ list(self.quant_conv.parameters())+
170
+ list(self.post_quant_conv.parameters()),
171
+ lr=lr, betas=(0.5, 0.9))
172
+ return opt_ae
173
+
174
+ def training_step(self, batch, batch_idx):
175
+ x = self.get_input(batch, self.image_key)
176
+ xrec, qloss = self(x)
177
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
178
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
179
+ return aeloss
180
+
181
+ def validation_step(self, batch, batch_idx):
182
+ x = self.get_input(batch, self.image_key)
183
+ xrec, qloss = self(x)
184
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
185
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
186
+ total_loss = log_dict_ae["val/total_loss"]
187
+ self.log("val/total_loss", total_loss,
188
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
189
+ return aeloss
190
+
191
+ @torch.no_grad()
192
+ def log_images(self, batch, **kwargs):
193
+ log = dict()
194
+ x = self.get_input(batch, self.image_key)
195
+ x = x.to(self.device)
196
+ xrec, _ = self(x)
197
+ if x.shape[1] > 3:
198
+ # colorize with random projection
199
+ assert xrec.shape[1] > 3
200
+ # convert logits to indices
201
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
202
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
203
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
204
+ x = self.to_rgb(x)
205
+ xrec = self.to_rgb(xrec)
206
+ log["inputs"] = x
207
+ log["reconstructions"] = xrec
208
+ return log
209
+
210
+
211
+ class VQNoDiscModel(VQModel):
212
+ def __init__(self,
213
+ ddconfig,
214
+ lossconfig,
215
+ n_embed,
216
+ embed_dim,
217
+ ckpt_path=None,
218
+ ignore_keys=[],
219
+ image_key="image",
220
+ colorize_nlabels=None
221
+ ):
222
+ super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
223
+ ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
224
+ colorize_nlabels=colorize_nlabels)
225
+
226
+ def training_step(self, batch, batch_idx):
227
+ x = self.get_input(batch, self.image_key)
228
+ xrec, qloss = self(x)
229
+ # autoencode
230
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
231
+ output = pl.TrainResult(minimize=aeloss)
232
+ output.log("train/aeloss", aeloss,
233
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
234
+ output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
235
+ return output
236
+
237
+ def validation_step(self, batch, batch_idx):
238
+ x = self.get_input(batch, self.image_key)
239
+ xrec, qloss = self(x)
240
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
241
+ rec_loss = log_dict_ae["val/rec_loss"]
242
+ output = pl.EvalResult(checkpoint_on=rec_loss)
243
+ output.log("val/rec_loss", rec_loss,
244
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
245
+ output.log("val/aeloss", aeloss,
246
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
247
+ output.log_dict(log_dict_ae)
248
+
249
+ return output
250
+
251
+ def configure_optimizers(self):
252
+ optimizer = torch.optim.Adam(list(self.encoder.parameters())+
253
+ list(self.decoder.parameters())+
254
+ list(self.quantize.parameters())+
255
+ list(self.quant_conv.parameters())+
256
+ list(self.post_quant_conv.parameters()),
257
+ lr=self.learning_rate, betas=(0.5, 0.9))
258
+ return optimizer
259
+
260
+
261
+ class GumbelVQ(VQModel):
262
+ def __init__(self,
263
+ ddconfig,
264
+ lossconfig,
265
+ n_embed,
266
+ embed_dim,
267
+ temperature_scheduler_config,
268
+ ckpt_path=None,
269
+ ignore_keys=[],
270
+ image_key="image",
271
+ colorize_nlabels=None,
272
+ monitor=None,
273
+ kl_weight=1e-8,
274
+ remap=None,
275
+ ):
276
+
277
+ z_channels = ddconfig["z_channels"]
278
+ super().__init__(ddconfig,
279
+ lossconfig,
280
+ n_embed,
281
+ embed_dim,
282
+ ckpt_path=None,
283
+ ignore_keys=ignore_keys,
284
+ image_key=image_key,
285
+ colorize_nlabels=colorize_nlabels,
286
+ monitor=monitor,
287
+ )
288
+
289
+ self.loss.n_classes = n_embed
290
+ self.vocab_size = n_embed
291
+
292
+ self.quantize = GumbelQuantize(z_channels, embed_dim,
293
+ n_embed=n_embed,
294
+ kl_weight=kl_weight, temp_init=1.0,
295
+ remap=remap)
296
+
297
+ self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
298
+
299
+ if ckpt_path is not None:
300
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
301
+
302
+ def temperature_scheduling(self):
303
+ self.quantize.temperature = self.temperature_scheduler(self.global_step)
304
+
305
+ def encode_to_prequant(self, x):
306
+ h = self.encoder(x)
307
+ h = self.quant_conv(h)
308
+ return h
309
+
310
+ def decode_code(self, code_b):
311
+ raise NotImplementedError
312
+
313
+ def training_step(self, batch, batch_idx, optimizer_idx):
314
+ self.temperature_scheduling()
315
+ x = self.get_input(batch, self.image_key)
316
+ xrec, qloss = self(x)
317
+
318
+ if optimizer_idx == 0:
319
+ # autoencode
320
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
321
+ last_layer=self.get_last_layer(), split="train")
322
+
323
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
324
+ self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
325
+ return aeloss
326
+
327
+ if optimizer_idx == 1:
328
+ # discriminator
329
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
330
+ last_layer=self.get_last_layer(), split="train")
331
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
332
+ return discloss
333
+
334
+ def validation_step(self, batch, batch_idx):
335
+ x = self.get_input(batch, self.image_key)
336
+ xrec, qloss = self(x, return_pred_indices=True)
337
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
338
+ last_layer=self.get_last_layer(), split="val")
339
+
340
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
341
+ last_layer=self.get_last_layer(), split="val")
342
+ rec_loss = log_dict_ae["val/rec_loss"]
343
+ self.log("val/rec_loss", rec_loss,
344
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
345
+ self.log("val/aeloss", aeloss,
346
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
347
+ self.log_dict(log_dict_ae)
348
+ self.log_dict(log_dict_disc)
349
+ return self.log_dict
350
+
351
+ def log_images(self, batch, **kwargs):
352
+ log = dict()
353
+ x = self.get_input(batch, self.image_key)
354
+ x = x.to(self.device)
355
+ # encode
356
+ h = self.encoder(x)
357
+ h = self.quant_conv(h)
358
+ quant, _, _ = self.quantize(h)
359
+ # decode
360
+ x_rec = self.decode(quant)
361
+ log["inputs"] = x
362
+ log["reconstructions"] = x_rec
363
+ return log
364
+
365
+
366
+ class EMAVQ(VQModel):
367
+ def __init__(self,
368
+ ddconfig,
369
+ lossconfig,
370
+ n_embed,
371
+ embed_dim,
372
+ ckpt_path=None,
373
+ ignore_keys=[],
374
+ image_key="image",
375
+ colorize_nlabels=None,
376
+ monitor=None,
377
+ remap=None,
378
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
379
+ ):
380
+ super().__init__(ddconfig,
381
+ lossconfig,
382
+ n_embed,
383
+ embed_dim,
384
+ ckpt_path=None,
385
+ ignore_keys=ignore_keys,
386
+ image_key=image_key,
387
+ colorize_nlabels=colorize_nlabels,
388
+ monitor=monitor,
389
+ )
390
+ self.quantize = EMAVectorQuantizer(n_embed=n_embed,
391
+ embedding_dim=embed_dim,
392
+ beta=0.25,
393
+ remap=remap)
394
+ def configure_optimizers(self):
395
+ lr = self.learning_rate
396
+ #Remove self.quantize from parameter list since it is updated via EMA
397
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
398
+ list(self.decoder.parameters())+
399
+ list(self.quant_conv.parameters())+
400
+ list(self.post_quant_conv.parameters()),
401
+ lr=lr, betas=(0.5, 0.9))
402
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
403
+ lr=lr, betas=(0.5, 0.9))
404
+ return [opt_ae, opt_disc], []
taming/modules/__pycache__/util.cpython-38.pyc ADDED
Binary file (4.28 kB). View file
 
taming/modules/autoencoder/lpips/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
taming/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+
7
+
8
+ def get_timestep_embedding(timesteps, embedding_dim):
9
+ """
10
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
11
+ From Fairseq.
12
+ Build sinusoidal embeddings.
13
+ This matches the implementation in tensor2tensor, but differs slightly
14
+ from the description in Section 3.5 of "Attention Is All You Need".
15
+ """
16
+ assert len(timesteps.shape) == 1
17
+
18
+ half_dim = embedding_dim // 2
19
+ emb = math.log(10000) / (half_dim - 1)
20
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
21
+ emb = emb.to(device=timesteps.device)
22
+ emb = timesteps.float()[:, None] * emb[None, :]
23
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
24
+ if embedding_dim % 2 == 1: # zero pad
25
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
26
+ return emb
27
+
28
+
29
+ def nonlinearity(x):
30
+ # swish
31
+ return x*torch.sigmoid(x)
32
+
33
+
34
+ def Normalize(in_channels):
35
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
36
+
37
+
38
+ class Upsample(nn.Module):
39
+ def __init__(self, in_channels, with_conv):
40
+ super().__init__()
41
+ self.with_conv = with_conv
42
+ if self.with_conv:
43
+ self.conv = torch.nn.Conv2d(in_channels,
44
+ in_channels,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1)
48
+
49
+ def forward(self, x):
50
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
51
+ if self.with_conv:
52
+ x = self.conv(x)
53
+ return x
54
+
55
+
56
+ class Downsample(nn.Module):
57
+ def __init__(self, in_channels, with_conv):
58
+ super().__init__()
59
+ self.with_conv = with_conv
60
+ if self.with_conv:
61
+ # no asymmetric padding in torch conv, must do it ourselves
62
+ self.conv = torch.nn.Conv2d(in_channels,
63
+ in_channels,
64
+ kernel_size=3,
65
+ stride=2,
66
+ padding=0)
67
+
68
+ def forward(self, x):
69
+ if self.with_conv:
70
+ pad = (0,1,0,1)
71
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
72
+ x = self.conv(x)
73
+ else:
74
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
75
+ return x
76
+
77
+
78
+ class ResnetBlock(nn.Module):
79
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
80
+ dropout, temb_channels=512):
81
+ super().__init__()
82
+ self.in_channels = in_channels
83
+ out_channels = in_channels if out_channels is None else out_channels
84
+ self.out_channels = out_channels
85
+ self.use_conv_shortcut = conv_shortcut
86
+
87
+ self.norm1 = Normalize(in_channels)
88
+ self.conv1 = torch.nn.Conv2d(in_channels,
89
+ out_channels,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1)
93
+ if temb_channels > 0:
94
+ self.temb_proj = torch.nn.Linear(temb_channels,
95
+ out_channels)
96
+ self.norm2 = Normalize(out_channels)
97
+ self.dropout = torch.nn.Dropout(dropout)
98
+ self.conv2 = torch.nn.Conv2d(out_channels,
99
+ out_channels,
100
+ kernel_size=3,
101
+ stride=1,
102
+ padding=1)
103
+ if self.in_channels != self.out_channels:
104
+ if self.use_conv_shortcut:
105
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
106
+ out_channels,
107
+ kernel_size=3,
108
+ stride=1,
109
+ padding=1)
110
+ else:
111
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
112
+ out_channels,
113
+ kernel_size=1,
114
+ stride=1,
115
+ padding=0)
116
+
117
+ def forward(self, x, temb):
118
+ h = x
119
+ h = self.norm1(h)
120
+ h = nonlinearity(h)
121
+ h = self.conv1(h)
122
+
123
+ if temb is not None:
124
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
125
+
126
+ h = self.norm2(h)
127
+ h = nonlinearity(h)
128
+ h = self.dropout(h)
129
+ h = self.conv2(h)
130
+
131
+ if self.in_channels != self.out_channels:
132
+ if self.use_conv_shortcut:
133
+ x = self.conv_shortcut(x)
134
+ else:
135
+ x = self.nin_shortcut(x)
136
+
137
+ return x+h
138
+
139
+
140
+ class AttnBlock(nn.Module):
141
+ def __init__(self, in_channels):
142
+ super().__init__()
143
+ self.in_channels = in_channels
144
+
145
+ self.norm = Normalize(in_channels)
146
+ self.q = torch.nn.Conv2d(in_channels,
147
+ in_channels,
148
+ kernel_size=1,
149
+ stride=1,
150
+ padding=0)
151
+ self.k = torch.nn.Conv2d(in_channels,
152
+ in_channels,
153
+ kernel_size=1,
154
+ stride=1,
155
+ padding=0)
156
+ self.v = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.proj_out = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+
167
+
168
+ def forward(self, x):
169
+ h_ = x
170
+ h_ = self.norm(h_)
171
+ q = self.q(h_)
172
+ k = self.k(h_)
173
+ v = self.v(h_)
174
+
175
+ # compute attention
176
+ b,c,h,w = q.shape
177
+ q = q.reshape(b,c,h*w)
178
+ q = q.permute(0,2,1) # b,hw,c
179
+ k = k.reshape(b,c,h*w) # b,c,hw
180
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
181
+ w_ = w_ * (int(c)**(-0.5))
182
+ w_ = torch.nn.functional.softmax(w_, dim=2)
183
+
184
+ # attend to values
185
+ v = v.reshape(b,c,h*w)
186
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
187
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
188
+ h_ = h_.reshape(b,c,h,w)
189
+
190
+ h_ = self.proj_out(h_)
191
+
192
+ return x+h_
193
+
194
+
195
+ class Model(nn.Module):
196
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
197
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
198
+ resolution, use_timestep=True):
199
+ super().__init__()
200
+ self.ch = ch
201
+ self.temb_ch = self.ch*4
202
+ self.num_resolutions = len(ch_mult)
203
+ self.num_res_blocks = num_res_blocks
204
+ self.resolution = resolution
205
+ self.in_channels = in_channels
206
+
207
+ self.use_timestep = use_timestep
208
+ if self.use_timestep:
209
+ # timestep embedding
210
+ self.temb = nn.Module()
211
+ self.temb.dense = nn.ModuleList([
212
+ torch.nn.Linear(self.ch,
213
+ self.temb_ch),
214
+ torch.nn.Linear(self.temb_ch,
215
+ self.temb_ch),
216
+ ])
217
+
218
+ # downsampling
219
+ self.conv_in = torch.nn.Conv2d(in_channels,
220
+ self.ch,
221
+ kernel_size=3,
222
+ stride=1,
223
+ padding=1)
224
+
225
+ curr_res = resolution
226
+ in_ch_mult = (1,)+tuple(ch_mult)
227
+ self.down = nn.ModuleList()
228
+ for i_level in range(self.num_resolutions):
229
+ block = nn.ModuleList()
230
+ attn = nn.ModuleList()
231
+ block_in = ch*in_ch_mult[i_level]
232
+ block_out = ch*ch_mult[i_level]
233
+ for i_block in range(self.num_res_blocks):
234
+ block.append(ResnetBlock(in_channels=block_in,
235
+ out_channels=block_out,
236
+ temb_channels=self.temb_ch,
237
+ dropout=dropout))
238
+ block_in = block_out
239
+ if curr_res in attn_resolutions:
240
+ attn.append(AttnBlock(block_in))
241
+ down = nn.Module()
242
+ down.block = block
243
+ down.attn = attn
244
+ if i_level != self.num_resolutions-1:
245
+ down.downsample = Downsample(block_in, resamp_with_conv)
246
+ curr_res = curr_res // 2
247
+ self.down.append(down)
248
+
249
+ # middle
250
+ self.mid = nn.Module()
251
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
252
+ out_channels=block_in,
253
+ temb_channels=self.temb_ch,
254
+ dropout=dropout)
255
+ self.mid.attn_1 = AttnBlock(block_in)
256
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
257
+ out_channels=block_in,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout)
260
+
261
+ # upsampling
262
+ self.up = nn.ModuleList()
263
+ for i_level in reversed(range(self.num_resolutions)):
264
+ block = nn.ModuleList()
265
+ attn = nn.ModuleList()
266
+ block_out = ch*ch_mult[i_level]
267
+ skip_in = ch*ch_mult[i_level]
268
+ for i_block in range(self.num_res_blocks+1):
269
+ if i_block == self.num_res_blocks:
270
+ skip_in = ch*in_ch_mult[i_level]
271
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
272
+ out_channels=block_out,
273
+ temb_channels=self.temb_ch,
274
+ dropout=dropout))
275
+ block_in = block_out
276
+ if curr_res in attn_resolutions:
277
+ attn.append(AttnBlock(block_in))
278
+ up = nn.Module()
279
+ up.block = block
280
+ up.attn = attn
281
+ if i_level != 0:
282
+ up.upsample = Upsample(block_in, resamp_with_conv)
283
+ curr_res = curr_res * 2
284
+ self.up.insert(0, up) # prepend to get consistent order
285
+
286
+ # end
287
+ self.norm_out = Normalize(block_in)
288
+ self.conv_out = torch.nn.Conv2d(block_in,
289
+ out_ch,
290
+ kernel_size=3,
291
+ stride=1,
292
+ padding=1)
293
+
294
+
295
+ def forward(self, x, t=None):
296
+ #assert x.shape[2] == x.shape[3] == self.resolution
297
+
298
+ if self.use_timestep:
299
+ # timestep embedding
300
+ assert t is not None
301
+ temb = get_timestep_embedding(t, self.ch)
302
+ temb = self.temb.dense[0](temb)
303
+ temb = nonlinearity(temb)
304
+ temb = self.temb.dense[1](temb)
305
+ else:
306
+ temb = None
307
+
308
+ # downsampling
309
+ hs = [self.conv_in(x)]
310
+ for i_level in range(self.num_resolutions):
311
+ for i_block in range(self.num_res_blocks):
312
+ h = self.down[i_level].block[i_block](hs[-1], temb)
313
+ if len(self.down[i_level].attn) > 0:
314
+ h = self.down[i_level].attn[i_block](h)
315
+ hs.append(h)
316
+ if i_level != self.num_resolutions-1:
317
+ hs.append(self.down[i_level].downsample(hs[-1]))
318
+
319
+ # middle
320
+ h = hs[-1]
321
+ h = self.mid.block_1(h, temb)
322
+ h = self.mid.attn_1(h)
323
+ h = self.mid.block_2(h, temb)
324
+
325
+ # upsampling
326
+ for i_level in reversed(range(self.num_resolutions)):
327
+ for i_block in range(self.num_res_blocks+1):
328
+ h = self.up[i_level].block[i_block](
329
+ torch.cat([h, hs.pop()], dim=1), temb)
330
+ if len(self.up[i_level].attn) > 0:
331
+ h = self.up[i_level].attn[i_block](h)
332
+ if i_level != 0:
333
+ h = self.up[i_level].upsample(h)
334
+
335
+ # end
336
+ h = self.norm_out(h)
337
+ h = nonlinearity(h)
338
+ h = self.conv_out(h)
339
+ return h
340
+
341
+
342
+ class Encoder(nn.Module):
343
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
344
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
345
+ resolution, z_channels, double_z=True, **ignore_kwargs):
346
+ super().__init__()
347
+ self.ch = ch
348
+ self.temb_ch = 0
349
+ self.num_resolutions = len(ch_mult)
350
+ self.num_res_blocks = num_res_blocks
351
+ self.resolution = resolution
352
+ self.in_channels = in_channels
353
+
354
+ # downsampling
355
+ self.conv_in = torch.nn.Conv2d(in_channels,
356
+ self.ch,
357
+ kernel_size=3,
358
+ stride=1,
359
+ padding=1)
360
+
361
+ curr_res = resolution
362
+ in_ch_mult = (1,)+tuple(ch_mult)
363
+ self.down = nn.ModuleList()
364
+ for i_level in range(self.num_resolutions):
365
+ block = nn.ModuleList()
366
+ attn = nn.ModuleList()
367
+ block_in = ch*in_ch_mult[i_level]
368
+ block_out = ch*ch_mult[i_level]
369
+ for i_block in range(self.num_res_blocks):
370
+ block.append(ResnetBlock(in_channels=block_in,
371
+ out_channels=block_out,
372
+ temb_channels=self.temb_ch,
373
+ dropout=dropout))
374
+ block_in = block_out
375
+ if curr_res in attn_resolutions:
376
+ attn.append(AttnBlock(block_in))
377
+ down = nn.Module()
378
+ down.block = block
379
+ down.attn = attn
380
+ if i_level != self.num_resolutions-1:
381
+ down.downsample = Downsample(block_in, resamp_with_conv)
382
+ curr_res = curr_res // 2
383
+ self.down.append(down)
384
+
385
+ # middle
386
+ self.mid = nn.Module()
387
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
388
+ out_channels=block_in,
389
+ temb_channels=self.temb_ch,
390
+ dropout=dropout)
391
+ self.mid.attn_1 = AttnBlock(block_in)
392
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
393
+ out_channels=block_in,
394
+ temb_channels=self.temb_ch,
395
+ dropout=dropout)
396
+
397
+ # end
398
+ self.norm_out = Normalize(block_in)
399
+ self.conv_out = torch.nn.Conv2d(block_in,
400
+ 2*z_channels if double_z else z_channels,
401
+ kernel_size=3,
402
+ stride=1,
403
+ padding=1)
404
+
405
+
406
+ def forward(self, x):
407
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
408
+
409
+ # timestep embedding
410
+ temb = None
411
+
412
+ # downsampling
413
+ hs = [self.conv_in(x)]
414
+ for i_level in range(self.num_resolutions):
415
+ for i_block in range(self.num_res_blocks):
416
+ h = self.down[i_level].block[i_block](hs[-1], temb)
417
+ if len(self.down[i_level].attn) > 0:
418
+ h = self.down[i_level].attn[i_block](h)
419
+ hs.append(h)
420
+ if i_level != self.num_resolutions-1:
421
+ hs.append(self.down[i_level].downsample(hs[-1]))
422
+
423
+ # middle
424
+ h = hs[-1]
425
+ h = self.mid.block_1(h, temb)
426
+ h = self.mid.attn_1(h)
427
+ h = self.mid.block_2(h, temb)
428
+
429
+ # end
430
+ h = self.norm_out(h)
431
+ h = nonlinearity(h)
432
+ h = self.conv_out(h)
433
+ return h
434
+
435
+
436
+ class Decoder(nn.Module):
437
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
438
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
439
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
440
+ super().__init__()
441
+ self.ch = ch
442
+ self.temb_ch = 0
443
+ self.num_resolutions = len(ch_mult)
444
+ self.num_res_blocks = num_res_blocks
445
+ self.resolution = resolution
446
+ self.in_channels = in_channels
447
+ self.give_pre_end = give_pre_end
448
+
449
+ # compute in_ch_mult, block_in and curr_res at lowest res
450
+ in_ch_mult = (1,)+tuple(ch_mult)
451
+ block_in = ch*ch_mult[self.num_resolutions-1]
452
+ curr_res = resolution // 2**(self.num_resolutions-1)
453
+ self.z_shape = (1,z_channels,curr_res,curr_res)
454
+ print("Working with z of shape {} = {} dimensions.".format(
455
+ self.z_shape, np.prod(self.z_shape)))
456
+
457
+ # z to block_in
458
+ self.conv_in = torch.nn.Conv2d(z_channels,
459
+ block_in,
460
+ kernel_size=3,
461
+ stride=1,
462
+ padding=1)
463
+
464
+ # middle
465
+ self.mid = nn.Module()
466
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
467
+ out_channels=block_in,
468
+ temb_channels=self.temb_ch,
469
+ dropout=dropout)
470
+ self.mid.attn_1 = AttnBlock(block_in)
471
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
472
+ out_channels=block_in,
473
+ temb_channels=self.temb_ch,
474
+ dropout=dropout)
475
+
476
+ # upsampling
477
+ self.up = nn.ModuleList()
478
+ for i_level in reversed(range(self.num_resolutions)):
479
+ block = nn.ModuleList()
480
+ attn = nn.ModuleList()
481
+ block_out = ch*ch_mult[i_level]
482
+ for i_block in range(self.num_res_blocks+1):
483
+ block.append(ResnetBlock(in_channels=block_in,
484
+ out_channels=block_out,
485
+ temb_channels=self.temb_ch,
486
+ dropout=dropout))
487
+ block_in = block_out
488
+ if curr_res in attn_resolutions:
489
+ attn.append(AttnBlock(block_in))
490
+ up = nn.Module()
491
+ up.block = block
492
+ up.attn = attn
493
+ if i_level != 0:
494
+ up.upsample = Upsample(block_in, resamp_with_conv)
495
+ curr_res = curr_res * 2
496
+ self.up.insert(0, up) # prepend to get consistent order
497
+
498
+ # end
499
+ self.norm_out = Normalize(block_in)
500
+ self.conv_out = torch.nn.Conv2d(block_in,
501
+ out_ch,
502
+ kernel_size=3,
503
+ stride=1,
504
+ padding=1)
505
+
506
+ def forward(self, z):
507
+ #assert z.shape[1:] == self.z_shape[1:]
508
+ self.last_z_shape = z.shape
509
+
510
+ # timestep embedding
511
+ temb = None
512
+
513
+ # z to block_in
514
+ h = self.conv_in(z)
515
+
516
+ # middle
517
+ h = self.mid.block_1(h, temb)
518
+ h = self.mid.attn_1(h)
519
+ h = self.mid.block_2(h, temb)
520
+
521
+ # upsampling
522
+ for i_level in reversed(range(self.num_resolutions)):
523
+ for i_block in range(self.num_res_blocks+1):
524
+ h = self.up[i_level].block[i_block](h, temb)
525
+ if len(self.up[i_level].attn) > 0:
526
+ h = self.up[i_level].attn[i_block](h)
527
+ if i_level != 0:
528
+ h = self.up[i_level].upsample(h)
529
+
530
+ # end
531
+ if self.give_pre_end:
532
+ return h
533
+
534
+ h = self.norm_out(h)
535
+ h = nonlinearity(h)
536
+ h = self.conv_out(h)
537
+ return h
538
+
539
+
540
+ class VUNet(nn.Module):
541
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
542
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
543
+ in_channels, c_channels,
544
+ resolution, z_channels, use_timestep=False, **ignore_kwargs):
545
+ super().__init__()
546
+ self.ch = ch
547
+ self.temb_ch = self.ch*4
548
+ self.num_resolutions = len(ch_mult)
549
+ self.num_res_blocks = num_res_blocks
550
+ self.resolution = resolution
551
+
552
+ self.use_timestep = use_timestep
553
+ if self.use_timestep:
554
+ # timestep embedding
555
+ self.temb = nn.Module()
556
+ self.temb.dense = nn.ModuleList([
557
+ torch.nn.Linear(self.ch,
558
+ self.temb_ch),
559
+ torch.nn.Linear(self.temb_ch,
560
+ self.temb_ch),
561
+ ])
562
+
563
+ # downsampling
564
+ self.conv_in = torch.nn.Conv2d(c_channels,
565
+ self.ch,
566
+ kernel_size=3,
567
+ stride=1,
568
+ padding=1)
569
+
570
+ curr_res = resolution
571
+ in_ch_mult = (1,)+tuple(ch_mult)
572
+ self.down = nn.ModuleList()
573
+ for i_level in range(self.num_resolutions):
574
+ block = nn.ModuleList()
575
+ attn = nn.ModuleList()
576
+ block_in = ch*in_ch_mult[i_level]
577
+ block_out = ch*ch_mult[i_level]
578
+ for i_block in range(self.num_res_blocks):
579
+ block.append(ResnetBlock(in_channels=block_in,
580
+ out_channels=block_out,
581
+ temb_channels=self.temb_ch,
582
+ dropout=dropout))
583
+ block_in = block_out
584
+ if curr_res in attn_resolutions:
585
+ attn.append(AttnBlock(block_in))
586
+ down = nn.Module()
587
+ down.block = block
588
+ down.attn = attn
589
+ if i_level != self.num_resolutions-1:
590
+ down.downsample = Downsample(block_in, resamp_with_conv)
591
+ curr_res = curr_res // 2
592
+ self.down.append(down)
593
+
594
+ self.z_in = torch.nn.Conv2d(z_channels,
595
+ block_in,
596
+ kernel_size=1,
597
+ stride=1,
598
+ padding=0)
599
+ # middle
600
+ self.mid = nn.Module()
601
+ self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
602
+ out_channels=block_in,
603
+ temb_channels=self.temb_ch,
604
+ dropout=dropout)
605
+ self.mid.attn_1 = AttnBlock(block_in)
606
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
607
+ out_channels=block_in,
608
+ temb_channels=self.temb_ch,
609
+ dropout=dropout)
610
+
611
+ # upsampling
612
+ self.up = nn.ModuleList()
613
+ for i_level in reversed(range(self.num_resolutions)):
614
+ block = nn.ModuleList()
615
+ attn = nn.ModuleList()
616
+ block_out = ch*ch_mult[i_level]
617
+ skip_in = ch*ch_mult[i_level]
618
+ for i_block in range(self.num_res_blocks+1):
619
+ if i_block == self.num_res_blocks:
620
+ skip_in = ch*in_ch_mult[i_level]
621
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
622
+ out_channels=block_out,
623
+ temb_channels=self.temb_ch,
624
+ dropout=dropout))
625
+ block_in = block_out
626
+ if curr_res in attn_resolutions:
627
+ attn.append(AttnBlock(block_in))
628
+ up = nn.Module()
629
+ up.block = block
630
+ up.attn = attn
631
+ if i_level != 0:
632
+ up.upsample = Upsample(block_in, resamp_with_conv)
633
+ curr_res = curr_res * 2
634
+ self.up.insert(0, up) # prepend to get consistent order
635
+
636
+ # end
637
+ self.norm_out = Normalize(block_in)
638
+ self.conv_out = torch.nn.Conv2d(block_in,
639
+ out_ch,
640
+ kernel_size=3,
641
+ stride=1,
642
+ padding=1)
643
+
644
+
645
+ def forward(self, x, z):
646
+ #assert x.shape[2] == x.shape[3] == self.resolution
647
+
648
+ if self.use_timestep:
649
+ # timestep embedding
650
+ assert t is not None
651
+ temb = get_timestep_embedding(t, self.ch)
652
+ temb = self.temb.dense[0](temb)
653
+ temb = nonlinearity(temb)
654
+ temb = self.temb.dense[1](temb)
655
+ else:
656
+ temb = None
657
+
658
+ # downsampling
659
+ hs = [self.conv_in(x)]
660
+ for i_level in range(self.num_resolutions):
661
+ for i_block in range(self.num_res_blocks):
662
+ h = self.down[i_level].block[i_block](hs[-1], temb)
663
+ if len(self.down[i_level].attn) > 0:
664
+ h = self.down[i_level].attn[i_block](h)
665
+ hs.append(h)
666
+ if i_level != self.num_resolutions-1:
667
+ hs.append(self.down[i_level].downsample(hs[-1]))
668
+
669
+ # middle
670
+ h = hs[-1]
671
+ z = self.z_in(z)
672
+ h = torch.cat((h,z),dim=1)
673
+ h = self.mid.block_1(h, temb)
674
+ h = self.mid.attn_1(h)
675
+ h = self.mid.block_2(h, temb)
676
+
677
+ # upsampling
678
+ for i_level in reversed(range(self.num_resolutions)):
679
+ for i_block in range(self.num_res_blocks+1):
680
+ h = self.up[i_level].block[i_block](
681
+ torch.cat([h, hs.pop()], dim=1), temb)
682
+ if len(self.up[i_level].attn) > 0:
683
+ h = self.up[i_level].attn[i_block](h)
684
+ if i_level != 0:
685
+ h = self.up[i_level].upsample(h)
686
+
687
+ # end
688
+ h = self.norm_out(h)
689
+ h = nonlinearity(h)
690
+ h = self.conv_out(h)
691
+ return h
692
+
693
+
694
+ class SimpleDecoder(nn.Module):
695
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
696
+ super().__init__()
697
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
698
+ ResnetBlock(in_channels=in_channels,
699
+ out_channels=2 * in_channels,
700
+ temb_channels=0, dropout=0.0),
701
+ ResnetBlock(in_channels=2 * in_channels,
702
+ out_channels=4 * in_channels,
703
+ temb_channels=0, dropout=0.0),
704
+ ResnetBlock(in_channels=4 * in_channels,
705
+ out_channels=2 * in_channels,
706
+ temb_channels=0, dropout=0.0),
707
+ nn.Conv2d(2*in_channels, in_channels, 1),
708
+ Upsample(in_channels, with_conv=True)])
709
+ # end
710
+ self.norm_out = Normalize(in_channels)
711
+ self.conv_out = torch.nn.Conv2d(in_channels,
712
+ out_channels,
713
+ kernel_size=3,
714
+ stride=1,
715
+ padding=1)
716
+
717
+ def forward(self, x):
718
+ for i, layer in enumerate(self.model):
719
+ if i in [1,2,3]:
720
+ x = layer(x, None)
721
+ else:
722
+ x = layer(x)
723
+
724
+ h = self.norm_out(x)
725
+ h = nonlinearity(h)
726
+ x = self.conv_out(h)
727
+ return x
728
+
729
+
730
+ class UpsampleDecoder(nn.Module):
731
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
732
+ ch_mult=(2,2), dropout=0.0):
733
+ super().__init__()
734
+ # upsampling
735
+ self.temb_ch = 0
736
+ self.num_resolutions = len(ch_mult)
737
+ self.num_res_blocks = num_res_blocks
738
+ block_in = in_channels
739
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
740
+ self.res_blocks = nn.ModuleList()
741
+ self.upsample_blocks = nn.ModuleList()
742
+ for i_level in range(self.num_resolutions):
743
+ res_block = []
744
+ block_out = ch * ch_mult[i_level]
745
+ for i_block in range(self.num_res_blocks + 1):
746
+ res_block.append(ResnetBlock(in_channels=block_in,
747
+ out_channels=block_out,
748
+ temb_channels=self.temb_ch,
749
+ dropout=dropout))
750
+ block_in = block_out
751
+ self.res_blocks.append(nn.ModuleList(res_block))
752
+ if i_level != self.num_resolutions - 1:
753
+ self.upsample_blocks.append(Upsample(block_in, True))
754
+ curr_res = curr_res * 2
755
+
756
+ # end
757
+ self.norm_out = Normalize(block_in)
758
+ self.conv_out = torch.nn.Conv2d(block_in,
759
+ out_channels,
760
+ kernel_size=3,
761
+ stride=1,
762
+ padding=1)
763
+
764
+ def forward(self, x):
765
+ # upsampling
766
+ h = x
767
+ for k, i_level in enumerate(range(self.num_resolutions)):
768
+ for i_block in range(self.num_res_blocks + 1):
769
+ h = self.res_blocks[i_level][i_block](h, None)
770
+ if i_level != self.num_resolutions - 1:
771
+ h = self.upsample_blocks[k](h)
772
+ h = self.norm_out(h)
773
+ h = nonlinearity(h)
774
+ h = self.conv_out(h)
775
+ return h
776
+
taming/modules/discriminator/__pycache__/model.cpython-38.pyc ADDED
Binary file (2.34 kB). View file
 
taming/modules/discriminator/model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch.nn as nn
3
+
4
+
5
+ from taming.modules.util import ActNorm
6
+
7
+
8
+ def weights_init(m):
9
+ classname = m.__class__.__name__
10
+ if classname.find('Conv') != -1:
11
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
12
+ elif classname.find('BatchNorm') != -1:
13
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
14
+ nn.init.constant_(m.bias.data, 0)
15
+
16
+
17
+ class NLayerDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22
+ """Construct a PatchGAN discriminator
23
+ Parameters:
24
+ input_nc (int) -- the number of channels in input images
25
+ ndf (int) -- the number of filters in the last conv layer
26
+ n_layers (int) -- the number of conv layers in the discriminator
27
+ norm_layer -- normalization layer
28
+ """
29
+ super(NLayerDiscriminator, self).__init__()
30
+ if not use_actnorm:
31
+ norm_layer = nn.BatchNorm2d
32
+ else:
33
+ norm_layer = ActNorm
34
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35
+ use_bias = norm_layer.func != nn.BatchNorm2d
36
+ else:
37
+ use_bias = norm_layer != nn.BatchNorm2d
38
+
39
+ kw = 4
40
+ padw = 1
41
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42
+ nf_mult = 1
43
+ nf_mult_prev = 1
44
+ for n in range(1, n_layers): # gradually increase the number of filters
45
+ nf_mult_prev = nf_mult
46
+ nf_mult = min(2 ** n, 8)
47
+ sequence += [
48
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49
+ norm_layer(ndf * nf_mult),
50
+ nn.LeakyReLU(0.2, True)
51
+ ]
52
+
53
+ nf_mult_prev = nf_mult
54
+ nf_mult = min(2 ** n_layers, 8)
55
+ sequence += [
56
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57
+ norm_layer(ndf * nf_mult),
58
+ nn.LeakyReLU(0.2, True)
59
+ ]
60
+
61
+ sequence += [
62
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63
+ self.main = nn.Sequential(*sequence)
64
+
65
+ def forward(self, input):
66
+ """Standard forward."""
67
+ return self.main(input)
taming/modules/losses/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from taming.modules.losses.vqperceptual import DummyLoss
2
+
taming/modules/losses/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (231 Bytes). View file
 
taming/modules/losses/__pycache__/lpips.cpython-38.pyc ADDED
Binary file (5.3 kB). View file
 
taming/modules/losses/__pycache__/vqperceptual.cpython-38.pyc ADDED
Binary file (6.94 kB). View file
 
taming/modules/losses/lpips.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import models
6
+ from collections import namedtuple
7
+
8
+ from taming.util import get_ckpt_path
9
+
10
+
11
+ class LPIPS(nn.Module):
12
+ # Learned perceptual metric
13
+ def __init__(self, use_dropout=True):
14
+ super().__init__()
15
+ self.scaling_layer = ScalingLayer()
16
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
17
+ self.net = vgg16(pretrained=True, requires_grad=False)
18
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
19
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
20
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
21
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
22
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
23
+ self.load_from_pretrained()
24
+ for param in self.parameters():
25
+ param.requires_grad = False
26
+
27
+ def load_from_pretrained(self, name="vgg_lpips"):
28
+ ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
29
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
30
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
31
+
32
+ @classmethod
33
+ def from_pretrained(cls, name="vgg_lpips"):
34
+ if name != "vgg_lpips":
35
+ raise NotImplementedError
36
+ model = cls()
37
+ ckpt = get_ckpt_path(name)
38
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
39
+ return model
40
+
41
+ def forward(self, input, target):
42
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
43
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
44
+ feats0, feats1, diffs = {}, {}, {}
45
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
46
+ for kk in range(len(self.chns)):
47
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
48
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
49
+
50
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
51
+ val = res[0]
52
+ for l in range(1, len(self.chns)):
53
+ val += res[l]
54
+ return val
55
+
56
+
57
+ class ScalingLayer(nn.Module):
58
+ def __init__(self):
59
+ super(ScalingLayer, self).__init__()
60
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
61
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
62
+
63
+ def forward(self, inp):
64
+ return (inp - self.shift) / self.scale
65
+
66
+
67
+ class NetLinLayer(nn.Module):
68
+ """ A single linear layer which does a 1x1 conv """
69
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
70
+ super(NetLinLayer, self).__init__()
71
+ layers = [nn.Dropout(), ] if (use_dropout) else []
72
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
73
+ self.model = nn.Sequential(*layers)
74
+
75
+
76
+ class vgg16(torch.nn.Module):
77
+ def __init__(self, requires_grad=False, pretrained=True):
78
+ super(vgg16, self).__init__()
79
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
80
+ self.slice1 = torch.nn.Sequential()
81
+ self.slice2 = torch.nn.Sequential()
82
+ self.slice3 = torch.nn.Sequential()
83
+ self.slice4 = torch.nn.Sequential()
84
+ self.slice5 = torch.nn.Sequential()
85
+ self.N_slices = 5
86
+ for x in range(4):
87
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
88
+ for x in range(4, 9):
89
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
90
+ for x in range(9, 16):
91
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
92
+ for x in range(16, 23):
93
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
94
+ for x in range(23, 30):
95
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
96
+ if not requires_grad:
97
+ for param in self.parameters():
98
+ param.requires_grad = False
99
+
100
+ def forward(self, X):
101
+ h = self.slice1(X)
102
+ h_relu1_2 = h
103
+ h = self.slice2(h)
104
+ h_relu2_2 = h
105
+ h = self.slice3(h)
106
+ h_relu3_3 = h
107
+ h = self.slice4(h)
108
+ h_relu4_3 = h
109
+ h = self.slice5(h)
110
+ h_relu5_3 = h
111
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
112
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
113
+ return out
114
+
115
+
116
+ def normalize_tensor(x,eps=1e-10):
117
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
118
+ return x/(norm_factor+eps)
119
+
120
+
121
+ def spatial_average(x, keepdim=True):
122
+ return x.mean([2,3],keepdim=keepdim)
123
+
taming/modules/losses/segmentation.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class BCELoss(nn.Module):
6
+ def forward(self, prediction, target):
7
+ loss = F.binary_cross_entropy_with_logits(prediction,target)
8
+ return loss, {}
9
+
10
+
11
+ class BCELossWithQuant(nn.Module):
12
+ def __init__(self, codebook_weight=1.):
13
+ super().__init__()
14
+ self.codebook_weight = codebook_weight
15
+
16
+ def forward(self, qloss, target, prediction, split):
17
+ bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
18
+ loss = bce_loss + self.codebook_weight*qloss
19
+ return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
20
+ "{}/bce_loss".format(split): bce_loss.detach().mean(),
21
+ "{}/quant_loss".format(split): qloss.detach().mean()
22
+ }
taming/modules/losses/vqperceptual.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from taming.modules.losses.lpips import LPIPS
6
+ from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7
+
8
+
9
+ class DummyLoss(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+
14
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
15
+ if global_step < threshold:
16
+ weight = value
17
+ return weight
18
+
19
+
20
+ def hinge_d_loss(logits_real, logits_fake):
21
+ loss_real = torch.mean(F.relu(1. - logits_real))
22
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
23
+ d_loss = 0.5 * (loss_real + loss_fake)
24
+ return d_loss
25
+
26
+
27
+ def vanilla_d_loss(logits_real, logits_fake):
28
+ d_loss = 0.5 * (
29
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
30
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
31
+ return d_loss
32
+
33
+
34
+ class VQLPIPSWithDiscriminator(nn.Module):
35
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
36
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
37
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
38
+ disc_ndf=64, disc_loss="hinge"):
39
+ super().__init__()
40
+ assert disc_loss in ["hinge", "vanilla"]
41
+ self.codebook_weight = codebook_weight
42
+ self.pixel_weight = pixelloss_weight
43
+ self.perceptual_loss = LPIPS().eval()
44
+ self.perceptual_weight = perceptual_weight
45
+
46
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
47
+ n_layers=disc_num_layers,
48
+ use_actnorm=use_actnorm,
49
+ ndf=disc_ndf
50
+ ).apply(weights_init)
51
+ self.discriminator_iter_start = disc_start
52
+ if disc_loss == "hinge":
53
+ self.disc_loss = hinge_d_loss
54
+ elif disc_loss == "vanilla":
55
+ self.disc_loss = vanilla_d_loss
56
+ else:
57
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
58
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
59
+ self.disc_factor = disc_factor
60
+ self.discriminator_weight = disc_weight
61
+ self.disc_conditional = disc_conditional
62
+
63
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
64
+ if last_layer is not None:
65
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
66
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
67
+ else:
68
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
69
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
70
+
71
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
72
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
73
+ d_weight = d_weight * self.discriminator_weight
74
+ return d_weight
75
+
76
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
77
+ global_step, last_layer=None, cond=None, split="train"):
78
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
79
+ if self.perceptual_weight > 0:
80
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
81
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
82
+ else:
83
+ p_loss = torch.tensor([0.0])
84
+
85
+ nll_loss = rec_loss
86
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
87
+ nll_loss = torch.mean(nll_loss)
88
+
89
+ # now the GAN part
90
+ if optimizer_idx == 0:
91
+ # generator update
92
+ if cond is None:
93
+ assert not self.disc_conditional
94
+ logits_fake = self.discriminator(reconstructions.contiguous())
95
+ else:
96
+ assert self.disc_conditional
97
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
98
+ g_loss = -torch.mean(logits_fake)
99
+
100
+ try:
101
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
102
+ except RuntimeError:
103
+ assert not self.training
104
+ d_weight = torch.tensor(0.0)
105
+
106
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
107
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
108
+
109
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
110
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
111
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
112
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
113
+ "{}/p_loss".format(split): p_loss.detach().mean(),
114
+ "{}/d_weight".format(split): d_weight.detach(),
115
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
116
+ "{}/g_loss".format(split): g_loss.detach().mean(),
117
+ }
118
+ return loss, log
119
+
120
+ if optimizer_idx == 1:
121
+ # second pass for discriminator update
122
+ if cond is None:
123
+ logits_real = self.discriminator(inputs.contiguous().detach())
124
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
125
+ else:
126
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
127
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
128
+
129
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
130
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
131
+
132
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
133
+ "{}/logits_real".format(split): logits_real.detach().mean(),
134
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
135
+ }
136
+ return d_loss, log
137
+
138
+ class LPIPSWithDiscriminator(nn.Module):
139
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
140
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
141
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
142
+ disc_loss="hinge"):
143
+
144
+ super().__init__()
145
+ assert disc_loss in ["hinge", "vanilla"]
146
+ self.kl_weight = kl_weight
147
+ self.pixel_weight = pixelloss_weight
148
+ self.perceptual_loss = LPIPS().eval()
149
+ self.perceptual_weight = perceptual_weight
150
+ # output log variance
151
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
152
+
153
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
154
+ n_layers=disc_num_layers,
155
+ use_actnorm=use_actnorm
156
+ ).apply(weights_init)
157
+ self.discriminator_iter_start = disc_start
158
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
159
+ self.disc_factor = disc_factor
160
+ self.discriminator_weight = disc_weight
161
+ self.disc_conditional = disc_conditional
162
+
163
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
164
+ if last_layer is not None:
165
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
166
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
167
+ else:
168
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
169
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
170
+
171
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
172
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
173
+ d_weight = d_weight * self.discriminator_weight
174
+ return d_weight
175
+
176
+ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
177
+ global_step, last_layer=None, cond=None, split="train",
178
+ weights=None):
179
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
180
+ if self.perceptual_weight > 0:
181
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
182
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
183
+
184
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
185
+ weighted_nll_loss = nll_loss
186
+ if weights is not None:
187
+ weighted_nll_loss = weights*nll_loss
188
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
189
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
190
+ kl_loss = posteriors.kl()
191
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
192
+
193
+ # now the GAN part
194
+ if optimizer_idx == 0:
195
+ # generator update
196
+ if cond is None:
197
+ assert not self.disc_conditional
198
+ logits_fake = self.discriminator(reconstructions.contiguous())
199
+ else:
200
+ assert self.disc_conditional
201
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
202
+ g_loss = -torch.mean(logits_fake)
203
+
204
+ if self.disc_factor > 0.0:
205
+ try:
206
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
207
+ except RuntimeError:
208
+ assert not self.training
209
+ d_weight = torch.tensor(0.0)
210
+ else:
211
+ d_weight = torch.tensor(0.0)
212
+
213
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
214
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
215
+
216
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
217
+ "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
218
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
219
+ "{}/d_weight".format(split): d_weight.detach(),
220
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
221
+ "{}/g_loss".format(split): g_loss.detach().mean(),
222
+ }
223
+ return loss, log
224
+
225
+ if optimizer_idx == 1:
226
+ # second pass for discriminator update
227
+ if cond is None:
228
+ logits_real = self.discriminator(inputs.contiguous().detach())
229
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
230
+ else:
231
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
232
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
233
+
234
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
235
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
236
+
237
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
238
+ "{}/logits_real".format(split): logits_real.detach().mean(),
239
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
240
+ }
241
+ return d_loss, log
taming/modules/misc/coord.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class CoordStage(object):
4
+ def __init__(self, n_embed, down_factor):
5
+ self.n_embed = n_embed
6
+ self.down_factor = down_factor
7
+
8
+ def eval(self):
9
+ return self
10
+
11
+ def encode(self, c):
12
+ """fake vqmodel interface"""
13
+ assert 0.0 <= c.min() and c.max() <= 1.0
14
+ b,ch,h,w = c.shape
15
+ assert ch == 1
16
+
17
+ c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
18
+ mode="area")
19
+ c = c.clamp(0.0, 1.0)
20
+ c = self.n_embed*c
21
+ c_quant = c.round()
22
+ c_ind = c_quant.to(dtype=torch.long)
23
+
24
+ info = None, None, c_ind
25
+ return c_quant, None, info
26
+
27
+ def decode(self, c):
28
+ c = c/self.n_embed
29
+ c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
30
+ mode="nearest")
31
+ return c
taming/modules/transformer/mingpt.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ taken from: https://github.com/karpathy/minGPT/
3
+ GPT model:
4
+ - the initial stem consists of a combination of token encoding and a positional encoding
5
+ - the meat of it is a uniform sequence of Transformer blocks
6
+ - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
7
+ - all blocks feed into a central residual pathway similar to resnets
8
+ - the final decoder is a linear projection into a vanilla Softmax classifier
9
+ """
10
+
11
+ import math
12
+ import logging
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+ from transformers import top_k_top_p_filtering
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class GPTConfig:
23
+ """ base GPT config, params common to all GPT versions """
24
+ embd_pdrop = 0.1
25
+ resid_pdrop = 0.1
26
+ attn_pdrop = 0.1
27
+
28
+ def __init__(self, vocab_size, block_size, **kwargs):
29
+ self.vocab_size = vocab_size
30
+ self.block_size = block_size
31
+ for k,v in kwargs.items():
32
+ setattr(self, k, v)
33
+
34
+
35
+ class GPT1Config(GPTConfig):
36
+ """ GPT-1 like network roughly 125M params """
37
+ n_layer = 12
38
+ n_head = 12
39
+ n_embd = 768
40
+
41
+
42
+ class CausalSelfAttention(nn.Module):
43
+ """
44
+ A vanilla multi-head masked self-attention layer with a projection at the end.
45
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
46
+ explicit implementation here to show that there is nothing too scary here.
47
+ """
48
+
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ assert config.n_embd % config.n_head == 0
52
+ # key, query, value projections for all heads
53
+ self.key = nn.Linear(config.n_embd, config.n_embd)
54
+ self.query = nn.Linear(config.n_embd, config.n_embd)
55
+ self.value = nn.Linear(config.n_embd, config.n_embd)
56
+ # regularization
57
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
58
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
59
+ # output projection
60
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
61
+ # causal mask to ensure that attention is only applied to the left in the input sequence
62
+ mask = torch.tril(torch.ones(config.block_size,
63
+ config.block_size))
64
+ if hasattr(config, "n_unmasked"):
65
+ mask[:config.n_unmasked, :config.n_unmasked] = 1
66
+ self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
67
+ self.n_head = config.n_head
68
+
69
+ def forward(self, x, layer_past=None):
70
+ B, T, C = x.size()
71
+
72
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
73
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
74
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
75
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
76
+
77
+ present = torch.stack((k, v))
78
+ if layer_past is not None:
79
+ past_key, past_value = layer_past
80
+ k = torch.cat((past_key, k), dim=-2)
81
+ v = torch.cat((past_value, v), dim=-2)
82
+
83
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
84
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
85
+ if layer_past is None:
86
+ att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
87
+
88
+ att = F.softmax(att, dim=-1)
89
+ att = self.attn_drop(att)
90
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
91
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
92
+
93
+ # output projection
94
+ y = self.resid_drop(self.proj(y))
95
+ return y, present # TODO: check that this does not break anything
96
+
97
+
98
+ class Block(nn.Module):
99
+ """ an unassuming Transformer block """
100
+ def __init__(self, config):
101
+ super().__init__()
102
+ self.ln1 = nn.LayerNorm(config.n_embd)
103
+ self.ln2 = nn.LayerNorm(config.n_embd)
104
+ self.attn = CausalSelfAttention(config)
105
+ self.mlp = nn.Sequential(
106
+ nn.Linear(config.n_embd, 4 * config.n_embd),
107
+ nn.GELU(), # nice
108
+ nn.Linear(4 * config.n_embd, config.n_embd),
109
+ nn.Dropout(config.resid_pdrop),
110
+ )
111
+
112
+ def forward(self, x, layer_past=None, return_present=False):
113
+ # TODO: check that training still works
114
+ if return_present: assert not self.training
115
+ # layer past: tuple of length two with B, nh, T, hs
116
+ attn, present = self.attn(self.ln1(x), layer_past=layer_past)
117
+
118
+ x = x + attn
119
+ x = x + self.mlp(self.ln2(x))
120
+ if layer_past is not None or return_present:
121
+ return x, present
122
+ return x
123
+
124
+
125
+ class GPT(nn.Module):
126
+ """ the full GPT language model, with a context size of block_size """
127
+ def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
128
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
129
+ super().__init__()
130
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
131
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
132
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
133
+ n_unmasked=n_unmasked)
134
+ # input embedding stem
135
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
136
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
137
+ self.drop = nn.Dropout(config.embd_pdrop)
138
+ # transformer
139
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
140
+ # decoder head
141
+ self.ln_f = nn.LayerNorm(config.n_embd)
142
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
143
+ self.block_size = config.block_size
144
+ self.apply(self._init_weights)
145
+ self.config = config
146
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
147
+
148
+ def get_block_size(self):
149
+ return self.block_size
150
+
151
+ def _init_weights(self, module):
152
+ if isinstance(module, (nn.Linear, nn.Embedding)):
153
+ module.weight.data.normal_(mean=0.0, std=0.02)
154
+ if isinstance(module, nn.Linear) and module.bias is not None:
155
+ module.bias.data.zero_()
156
+ elif isinstance(module, nn.LayerNorm):
157
+ module.bias.data.zero_()
158
+ module.weight.data.fill_(1.0)
159
+
160
+ def forward(self, idx, embeddings=None, targets=None):
161
+ # forward the GPT model
162
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
163
+
164
+ if embeddings is not None: # prepend explicit embeddings
165
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
166
+
167
+ t = token_embeddings.shape[1]
168
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
169
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
170
+ x = self.drop(token_embeddings + position_embeddings)
171
+ x = self.blocks(x)
172
+ x = self.ln_f(x)
173
+ logits = self.head(x)
174
+
175
+ # if we are given some desired targets also calculate the loss
176
+ loss = None
177
+ if targets is not None:
178
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
179
+
180
+ return logits, loss
181
+
182
+ def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
183
+ # inference only
184
+ assert not self.training
185
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
186
+ if embeddings is not None: # prepend explicit embeddings
187
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
188
+
189
+ if past is not None:
190
+ assert past_length is not None
191
+ past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
192
+ past_shape = list(past.shape)
193
+ expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
194
+ assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
195
+ position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
196
+ else:
197
+ position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
198
+
199
+ x = self.drop(token_embeddings + position_embeddings)
200
+ presents = [] # accumulate over layers
201
+ for i, block in enumerate(self.blocks):
202
+ x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
203
+ presents.append(present)
204
+
205
+ x = self.ln_f(x)
206
+ logits = self.head(x)
207
+ # if we are given some desired targets also calculate the loss
208
+ loss = None
209
+ if targets is not None:
210
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
211
+
212
+ return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
213
+
214
+
215
+ class DummyGPT(nn.Module):
216
+ # for debugging
217
+ def __init__(self, add_value=1):
218
+ super().__init__()
219
+ self.add_value = add_value
220
+
221
+ def forward(self, idx):
222
+ return idx + self.add_value, None
223
+
224
+
225
+ class CodeGPT(nn.Module):
226
+ """Takes in semi-embeddings"""
227
+ def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
228
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
229
+ super().__init__()
230
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
231
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
232
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
233
+ n_unmasked=n_unmasked)
234
+ # input embedding stem
235
+ self.tok_emb = nn.Linear(in_channels, config.n_embd)
236
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
237
+ self.drop = nn.Dropout(config.embd_pdrop)
238
+ # transformer
239
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
240
+ # decoder head
241
+ self.ln_f = nn.LayerNorm(config.n_embd)
242
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
243
+ self.block_size = config.block_size
244
+ self.apply(self._init_weights)
245
+ self.config = config
246
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
247
+
248
+ def get_block_size(self):
249
+ return self.block_size
250
+
251
+ def _init_weights(self, module):
252
+ if isinstance(module, (nn.Linear, nn.Embedding)):
253
+ module.weight.data.normal_(mean=0.0, std=0.02)
254
+ if isinstance(module, nn.Linear) and module.bias is not None:
255
+ module.bias.data.zero_()
256
+ elif isinstance(module, nn.LayerNorm):
257
+ module.bias.data.zero_()
258
+ module.weight.data.fill_(1.0)
259
+
260
+ def forward(self, idx, embeddings=None, targets=None):
261
+ # forward the GPT model
262
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
263
+
264
+ if embeddings is not None: # prepend explicit embeddings
265
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
266
+
267
+ t = token_embeddings.shape[1]
268
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
269
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
270
+ x = self.drop(token_embeddings + position_embeddings)
271
+ x = self.blocks(x)
272
+ x = self.taming_cinln_f(x)
273
+ logits = self.head(x)
274
+
275
+ # if we are given some desired targets also calculate the loss
276
+ loss = None
277
+ if targets is not None:
278
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
279
+
280
+ return logits, loss
281
+
282
+
283
+
284
+ #### sampling utils
285
+
286
+ def top_k_logits(logits, k):
287
+ v, ix = torch.topk(logits, k)
288
+ out = logits.clone()
289
+ out[out < v[:, [-1]]] = -float('Inf')
290
+ return out
291
+
292
+ @torch.no_grad()
293
+ def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
294
+ """
295
+ take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
296
+ the sequence, feeding the predictions back into the model each time. Clearly the sampling
297
+ has quadratic complexity unlike an RNN that is only linear, and has a finite context window
298
+ of block_size, unlike an RNN that has an infinite context window.
299
+ """
300
+ block_size = model.get_block_size()
301
+ model.eval()
302
+ for k in range(steps):
303
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
304
+ logits, _ = model(x_cond)
305
+ # pluck the logits at the final step and scale by temperature
306
+ logits = logits[:, -1, :] / temperature
307
+ # optionally crop probabilities to only the top k options
308
+ if top_k is not None:
309
+ logits = top_k_logits(logits, top_k)
310
+ # apply softmax to convert to probabilities
311
+ probs = F.softmax(logits, dim=-1)
312
+ # sample from the distribution or take the most likely
313
+ if sample:
314
+ ix = torch.multinomial(probs, num_samples=1)
315
+ else:
316
+ _, ix = torch.topk(probs, k=1, dim=-1)
317
+ # append to the sequence and continue
318
+ x = torch.cat((x, ix), dim=1)
319
+
320
+ return x
321
+
322
+
323
+ @torch.no_grad()
324
+ def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
325
+ top_k=None, top_p=None, callback=None):
326
+ # x is conditioning
327
+ sample = x
328
+ cond_len = x.shape[1]
329
+ past = None
330
+ for n in range(steps):
331
+ if callback is not None:
332
+ callback(n)
333
+ logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
334
+ if past is None:
335
+ past = [present]
336
+ else:
337
+ past.append(present)
338
+ logits = logits[:, -1, :] / temperature
339
+ if top_k is not None:
340
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
341
+
342
+ probs = F.softmax(logits, dim=-1)
343
+ if not sample_logits:
344
+ _, x = torch.topk(probs, k=1, dim=-1)
345
+ else:
346
+ x = torch.multinomial(probs, num_samples=1)
347
+ # append to the sequence and continue
348
+ sample = torch.cat((sample, x), dim=1)
349
+ del past
350
+ sample = sample[:, cond_len:] # cut conditioning off
351
+ return sample
352
+
353
+
354
+ #### clustering utils
355
+
356
+ class KMeans(nn.Module):
357
+ def __init__(self, ncluster=512, nc=3, niter=10):
358
+ super().__init__()
359
+ self.ncluster = ncluster
360
+ self.nc = nc
361
+ self.niter = niter
362
+ self.shape = (3,32,32)
363
+ self.register_buffer("C", torch.zeros(self.ncluster,nc))
364
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
365
+
366
+ def is_initialized(self):
367
+ return self.initialized.item() == 1
368
+
369
+ @torch.no_grad()
370
+ def initialize(self, x):
371
+ N, D = x.shape
372
+ assert D == self.nc, D
373
+ c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
374
+ for i in range(self.niter):
375
+ # assign all pixels to the closest codebook element
376
+ a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
377
+ # move each codebook element to be the mean of the pixels that assigned to it
378
+ c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
379
+ # re-assign any poorly positioned codebook elements
380
+ nanix = torch.any(torch.isnan(c), dim=1)
381
+ ndead = nanix.sum().item()
382
+ print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
383
+ c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
384
+
385
+ self.C.copy_(c)
386
+ self.initialized.fill_(1)
387
+
388
+
389
+ def forward(self, x, reverse=False, shape=None):
390
+ if not reverse:
391
+ # flatten
392
+ bs,c,h,w = x.shape
393
+ assert c == self.nc
394
+ x = x.reshape(bs,c,h*w,1)
395
+ C = self.C.permute(1,0)
396
+ C = C.reshape(1,c,1,self.ncluster)
397
+ a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
398
+ return a
399
+ else:
400
+ # flatten
401
+ bs, HW = x.shape
402
+ """
403
+ c = self.C.reshape( 1, self.nc, 1, self.ncluster)
404
+ c = c[bs*[0],:,:,:]
405
+ c = c[:,:,HW*[0],:]
406
+ x = x.reshape(bs, 1, HW, 1)
407
+ x = x[:,3*[0],:,:]
408
+ x = torch.gather(c, dim=3, index=x)
409
+ """
410
+ x = self.C[x]
411
+ x = x.permute(0,2,1)
412
+ shape = shape if shape is not None else self.shape
413
+ x = x.reshape(bs, *shape)
414
+
415
+ return x
taming/modules/transformer/permuter.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class AbstractPermuter(nn.Module):
7
+ def __init__(self, *args, **kwargs):
8
+ super().__init__()
9
+ def forward(self, x, reverse=False):
10
+ raise NotImplementedError
11
+
12
+
13
+ class Identity(AbstractPermuter):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(self, x, reverse=False):
18
+ return x
19
+
20
+
21
+ class Subsample(AbstractPermuter):
22
+ def __init__(self, H, W):
23
+ super().__init__()
24
+ C = 1
25
+ indices = np.arange(H*W).reshape(C,H,W)
26
+ while min(H, W) > 1:
27
+ indices = indices.reshape(C,H//2,2,W//2,2)
28
+ indices = indices.transpose(0,2,4,1,3)
29
+ indices = indices.reshape(C*4,H//2, W//2)
30
+ H = H//2
31
+ W = W//2
32
+ C = C*4
33
+ assert H == W == 1
34
+ idx = torch.tensor(indices.ravel())
35
+ self.register_buffer('forward_shuffle_idx',
36
+ nn.Parameter(idx, requires_grad=False))
37
+ self.register_buffer('backward_shuffle_idx',
38
+ nn.Parameter(torch.argsort(idx), requires_grad=False))
39
+
40
+ def forward(self, x, reverse=False):
41
+ if not reverse:
42
+ return x[:, self.forward_shuffle_idx]
43
+ else:
44
+ return x[:, self.backward_shuffle_idx]
45
+
46
+
47
+ def mortonify(i, j):
48
+ """(i,j) index to linear morton code"""
49
+ i = np.uint64(i)
50
+ j = np.uint64(j)
51
+
52
+ z = np.uint(0)
53
+
54
+ for pos in range(32):
55
+ z = (z |
56
+ ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
57
+ ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
58
+ )
59
+ return z
60
+
61
+
62
+ class ZCurve(AbstractPermuter):
63
+ def __init__(self, H, W):
64
+ super().__init__()
65
+ reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
66
+ idx = np.argsort(reverseidx)
67
+ idx = torch.tensor(idx)
68
+ reverseidx = torch.tensor(reverseidx)
69
+ self.register_buffer('forward_shuffle_idx',
70
+ idx)
71
+ self.register_buffer('backward_shuffle_idx',
72
+ reverseidx)
73
+
74
+ def forward(self, x, reverse=False):
75
+ if not reverse:
76
+ return x[:, self.forward_shuffle_idx]
77
+ else:
78
+ return x[:, self.backward_shuffle_idx]
79
+
80
+
81
+ class SpiralOut(AbstractPermuter):
82
+ def __init__(self, H, W):
83
+ super().__init__()
84
+ assert H == W
85
+ size = W
86
+ indices = np.arange(size*size).reshape(size,size)
87
+
88
+ i0 = size//2
89
+ j0 = size//2-1
90
+
91
+ i = i0
92
+ j = j0
93
+
94
+ idx = [indices[i0, j0]]
95
+ step_mult = 0
96
+ for c in range(1, size//2+1):
97
+ step_mult += 1
98
+ # steps left
99
+ for k in range(step_mult):
100
+ i = i - 1
101
+ j = j
102
+ idx.append(indices[i, j])
103
+
104
+ # step down
105
+ for k in range(step_mult):
106
+ i = i
107
+ j = j + 1
108
+ idx.append(indices[i, j])
109
+
110
+ step_mult += 1
111
+ if c < size//2:
112
+ # step right
113
+ for k in range(step_mult):
114
+ i = i + 1
115
+ j = j
116
+ idx.append(indices[i, j])
117
+
118
+ # step up
119
+ for k in range(step_mult):
120
+ i = i
121
+ j = j - 1
122
+ idx.append(indices[i, j])
123
+ else:
124
+ # end reached
125
+ for k in range(step_mult-1):
126
+ i = i + 1
127
+ idx.append(indices[i, j])
128
+
129
+ assert len(idx) == size*size
130
+ idx = torch.tensor(idx)
131
+ self.register_buffer('forward_shuffle_idx', idx)
132
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
133
+
134
+ def forward(self, x, reverse=False):
135
+ if not reverse:
136
+ return x[:, self.forward_shuffle_idx]
137
+ else:
138
+ return x[:, self.backward_shuffle_idx]
139
+
140
+
141
+ class SpiralIn(AbstractPermuter):
142
+ def __init__(self, H, W):
143
+ super().__init__()
144
+ assert H == W
145
+ size = W
146
+ indices = np.arange(size*size).reshape(size,size)
147
+
148
+ i0 = size//2
149
+ j0 = size//2-1
150
+
151
+ i = i0
152
+ j = j0
153
+
154
+ idx = [indices[i0, j0]]
155
+ step_mult = 0
156
+ for c in range(1, size//2+1):
157
+ step_mult += 1
158
+ # steps left
159
+ for k in range(step_mult):
160
+ i = i - 1
161
+ j = j
162
+ idx.append(indices[i, j])
163
+
164
+ # step down
165
+ for k in range(step_mult):
166
+ i = i
167
+ j = j + 1
168
+ idx.append(indices[i, j])
169
+
170
+ step_mult += 1
171
+ if c < size//2:
172
+ # step right
173
+ for k in range(step_mult):
174
+ i = i + 1
175
+ j = j
176
+ idx.append(indices[i, j])
177
+
178
+ # step up
179
+ for k in range(step_mult):
180
+ i = i
181
+ j = j - 1
182
+ idx.append(indices[i, j])
183
+ else:
184
+ # end reached
185
+ for k in range(step_mult-1):
186
+ i = i + 1
187
+ idx.append(indices[i, j])
188
+
189
+ assert len(idx) == size*size
190
+ idx = idx[::-1]
191
+ idx = torch.tensor(idx)
192
+ self.register_buffer('forward_shuffle_idx', idx)
193
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
194
+
195
+ def forward(self, x, reverse=False):
196
+ if not reverse:
197
+ return x[:, self.forward_shuffle_idx]
198
+ else:
199
+ return x[:, self.backward_shuffle_idx]
200
+
201
+
202
+ class Random(nn.Module):
203
+ def __init__(self, H, W):
204
+ super().__init__()
205
+ indices = np.random.RandomState(1).permutation(H*W)
206
+ idx = torch.tensor(indices.ravel())
207
+ self.register_buffer('forward_shuffle_idx', idx)
208
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
209
+
210
+ def forward(self, x, reverse=False):
211
+ if not reverse:
212
+ return x[:, self.forward_shuffle_idx]
213
+ else:
214
+ return x[:, self.backward_shuffle_idx]
215
+
216
+
217
+ class AlternateParsing(AbstractPermuter):
218
+ def __init__(self, H, W):
219
+ super().__init__()
220
+ indices = np.arange(W*H).reshape(H,W)
221
+ for i in range(1, H, 2):
222
+ indices[i, :] = indices[i, ::-1]
223
+ idx = indices.flatten()
224
+ assert len(idx) == H*W
225
+ idx = torch.tensor(idx)
226
+ self.register_buffer('forward_shuffle_idx', idx)
227
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
228
+
229
+ def forward(self, x, reverse=False):
230
+ if not reverse:
231
+ return x[:, self.forward_shuffle_idx]
232
+ else:
233
+ return x[:, self.backward_shuffle_idx]
234
+
235
+
236
+ if __name__ == "__main__":
237
+ p0 = AlternateParsing(16, 16)
238
+ print(p0.forward_shuffle_idx)
239
+ print(p0.backward_shuffle_idx)
240
+
241
+ x = torch.randint(0, 768, size=(11, 256))
242
+ y = p0(x)
243
+ xre = p0(y, reverse=True)
244
+ assert torch.equal(x, xre)
245
+
246
+ p1 = SpiralOut(2, 2)
247
+ print(p1.forward_shuffle_idx)
248
+ print(p1.backward_shuffle_idx)
taming/modules/util.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def count_params(model):
6
+ total_params = sum(p.numel() for p in model.parameters())
7
+ return total_params
8
+
9
+
10
+ class ActNorm(nn.Module):
11
+ def __init__(self, num_features, logdet=False, affine=True,
12
+ allow_reverse_init=False):
13
+ assert affine
14
+ super().__init__()
15
+ self.logdet = logdet
16
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18
+ self.allow_reverse_init = allow_reverse_init
19
+
20
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21
+
22
+ def initialize(self, input):
23
+ with torch.no_grad():
24
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25
+ mean = (
26
+ flatten.mean(1)
27
+ .unsqueeze(1)
28
+ .unsqueeze(2)
29
+ .unsqueeze(3)
30
+ .permute(1, 0, 2, 3)
31
+ )
32
+ std = (
33
+ flatten.std(1)
34
+ .unsqueeze(1)
35
+ .unsqueeze(2)
36
+ .unsqueeze(3)
37
+ .permute(1, 0, 2, 3)
38
+ )
39
+
40
+ self.loc.data.copy_(-mean)
41
+ self.scale.data.copy_(1 / (std + 1e-6))
42
+
43
+ def forward(self, input, reverse=False):
44
+ if reverse:
45
+ return self.reverse(input)
46
+ if len(input.shape) == 2:
47
+ input = input[:,:,None,None]
48
+ squeeze = True
49
+ else:
50
+ squeeze = False
51
+
52
+ _, _, height, width = input.shape
53
+
54
+ if self.training and self.initialized.item() == 0:
55
+ self.initialize(input)
56
+ self.initialized.fill_(1)
57
+
58
+ h = self.scale * (input + self.loc)
59
+
60
+ if squeeze:
61
+ h = h.squeeze(-1).squeeze(-1)
62
+
63
+ if self.logdet:
64
+ log_abs = torch.log(torch.abs(self.scale))
65
+ logdet = height*width*torch.sum(log_abs)
66
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
67
+ return h, logdet
68
+
69
+ return h
70
+
71
+ def reverse(self, output):
72
+ if self.training and self.initialized.item() == 0:
73
+ if not self.allow_reverse_init:
74
+ raise RuntimeError(
75
+ "Initializing ActNorm in reverse direction is "
76
+ "disabled by default. Use allow_reverse_init=True to enable."
77
+ )
78
+ else:
79
+ self.initialize(output)
80
+ self.initialized.fill_(1)
81
+
82
+ if len(output.shape) == 2:
83
+ output = output[:,:,None,None]
84
+ squeeze = True
85
+ else:
86
+ squeeze = False
87
+
88
+ h = output / self.scale - self.loc
89
+
90
+ if squeeze:
91
+ h = h.squeeze(-1).squeeze(-1)
92
+ return h
93
+
94
+
95
+ class AbstractEncoder(nn.Module):
96
+ def __init__(self):
97
+ super().__init__()
98
+
99
+ def encode(self, *args, **kwargs):
100
+ raise NotImplementedError
101
+
102
+
103
+ class Labelator(AbstractEncoder):
104
+ """Net2Net Interface for Class-Conditional Model"""
105
+ def __init__(self, n_classes, quantize_interface=True):
106
+ super().__init__()
107
+ self.n_classes = n_classes
108
+ self.quantize_interface = quantize_interface
109
+
110
+ def encode(self, c):
111
+ c = c[:,None]
112
+ if self.quantize_interface:
113
+ return c, None, [None, None, c.long()]
114
+ return c
115
+
116
+
117
+ class SOSProvider(AbstractEncoder):
118
+ # for unconditional training
119
+ def __init__(self, sos_token, quantize_interface=True):
120
+ super().__init__()
121
+ self.sos_token = sos_token
122
+ self.quantize_interface = quantize_interface
123
+
124
+ def encode(self, x):
125
+ # get batch size from data and replicate sos_token
126
+ c = torch.ones(x.shape[0], 1)*self.sos_token
127
+ c = c.long().to(x.device)
128
+ if self.quantize_interface:
129
+ return c, None, [None, None, c]
130
+ return c
taming/modules/vqvae/quantize.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from torch import einsum
6
+ from einops import rearrange
7
+
8
+
9
+ class VectorQuantizer(nn.Module):
10
+ """
11
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
12
+ ____________________________________________
13
+ Discretization bottleneck part of the VQ-VAE.
14
+ Inputs:
15
+ - n_e : number of embeddings
16
+ - e_dim : dimension of embedding
17
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
18
+ _____________________________________________
19
+ """
20
+
21
+ # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
22
+ # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
23
+ # used wherever VectorQuantizer has been used before and is additionally
24
+ # more efficient.
25
+ def __init__(self, n_e, e_dim, beta):
26
+ super(VectorQuantizer, self).__init__()
27
+ self.n_e = n_e
28
+ self.e_dim = e_dim
29
+ self.beta = beta
30
+
31
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
32
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
33
+
34
+ def forward(self, z):
35
+ """
36
+ Inputs the output of the encoder network z and maps it to a discrete
37
+ one-hot vector that is the index of the closest embedding vector e_j
38
+ z (continuous) -> z_q (discrete)
39
+ z.shape = (batch, channel, height, width)
40
+ quantization pipeline:
41
+ 1. get encoder input (B,C,H,W)
42
+ 2. flatten input to (B*H*W,C)
43
+ """
44
+ # reshape z -> (batch, height, width, channel) and flatten
45
+ z = z.permute(0, 2, 3, 1).contiguous()
46
+ z_flattened = z.view(-1, self.e_dim)
47
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
48
+
49
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
50
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
51
+ torch.matmul(z_flattened, self.embedding.weight.t())
52
+
53
+ ## could possible replace this here
54
+ # #\start...
55
+ # find closest encodings
56
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
57
+
58
+ min_encodings = torch.zeros(
59
+ min_encoding_indices.shape[0], self.n_e).to(z)
60
+ min_encodings.scatter_(1, min_encoding_indices, 1)
61
+
62
+ # dtype min encodings: torch.float32
63
+ # min_encodings shape: torch.Size([2048, 512])
64
+ # min_encoding_indices.shape: torch.Size([2048, 1])
65
+
66
+ # get quantized latent vectors
67
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
68
+ #.........\end
69
+
70
+ # with:
71
+ # .........\start
72
+ #min_encoding_indices = torch.argmin(d, dim=1)
73
+ #z_q = self.embedding(min_encoding_indices)
74
+ # ......\end......... (TODO)
75
+
76
+ # compute loss for embedding
77
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
78
+ torch.mean((z_q - z.detach()) ** 2)
79
+
80
+ # preserve gradients
81
+ z_q = z + (z_q - z).detach()
82
+
83
+ # perplexity
84
+ e_mean = torch.mean(min_encodings, dim=0)
85
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
86
+
87
+ # reshape back to match original input shape
88
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
89
+
90
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
91
+
92
+ def get_codebook_entry(self, indices, shape):
93
+ # shape specifying (batch, height, width, channel)
94
+ # TODO: check for more easy handling with nn.Embedding
95
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
96
+ min_encodings.scatter_(1, indices[:,None], 1)
97
+
98
+ # get quantized latent vectors
99
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
100
+
101
+ if shape is not None:
102
+ z_q = z_q.view(shape)
103
+
104
+ # reshape back to match original input shape
105
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
106
+
107
+ return z_q
108
+
109
+
110
+ class GumbelQuantize(nn.Module):
111
+ """
112
+ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
113
+ Gumbel Softmax trick quantizer
114
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
115
+ https://arxiv.org/abs/1611.01144
116
+ """
117
+ def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
118
+ kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
119
+ remap=None, unknown_index="random"):
120
+ super().__init__()
121
+
122
+ self.embedding_dim = embedding_dim
123
+ self.n_embed = n_embed
124
+
125
+ self.straight_through = straight_through
126
+ self.temperature = temp_init
127
+ self.kl_weight = kl_weight
128
+
129
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
130
+ self.embed = nn.Embedding(n_embed, embedding_dim)
131
+
132
+ self.use_vqinterface = use_vqinterface
133
+
134
+ self.remap = remap
135
+ if self.remap is not None:
136
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
137
+ self.re_embed = self.used.shape[0]
138
+ self.unknown_index = unknown_index # "random" or "extra" or integer
139
+ if self.unknown_index == "extra":
140
+ self.unknown_index = self.re_embed
141
+ self.re_embed = self.re_embed+1
142
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
143
+ f"Using {self.unknown_index} for unknown indices.")
144
+ else:
145
+ self.re_embed = n_embed
146
+
147
+ def remap_to_used(self, inds):
148
+ ishape = inds.shape
149
+ assert len(ishape)>1
150
+ inds = inds.reshape(ishape[0],-1)
151
+ used = self.used.to(inds)
152
+ match = (inds[:,:,None]==used[None,None,...]).long()
153
+ new = match.argmax(-1)
154
+ unknown = match.sum(2)<1
155
+ if self.unknown_index == "random":
156
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
157
+ else:
158
+ new[unknown] = self.unknown_index
159
+ return new.reshape(ishape)
160
+
161
+ def unmap_to_all(self, inds):
162
+ ishape = inds.shape
163
+ assert len(ishape)>1
164
+ inds = inds.reshape(ishape[0],-1)
165
+ used = self.used.to(inds)
166
+ if self.re_embed > self.used.shape[0]: # extra token
167
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
168
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
169
+ return back.reshape(ishape)
170
+
171
+ def forward(self, z, temp=None, return_logits=False):
172
+ # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
173
+ hard = self.straight_through if self.training else True
174
+ temp = self.temperature if temp is None else temp
175
+
176
+ logits = self.proj(z)
177
+ if self.remap is not None:
178
+ # continue only with used logits
179
+ full_zeros = torch.zeros_like(logits)
180
+ logits = logits[:,self.used,...]
181
+
182
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
183
+ if self.remap is not None:
184
+ # go back to all entries but unused set to zero
185
+ full_zeros[:,self.used,...] = soft_one_hot
186
+ soft_one_hot = full_zeros
187
+ z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
188
+
189
+ # + kl divergence to the prior loss
190
+ qy = F.softmax(logits, dim=1)
191
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
192
+
193
+ ind = soft_one_hot.argmax(dim=1)
194
+ if self.remap is not None:
195
+ ind = self.remap_to_used(ind)
196
+ if self.use_vqinterface:
197
+ if return_logits:
198
+ return z_q, diff, (None, None, ind), logits
199
+ return z_q, diff, (None, None, ind)
200
+ return z_q, diff, ind
201
+
202
+ def get_codebook_entry(self, indices, shape):
203
+ b, h, w, c = shape
204
+ assert b*h*w == indices.shape[0]
205
+ indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
206
+ if self.remap is not None:
207
+ indices = self.unmap_to_all(indices)
208
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
209
+ z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
210
+ return z_q
211
+
212
+
213
+ class VectorQuantizer2(nn.Module):
214
+ """
215
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
216
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
217
+ """
218
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
219
+ # backwards compatibility we use the buggy version by default, but you can
220
+ # specify legacy=False to fix it.
221
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
222
+ sane_index_shape=False, legacy=True):
223
+ super().__init__()
224
+ self.n_e = n_e
225
+ self.e_dim = e_dim
226
+ self.beta = beta
227
+ self.legacy = legacy
228
+
229
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
230
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
231
+
232
+ self.remap = remap
233
+ if self.remap is not None:
234
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
235
+ self.re_embed = self.used.shape[0]
236
+ self.unknown_index = unknown_index # "random" or "extra" or integer
237
+ if self.unknown_index == "extra":
238
+ self.unknown_index = self.re_embed
239
+ self.re_embed = self.re_embed+1
240
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
241
+ f"Using {self.unknown_index} for unknown indices.")
242
+ else:
243
+ self.re_embed = n_e
244
+
245
+ self.sane_index_shape = sane_index_shape
246
+
247
+ def remap_to_used(self, inds):
248
+ ishape = inds.shape
249
+ assert len(ishape)>1
250
+ inds = inds.reshape(ishape[0],-1)
251
+ used = self.used.to(inds)
252
+ match = (inds[:,:,None]==used[None,None,...]).long()
253
+ new = match.argmax(-1)
254
+ unknown = match.sum(2)<1
255
+ if self.unknown_index == "random":
256
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
257
+ else:
258
+ new[unknown] = self.unknown_index
259
+ return new.reshape(ishape)
260
+
261
+ def unmap_to_all(self, inds):
262
+ ishape = inds.shape
263
+ assert len(ishape)>1
264
+ inds = inds.reshape(ishape[0],-1)
265
+ used = self.used.to(inds)
266
+ if self.re_embed > self.used.shape[0]: # extra token
267
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
268
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
269
+ return back.reshape(ishape)
270
+
271
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
272
+ assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
273
+ assert rescale_logits==False, "Only for interface compatible with Gumbel"
274
+ assert return_logits==False, "Only for interface compatible with Gumbel"
275
+ # reshape z -> (batch, height, width, channel) and flatten
276
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
277
+ z_flattened = z.view(-1, self.e_dim)
278
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
279
+
280
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
281
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
282
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
283
+
284
+ min_encoding_indices = torch.argmin(d, dim=1)
285
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
286
+ perplexity = None
287
+ min_encodings = None
288
+
289
+ # compute loss for embedding
290
+ if not self.legacy:
291
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
292
+ torch.mean((z_q - z.detach()) ** 2)
293
+ else:
294
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
295
+ torch.mean((z_q - z.detach()) ** 2)
296
+
297
+ # preserve gradients
298
+ z_q = z + (z_q - z).detach()
299
+
300
+ # reshape back to match original input shape
301
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
302
+
303
+ if self.remap is not None:
304
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
305
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
306
+ min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
307
+
308
+ if self.sane_index_shape:
309
+ min_encoding_indices = min_encoding_indices.reshape(
310
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
311
+
312
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
313
+
314
+ def get_codebook_entry(self, indices, shape):
315
+ # shape specifying (batch, height, width, channel)
316
+ if self.remap is not None:
317
+ indices = indices.reshape(shape[0],-1) # add batch axis
318
+ indices = self.unmap_to_all(indices)
319
+ indices = indices.reshape(-1) # flatten again
320
+
321
+ # get quantized latent vectors
322
+ z_q = self.embedding(indices)
323
+
324
+ if shape is not None:
325
+ z_q = z_q.view(shape)
326
+ # reshape back to match original input shape
327
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
328
+
329
+ return z_q
330
+
331
+ class EmbeddingEMA(nn.Module):
332
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
333
+ super().__init__()
334
+ self.decay = decay
335
+ self.eps = eps
336
+ weight = torch.randn(num_tokens, codebook_dim)
337
+ self.weight = nn.Parameter(weight, requires_grad = False)
338
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
339
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
340
+ self.update = True
341
+
342
+ def forward(self, embed_id):
343
+ return F.embedding(embed_id, self.weight)
344
+
345
+ def cluster_size_ema_update(self, new_cluster_size):
346
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
347
+
348
+ def embed_avg_ema_update(self, new_embed_avg):
349
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
350
+
351
+ def weight_update(self, num_tokens):
352
+ n = self.cluster_size.sum()
353
+ smoothed_cluster_size = (
354
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
355
+ )
356
+ #normalize embedding average with smoothed cluster size
357
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
358
+ self.weight.data.copy_(embed_normalized)
359
+
360
+
361
+ class EMAVectorQuantizer(nn.Module):
362
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
363
+ remap=None, unknown_index="random"):
364
+ super().__init__()
365
+ self.codebook_dim = codebook_dim
366
+ self.num_tokens = num_tokens
367
+ self.beta = beta
368
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
369
+
370
+ self.remap = remap
371
+ if self.remap is not None:
372
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
373
+ self.re_embed = self.used.shape[0]
374
+ self.unknown_index = unknown_index # "random" or "extra" or integer
375
+ if self.unknown_index == "extra":
376
+ self.unknown_index = self.re_embed
377
+ self.re_embed = self.re_embed+1
378
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
379
+ f"Using {self.unknown_index} for unknown indices.")
380
+ else:
381
+ self.re_embed = n_embed
382
+
383
+ def remap_to_used(self, inds):
384
+ ishape = inds.shape
385
+ assert len(ishape)>1
386
+ inds = inds.reshape(ishape[0],-1)
387
+ used = self.used.to(inds)
388
+ match = (inds[:,:,None]==used[None,None,...]).long()
389
+ new = match.argmax(-1)
390
+ unknown = match.sum(2)<1
391
+ if self.unknown_index == "random":
392
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
393
+ else:
394
+ new[unknown] = self.unknown_index
395
+ return new.reshape(ishape)
396
+
397
+ def unmap_to_all(self, inds):
398
+ ishape = inds.shape
399
+ assert len(ishape)>1
400
+ inds = inds.reshape(ishape[0],-1)
401
+ used = self.used.to(inds)
402
+ if self.re_embed > self.used.shape[0]: # extra token
403
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
404
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
405
+ return back.reshape(ishape)
406
+
407
+ def forward(self, z):
408
+ # reshape z -> (batch, height, width, channel) and flatten
409
+ #z, 'b c h w -> b h w c'
410
+ z = rearrange(z, 'b c h w -> b h w c')
411
+ z_flattened = z.reshape(-1, self.codebook_dim)
412
+
413
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
414
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
415
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
416
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
417
+
418
+
419
+ encoding_indices = torch.argmin(d, dim=1)
420
+
421
+ z_q = self.embedding(encoding_indices).view(z.shape)
422
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
423
+ avg_probs = torch.mean(encodings, dim=0)
424
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
425
+
426
+ if self.training and self.embedding.update:
427
+ #EMA cluster size
428
+ encodings_sum = encodings.sum(0)
429
+ self.embedding.cluster_size_ema_update(encodings_sum)
430
+ #EMA embedding average
431
+ embed_sum = encodings.transpose(0,1) @ z_flattened
432
+ self.embedding.embed_avg_ema_update(embed_sum)
433
+ #normalize embed_avg and update weight
434
+ self.embedding.weight_update(self.num_tokens)
435
+
436
+ # compute loss for embedding
437
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
438
+
439
+ # preserve gradients
440
+ z_q = z + (z_q - z).detach()
441
+
442
+ # reshape back to match original input shape
443
+ #z_q, 'b h w c -> b c h w'
444
+ z_q = rearrange(z_q, 'b h w c -> b c h w')
445
+ return z_q, loss, (perplexity, encodings, encoding_indices)
taming/util.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, hashlib
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+ URL_MAP = {
6
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
7
+ }
8
+
9
+ CKPT_MAP = {
10
+ "vgg_lpips": "vgg.pth"
11
+ }
12
+
13
+ MD5_MAP = {
14
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
15
+ }
16
+
17
+
18
+ def download(url, local_path, chunk_size=1024):
19
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20
+ with requests.get(url, stream=True) as r:
21
+ total_size = int(r.headers.get("content-length", 0))
22
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23
+ with open(local_path, "wb") as f:
24
+ for data in r.iter_content(chunk_size=chunk_size):
25
+ if data:
26
+ f.write(data)
27
+ pbar.update(chunk_size)
28
+
29
+
30
+ def md5_hash(path):
31
+ with open(path, "rb") as f:
32
+ content = f.read()
33
+ return hashlib.md5(content).hexdigest()
34
+
35
+
36
+ def get_ckpt_path(name, root, check=False):
37
+ assert name in URL_MAP
38
+ path = os.path.join(root, CKPT_MAP[name])
39
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
40
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
41
+ download(URL_MAP[name], path)
42
+ md5 = md5_hash(path)
43
+ assert md5 == MD5_MAP[name], md5
44
+ return path
45
+
46
+
47
+ class KeyNotFoundError(Exception):
48
+ def __init__(self, cause, keys=None, visited=None):
49
+ self.cause = cause
50
+ self.keys = keys
51
+ self.visited = visited
52
+ messages = list()
53
+ if keys is not None:
54
+ messages.append("Key not found: {}".format(keys))
55
+ if visited is not None:
56
+ messages.append("Visited: {}".format(visited))
57
+ messages.append("Cause:\n{}".format(cause))
58
+ message = "\n".join(messages)
59
+ super().__init__(message)
60
+
61
+
62
+ def retrieve(
63
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
64
+ ):
65
+ """Given a nested list or dict return the desired value at key expanding
66
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
67
+ is done in-place.
68
+
69
+ Parameters
70
+ ----------
71
+ list_or_dict : list or dict
72
+ Possibly nested list or dictionary.
73
+ key : str
74
+ key/to/value, path like string describing all keys necessary to
75
+ consider to get to the desired value. List indices can also be
76
+ passed here.
77
+ splitval : str
78
+ String that defines the delimiter between keys of the
79
+ different depth levels in `key`.
80
+ default : obj
81
+ Value returned if :attr:`key` is not found.
82
+ expand : bool
83
+ Whether to expand callable nodes on the path or not.
84
+
85
+ Returns
86
+ -------
87
+ The desired value or if :attr:`default` is not ``None`` and the
88
+ :attr:`key` is not found returns ``default``.
89
+
90
+ Raises
91
+ ------
92
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
93
+ ``None``.
94
+ """
95
+
96
+ keys = key.split(splitval)
97
+
98
+ success = True
99
+ try:
100
+ visited = []
101
+ parent = None
102
+ last_key = None
103
+ for key in keys:
104
+ if callable(list_or_dict):
105
+ if not expand:
106
+ raise KeyNotFoundError(
107
+ ValueError(
108
+ "Trying to get past callable node with expand=False."
109
+ ),
110
+ keys=keys,
111
+ visited=visited,
112
+ )
113
+ list_or_dict = list_or_dict()
114
+ parent[last_key] = list_or_dict
115
+
116
+ last_key = key
117
+ parent = list_or_dict
118
+
119
+ try:
120
+ if isinstance(list_or_dict, dict):
121
+ list_or_dict = list_or_dict[key]
122
+ else:
123
+ list_or_dict = list_or_dict[int(key)]
124
+ except (KeyError, IndexError, ValueError) as e:
125
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
126
+
127
+ visited += [key]
128
+ # final expansion of retrieved value
129
+ if expand and callable(list_or_dict):
130
+ list_or_dict = list_or_dict()
131
+ parent[last_key] = list_or_dict
132
+ except KeyNotFoundError as e:
133
+ if default is None:
134
+ raise e
135
+ else:
136
+ list_or_dict = default
137
+ success = False
138
+
139
+ if not pass_success:
140
+ return list_or_dict
141
+ else:
142
+ return list_or_dict, success
143
+
144
+
145
+ if __name__ == "__main__":
146
+ config = {"keya": "a",
147
+ "keyb": "b",
148
+ "keyc":
149
+ {"cc1": 1,
150
+ "cc2": 2,
151
+ }
152
+ }
153
+ from omegaconf import OmegaConf
154
+ config = OmegaConf.create(config)
155
+ print(config)
156
+ retrieve(config, "keya")
157
+