Spaces:
Runtime error
Runtime error
Upload 40 files
Browse files- taming/__pycache__/util.cpython-38.pyc +0 -0
- taming/data/ade20k.py +124 -0
- taming/data/annotated_objects_coco.py +139 -0
- taming/data/annotated_objects_dataset.py +218 -0
- taming/data/annotated_objects_open_images.py +137 -0
- taming/data/base.py +70 -0
- taming/data/coco.py +176 -0
- taming/data/conditional_builder/objects_bbox.py +60 -0
- taming/data/conditional_builder/objects_center_points.py +168 -0
- taming/data/conditional_builder/utils.py +105 -0
- taming/data/custom.py +38 -0
- taming/data/faceshq.py +134 -0
- taming/data/helper_types.py +49 -0
- taming/data/image_transforms.py +132 -0
- taming/data/imagenet.py +558 -0
- taming/data/open_images_helper.py +379 -0
- taming/data/sflckr.py +91 -0
- taming/data/utils.py +169 -0
- taming/lr_scheduler.py +34 -0
- taming/models/cond_transformer.py +352 -0
- taming/models/dummy_cond_stage.py +22 -0
- taming/models/vqgan.py +404 -0
- taming/modules/__pycache__/util.cpython-38.pyc +0 -0
- taming/modules/autoencoder/lpips/vgg.pth +3 -0
- taming/modules/diffusionmodules/model.py +776 -0
- taming/modules/discriminator/__pycache__/model.cpython-38.pyc +0 -0
- taming/modules/discriminator/model.py +67 -0
- taming/modules/losses/__init__.py +2 -0
- taming/modules/losses/__pycache__/__init__.cpython-38.pyc +0 -0
- taming/modules/losses/__pycache__/lpips.cpython-38.pyc +0 -0
- taming/modules/losses/__pycache__/vqperceptual.cpython-38.pyc +0 -0
- taming/modules/losses/lpips.py +123 -0
- taming/modules/losses/segmentation.py +22 -0
- taming/modules/losses/vqperceptual.py +241 -0
- taming/modules/misc/coord.py +31 -0
- taming/modules/transformer/mingpt.py +415 -0
- taming/modules/transformer/permuter.py +248 -0
- taming/modules/util.py +130 -0
- taming/modules/vqvae/quantize.py +445 -0
- taming/util.py +157 -0
taming/__pycache__/util.cpython-38.pyc
ADDED
Binary file (4.12 kB). View file
|
|
taming/data/ade20k.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import albumentations
|
5 |
+
from PIL import Image
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
|
8 |
+
from taming.data.sflckr import SegmentationBase # for examples included in repo
|
9 |
+
|
10 |
+
|
11 |
+
class Examples(SegmentationBase):
|
12 |
+
def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
|
13 |
+
super().__init__(data_csv="data/ade20k_examples.txt",
|
14 |
+
data_root="data/ade20k_images",
|
15 |
+
segmentation_root="data/ade20k_segmentations",
|
16 |
+
size=size, random_crop=random_crop,
|
17 |
+
interpolation=interpolation,
|
18 |
+
n_labels=151, shift_segmentation=False)
|
19 |
+
|
20 |
+
|
21 |
+
# With semantic map and scene label
|
22 |
+
class ADE20kBase(Dataset):
|
23 |
+
def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
|
24 |
+
self.split = self.get_split()
|
25 |
+
self.n_labels = 151 # unknown + 150
|
26 |
+
self.data_csv = {"train": "data/ade20k_train.txt",
|
27 |
+
"validation": "data/ade20k_test.txt"}[self.split]
|
28 |
+
self.data_root = "data/ade20k_root"
|
29 |
+
with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
|
30 |
+
self.scene_categories = f.read().splitlines()
|
31 |
+
self.scene_categories = dict(line.split() for line in self.scene_categories)
|
32 |
+
with open(self.data_csv, "r") as f:
|
33 |
+
self.image_paths = f.read().splitlines()
|
34 |
+
self._length = len(self.image_paths)
|
35 |
+
self.labels = {
|
36 |
+
"relative_file_path_": [l for l in self.image_paths],
|
37 |
+
"file_path_": [os.path.join(self.data_root, "images", l)
|
38 |
+
for l in self.image_paths],
|
39 |
+
"relative_segmentation_path_": [l.replace(".jpg", ".png")
|
40 |
+
for l in self.image_paths],
|
41 |
+
"segmentation_path_": [os.path.join(self.data_root, "annotations",
|
42 |
+
l.replace(".jpg", ".png"))
|
43 |
+
for l in self.image_paths],
|
44 |
+
"scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
|
45 |
+
for l in self.image_paths],
|
46 |
+
}
|
47 |
+
|
48 |
+
size = None if size is not None and size<=0 else size
|
49 |
+
self.size = size
|
50 |
+
if crop_size is None:
|
51 |
+
self.crop_size = size if size is not None else None
|
52 |
+
else:
|
53 |
+
self.crop_size = crop_size
|
54 |
+
if self.size is not None:
|
55 |
+
self.interpolation = interpolation
|
56 |
+
self.interpolation = {
|
57 |
+
"nearest": cv2.INTER_NEAREST,
|
58 |
+
"bilinear": cv2.INTER_LINEAR,
|
59 |
+
"bicubic": cv2.INTER_CUBIC,
|
60 |
+
"area": cv2.INTER_AREA,
|
61 |
+
"lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
|
62 |
+
self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
|
63 |
+
interpolation=self.interpolation)
|
64 |
+
self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
|
65 |
+
interpolation=cv2.INTER_NEAREST)
|
66 |
+
|
67 |
+
if crop_size is not None:
|
68 |
+
self.center_crop = not random_crop
|
69 |
+
if self.center_crop:
|
70 |
+
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
|
71 |
+
else:
|
72 |
+
self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
|
73 |
+
self.preprocessor = self.cropper
|
74 |
+
|
75 |
+
def __len__(self):
|
76 |
+
return self._length
|
77 |
+
|
78 |
+
def __getitem__(self, i):
|
79 |
+
example = dict((k, self.labels[k][i]) for k in self.labels)
|
80 |
+
image = Image.open(example["file_path_"])
|
81 |
+
if not image.mode == "RGB":
|
82 |
+
image = image.convert("RGB")
|
83 |
+
image = np.array(image).astype(np.uint8)
|
84 |
+
if self.size is not None:
|
85 |
+
image = self.image_rescaler(image=image)["image"]
|
86 |
+
segmentation = Image.open(example["segmentation_path_"])
|
87 |
+
segmentation = np.array(segmentation).astype(np.uint8)
|
88 |
+
if self.size is not None:
|
89 |
+
segmentation = self.segmentation_rescaler(image=segmentation)["image"]
|
90 |
+
if self.size is not None:
|
91 |
+
processed = self.preprocessor(image=image, mask=segmentation)
|
92 |
+
else:
|
93 |
+
processed = {"image": image, "mask": segmentation}
|
94 |
+
example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
|
95 |
+
segmentation = processed["mask"]
|
96 |
+
onehot = np.eye(self.n_labels)[segmentation]
|
97 |
+
example["segmentation"] = onehot
|
98 |
+
return example
|
99 |
+
|
100 |
+
|
101 |
+
class ADE20kTrain(ADE20kBase):
|
102 |
+
# default to random_crop=True
|
103 |
+
def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
|
104 |
+
super().__init__(config=config, size=size, random_crop=random_crop,
|
105 |
+
interpolation=interpolation, crop_size=crop_size)
|
106 |
+
|
107 |
+
def get_split(self):
|
108 |
+
return "train"
|
109 |
+
|
110 |
+
|
111 |
+
class ADE20kValidation(ADE20kBase):
|
112 |
+
def get_split(self):
|
113 |
+
return "validation"
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
dset = ADE20kValidation()
|
118 |
+
ex = dset[0]
|
119 |
+
for k in ["image", "scene_category", "segmentation"]:
|
120 |
+
print(type(ex[k]))
|
121 |
+
try:
|
122 |
+
print(ex[k].shape)
|
123 |
+
except:
|
124 |
+
print(ex[k])
|
taming/data/annotated_objects_coco.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from itertools import chain
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Iterable, Dict, List, Callable, Any
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
|
10 |
+
from taming.data.helper_types import Annotation, ImageDescription, Category
|
11 |
+
|
12 |
+
COCO_PATH_STRUCTURE = {
|
13 |
+
'train': {
|
14 |
+
'top_level': '',
|
15 |
+
'instances_annotations': 'annotations/instances_train2017.json',
|
16 |
+
'stuff_annotations': 'annotations/stuff_train2017.json',
|
17 |
+
'files': 'train2017'
|
18 |
+
},
|
19 |
+
'validation': {
|
20 |
+
'top_level': '',
|
21 |
+
'instances_annotations': 'annotations/instances_val2017.json',
|
22 |
+
'stuff_annotations': 'annotations/stuff_val2017.json',
|
23 |
+
'files': 'val2017'
|
24 |
+
}
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
|
29 |
+
return {
|
30 |
+
str(img['id']): ImageDescription(
|
31 |
+
id=img['id'],
|
32 |
+
license=img.get('license'),
|
33 |
+
file_name=img['file_name'],
|
34 |
+
coco_url=img['coco_url'],
|
35 |
+
original_size=(img['width'], img['height']),
|
36 |
+
date_captured=img.get('date_captured'),
|
37 |
+
flickr_url=img.get('flickr_url')
|
38 |
+
)
|
39 |
+
for img in description_json
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
def load_categories(category_json: Iterable) -> Dict[str, Category]:
|
44 |
+
return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
|
45 |
+
for cat in category_json if cat['name'] != 'other'}
|
46 |
+
|
47 |
+
|
48 |
+
def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
|
49 |
+
category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
|
50 |
+
annotations = defaultdict(list)
|
51 |
+
total = sum(len(a) for a in annotations_json)
|
52 |
+
for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
|
53 |
+
image_id = str(ann['image_id'])
|
54 |
+
if image_id not in image_descriptions:
|
55 |
+
raise ValueError(f'image_id [{image_id}] has no image description.')
|
56 |
+
category_id = ann['category_id']
|
57 |
+
try:
|
58 |
+
category_no = category_no_for_id(str(category_id))
|
59 |
+
except KeyError:
|
60 |
+
continue
|
61 |
+
|
62 |
+
width, height = image_descriptions[image_id].original_size
|
63 |
+
bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
|
64 |
+
|
65 |
+
annotations[image_id].append(
|
66 |
+
Annotation(
|
67 |
+
id=ann['id'],
|
68 |
+
area=bbox[2]*bbox[3], # use bbox area
|
69 |
+
is_group_of=ann['iscrowd'],
|
70 |
+
image_id=ann['image_id'],
|
71 |
+
bbox=bbox,
|
72 |
+
category_id=str(category_id),
|
73 |
+
category_no=category_no
|
74 |
+
)
|
75 |
+
)
|
76 |
+
return dict(annotations)
|
77 |
+
|
78 |
+
|
79 |
+
class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
|
80 |
+
def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
|
81 |
+
"""
|
82 |
+
@param data_path: is the path to the following folder structure:
|
83 |
+
coco/
|
84 |
+
βββ annotations
|
85 |
+
β βββ instances_train2017.json
|
86 |
+
β βββ instances_val2017.json
|
87 |
+
β βββ stuff_train2017.json
|
88 |
+
β βββ stuff_val2017.json
|
89 |
+
βββ train2017
|
90 |
+
β βββ 000000000009.jpg
|
91 |
+
β βββ 000000000025.jpg
|
92 |
+
β βββ ...
|
93 |
+
βββ val2017
|
94 |
+
β βββ 000000000139.jpg
|
95 |
+
β βββ 000000000285.jpg
|
96 |
+
β βββ ...
|
97 |
+
@param: split: one of 'train' or 'validation'
|
98 |
+
@param: desired image size (give square images)
|
99 |
+
"""
|
100 |
+
super().__init__(**kwargs)
|
101 |
+
self.use_things = use_things
|
102 |
+
self.use_stuff = use_stuff
|
103 |
+
|
104 |
+
with open(self.paths['instances_annotations']) as f:
|
105 |
+
inst_data_json = json.load(f)
|
106 |
+
with open(self.paths['stuff_annotations']) as f:
|
107 |
+
stuff_data_json = json.load(f)
|
108 |
+
|
109 |
+
category_jsons = []
|
110 |
+
annotation_jsons = []
|
111 |
+
if self.use_things:
|
112 |
+
category_jsons.append(inst_data_json['categories'])
|
113 |
+
annotation_jsons.append(inst_data_json['annotations'])
|
114 |
+
if self.use_stuff:
|
115 |
+
category_jsons.append(stuff_data_json['categories'])
|
116 |
+
annotation_jsons.append(stuff_data_json['annotations'])
|
117 |
+
|
118 |
+
self.categories = load_categories(chain(*category_jsons))
|
119 |
+
self.filter_categories()
|
120 |
+
self.setup_category_id_and_number()
|
121 |
+
|
122 |
+
self.image_descriptions = load_image_descriptions(inst_data_json['images'])
|
123 |
+
annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
|
124 |
+
self.annotations = self.filter_object_number(annotations, self.min_object_area,
|
125 |
+
self.min_objects_per_image, self.max_objects_per_image)
|
126 |
+
self.image_ids = list(self.annotations.keys())
|
127 |
+
self.clean_up_annotations_and_image_descriptions()
|
128 |
+
|
129 |
+
def get_path_structure(self) -> Dict[str, str]:
|
130 |
+
if self.split not in COCO_PATH_STRUCTURE:
|
131 |
+
raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
|
132 |
+
return COCO_PATH_STRUCTURE[self.split]
|
133 |
+
|
134 |
+
def get_image_path(self, image_id: str) -> Path:
|
135 |
+
return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
|
136 |
+
|
137 |
+
def get_image_description(self, image_id: str) -> Dict[str, Any]:
|
138 |
+
# noinspection PyProtectedMember
|
139 |
+
return self.image_descriptions[image_id]._asdict()
|
taming/data/annotated_objects_dataset.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Optional, List, Callable, Dict, Any, Union
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import PIL.Image as pil_image
|
6 |
+
from torch import Tensor
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from torchvision import transforms
|
9 |
+
|
10 |
+
from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
|
11 |
+
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
|
12 |
+
from taming.data.conditional_builder.utils import load_object_from_string
|
13 |
+
from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
|
14 |
+
from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
|
15 |
+
Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
|
16 |
+
|
17 |
+
|
18 |
+
class AnnotatedObjectsDataset(Dataset):
|
19 |
+
def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
|
20 |
+
min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
|
21 |
+
crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
|
22 |
+
encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
|
23 |
+
no_object_classes: Optional[int] = None):
|
24 |
+
self.data_path = data_path
|
25 |
+
self.split = split
|
26 |
+
self.keys = keys
|
27 |
+
self.target_image_size = target_image_size
|
28 |
+
self.min_object_area = min_object_area
|
29 |
+
self.min_objects_per_image = min_objects_per_image
|
30 |
+
self.max_objects_per_image = max_objects_per_image
|
31 |
+
self.crop_method = crop_method
|
32 |
+
self.random_flip = random_flip
|
33 |
+
self.no_tokens = no_tokens
|
34 |
+
self.use_group_parameter = use_group_parameter
|
35 |
+
self.encode_crop = encode_crop
|
36 |
+
|
37 |
+
self.annotations = None
|
38 |
+
self.image_descriptions = None
|
39 |
+
self.categories = None
|
40 |
+
self.category_ids = None
|
41 |
+
self.category_number = None
|
42 |
+
self.image_ids = None
|
43 |
+
self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
|
44 |
+
self.paths = self.build_paths(self.data_path)
|
45 |
+
self._conditional_builders = None
|
46 |
+
self.category_allow_list = None
|
47 |
+
if category_allow_list_target:
|
48 |
+
allow_list = load_object_from_string(category_allow_list_target)
|
49 |
+
self.category_allow_list = {name for name, _ in allow_list}
|
50 |
+
self.category_mapping = {}
|
51 |
+
if category_mapping_target:
|
52 |
+
self.category_mapping = load_object_from_string(category_mapping_target)
|
53 |
+
self.no_object_classes = no_object_classes
|
54 |
+
|
55 |
+
def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
|
56 |
+
top_level = Path(top_level)
|
57 |
+
sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
|
58 |
+
for path in sub_paths.values():
|
59 |
+
if not path.exists():
|
60 |
+
raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
|
61 |
+
return sub_paths
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def load_image_from_disk(path: Path) -> Image:
|
65 |
+
return pil_image.open(path).convert('RGB')
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
|
69 |
+
transform_functions = []
|
70 |
+
if crop_method == 'none':
|
71 |
+
transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
|
72 |
+
elif crop_method == 'center':
|
73 |
+
transform_functions.extend([
|
74 |
+
transforms.Resize(target_image_size),
|
75 |
+
CenterCropReturnCoordinates(target_image_size)
|
76 |
+
])
|
77 |
+
elif crop_method == 'random-1d':
|
78 |
+
transform_functions.extend([
|
79 |
+
transforms.Resize(target_image_size),
|
80 |
+
RandomCrop1dReturnCoordinates(target_image_size)
|
81 |
+
])
|
82 |
+
elif crop_method == 'random-2d':
|
83 |
+
transform_functions.extend([
|
84 |
+
Random2dCropReturnCoordinates(target_image_size),
|
85 |
+
transforms.Resize(target_image_size)
|
86 |
+
])
|
87 |
+
elif crop_method is None:
|
88 |
+
return None
|
89 |
+
else:
|
90 |
+
raise ValueError(f'Received invalid crop method [{crop_method}].')
|
91 |
+
if random_flip:
|
92 |
+
transform_functions.append(RandomHorizontalFlipReturn())
|
93 |
+
transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
|
94 |
+
return transform_functions
|
95 |
+
|
96 |
+
def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
|
97 |
+
crop_bbox = None
|
98 |
+
flipped = None
|
99 |
+
for t in self.transform_functions:
|
100 |
+
if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
|
101 |
+
crop_bbox, x = t(x)
|
102 |
+
elif isinstance(t, RandomHorizontalFlipReturn):
|
103 |
+
flipped, x = t(x)
|
104 |
+
else:
|
105 |
+
x = t(x)
|
106 |
+
return crop_bbox, flipped, x
|
107 |
+
|
108 |
+
@property
|
109 |
+
def no_classes(self) -> int:
|
110 |
+
return self.no_object_classes if self.no_object_classes else len(self.categories)
|
111 |
+
|
112 |
+
@property
|
113 |
+
def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
|
114 |
+
# cannot set this up in init because no_classes is only known after loading data in init of superclass
|
115 |
+
if self._conditional_builders is None:
|
116 |
+
self._conditional_builders = {
|
117 |
+
'objects_center_points': ObjectsCenterPointsConditionalBuilder(
|
118 |
+
self.no_classes,
|
119 |
+
self.max_objects_per_image,
|
120 |
+
self.no_tokens,
|
121 |
+
self.encode_crop,
|
122 |
+
self.use_group_parameter,
|
123 |
+
getattr(self, 'use_additional_parameters', False)
|
124 |
+
),
|
125 |
+
'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
|
126 |
+
self.no_classes,
|
127 |
+
self.max_objects_per_image,
|
128 |
+
self.no_tokens,
|
129 |
+
self.encode_crop,
|
130 |
+
self.use_group_parameter,
|
131 |
+
getattr(self, 'use_additional_parameters', False)
|
132 |
+
)
|
133 |
+
}
|
134 |
+
return self._conditional_builders
|
135 |
+
|
136 |
+
def filter_categories(self) -> None:
|
137 |
+
if self.category_allow_list:
|
138 |
+
self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
|
139 |
+
if self.category_mapping:
|
140 |
+
self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
|
141 |
+
|
142 |
+
def setup_category_id_and_number(self) -> None:
|
143 |
+
self.category_ids = list(self.categories.keys())
|
144 |
+
self.category_ids.sort()
|
145 |
+
if '/m/01s55n' in self.category_ids:
|
146 |
+
self.category_ids.remove('/m/01s55n')
|
147 |
+
self.category_ids.append('/m/01s55n')
|
148 |
+
self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
|
149 |
+
if self.category_allow_list is not None and self.category_mapping is None \
|
150 |
+
and len(self.category_ids) != len(self.category_allow_list):
|
151 |
+
warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
|
152 |
+
'Make sure all names in category_allow_list exist.')
|
153 |
+
|
154 |
+
def clean_up_annotations_and_image_descriptions(self) -> None:
|
155 |
+
image_id_set = set(self.image_ids)
|
156 |
+
self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
|
157 |
+
self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}
|
158 |
+
|
159 |
+
@staticmethod
|
160 |
+
def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
|
161 |
+
min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
|
162 |
+
filtered = {}
|
163 |
+
for image_id, annotations in all_annotations.items():
|
164 |
+
annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
|
165 |
+
if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
|
166 |
+
filtered[image_id] = annotations_with_min_area
|
167 |
+
return filtered
|
168 |
+
|
169 |
+
def __len__(self):
|
170 |
+
return len(self.image_ids)
|
171 |
+
|
172 |
+
def __getitem__(self, n: int) -> Dict[str, Any]:
|
173 |
+
image_id = self.get_image_id(n)
|
174 |
+
sample = self.get_image_description(image_id)
|
175 |
+
sample['annotations'] = self.get_annotation(image_id)
|
176 |
+
|
177 |
+
if 'image' in self.keys:
|
178 |
+
sample['image_path'] = str(self.get_image_path(image_id))
|
179 |
+
sample['image'] = self.load_image_from_disk(sample['image_path'])
|
180 |
+
sample['image'] = convert_pil_to_tensor(sample['image'])
|
181 |
+
sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
|
182 |
+
sample['image'] = sample['image'].permute(1, 2, 0)
|
183 |
+
|
184 |
+
for conditional, builder in self.conditional_builders.items():
|
185 |
+
if conditional in self.keys:
|
186 |
+
sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])
|
187 |
+
|
188 |
+
if self.keys:
|
189 |
+
# only return specified keys
|
190 |
+
sample = {key: sample[key] for key in self.keys}
|
191 |
+
return sample
|
192 |
+
|
193 |
+
def get_image_id(self, no: int) -> str:
|
194 |
+
return self.image_ids[no]
|
195 |
+
|
196 |
+
def get_annotation(self, image_id: str) -> str:
|
197 |
+
return self.annotations[image_id]
|
198 |
+
|
199 |
+
def get_textual_label_for_category_id(self, category_id: str) -> str:
|
200 |
+
return self.categories[category_id].name
|
201 |
+
|
202 |
+
def get_textual_label_for_category_no(self, category_no: int) -> str:
|
203 |
+
return self.categories[self.get_category_id(category_no)].name
|
204 |
+
|
205 |
+
def get_category_number(self, category_id: str) -> int:
|
206 |
+
return self.category_number[category_id]
|
207 |
+
|
208 |
+
def get_category_id(self, category_no: int) -> str:
|
209 |
+
return self.category_ids[category_no]
|
210 |
+
|
211 |
+
def get_image_description(self, image_id: str) -> Dict[str, Any]:
|
212 |
+
raise NotImplementedError()
|
213 |
+
|
214 |
+
def get_path_structure(self):
|
215 |
+
raise NotImplementedError
|
216 |
+
|
217 |
+
def get_image_path(self, image_id: str) -> Path:
|
218 |
+
raise NotImplementedError
|
taming/data/annotated_objects_open_images.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from csv import DictReader, reader as TupleReader
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Dict, List, Any
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
|
8 |
+
from taming.data.helper_types import Annotation, Category
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
OPEN_IMAGES_STRUCTURE = {
|
12 |
+
'train': {
|
13 |
+
'top_level': '',
|
14 |
+
'class_descriptions': 'class-descriptions-boxable.csv',
|
15 |
+
'annotations': 'oidv6-train-annotations-bbox.csv',
|
16 |
+
'file_list': 'train-images-boxable.csv',
|
17 |
+
'files': 'train'
|
18 |
+
},
|
19 |
+
'validation': {
|
20 |
+
'top_level': '',
|
21 |
+
'class_descriptions': 'class-descriptions-boxable.csv',
|
22 |
+
'annotations': 'validation-annotations-bbox.csv',
|
23 |
+
'file_list': 'validation-images.csv',
|
24 |
+
'files': 'validation'
|
25 |
+
},
|
26 |
+
'test': {
|
27 |
+
'top_level': '',
|
28 |
+
'class_descriptions': 'class-descriptions-boxable.csv',
|
29 |
+
'annotations': 'test-annotations-bbox.csv',
|
30 |
+
'file_list': 'test-images.csv',
|
31 |
+
'files': 'test'
|
32 |
+
}
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
|
37 |
+
category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
|
38 |
+
annotations: Dict[str, List[Annotation]] = defaultdict(list)
|
39 |
+
with open(descriptor_path) as file:
|
40 |
+
reader = DictReader(file)
|
41 |
+
for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'):
|
42 |
+
width = float(row['XMax']) - float(row['XMin'])
|
43 |
+
height = float(row['YMax']) - float(row['YMin'])
|
44 |
+
area = width * height
|
45 |
+
category_id = row['LabelName']
|
46 |
+
if category_id in category_mapping:
|
47 |
+
category_id = category_mapping[category_id]
|
48 |
+
if area >= min_object_area and category_id in category_no_for_id:
|
49 |
+
annotations[row['ImageID']].append(
|
50 |
+
Annotation(
|
51 |
+
id=i,
|
52 |
+
image_id=row['ImageID'],
|
53 |
+
source=row['Source'],
|
54 |
+
category_id=category_id,
|
55 |
+
category_no=category_no_for_id[category_id],
|
56 |
+
confidence=float(row['Confidence']),
|
57 |
+
bbox=(float(row['XMin']), float(row['YMin']), width, height),
|
58 |
+
area=area,
|
59 |
+
is_occluded=bool(int(row['IsOccluded'])),
|
60 |
+
is_truncated=bool(int(row['IsTruncated'])),
|
61 |
+
is_group_of=bool(int(row['IsGroupOf'])),
|
62 |
+
is_depiction=bool(int(row['IsDepiction'])),
|
63 |
+
is_inside=bool(int(row['IsInside']))
|
64 |
+
)
|
65 |
+
)
|
66 |
+
if 'train' in str(descriptor_path) and i < 14000000:
|
67 |
+
warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].')
|
68 |
+
return dict(annotations)
|
69 |
+
|
70 |
+
|
71 |
+
def load_image_ids(csv_path: Path) -> List[str]:
|
72 |
+
with open(csv_path) as file:
|
73 |
+
reader = DictReader(file)
|
74 |
+
return [row['image_name'] for row in reader]
|
75 |
+
|
76 |
+
|
77 |
+
def load_categories(csv_path: Path) -> Dict[str, Category]:
|
78 |
+
with open(csv_path) as file:
|
79 |
+
reader = TupleReader(file)
|
80 |
+
return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader}
|
81 |
+
|
82 |
+
|
83 |
+
class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset):
|
84 |
+
def __init__(self, use_additional_parameters: bool, **kwargs):
|
85 |
+
"""
|
86 |
+
@param data_path: is the path to the following folder structure:
|
87 |
+
open_images/
|
88 |
+
β oidv6-train-annotations-bbox.csv
|
89 |
+
βββ class-descriptions-boxable.csv
|
90 |
+
βββ oidv6-train-annotations-bbox.csv
|
91 |
+
βββ test
|
92 |
+
β βββ 000026e7ee790996.jpg
|
93 |
+
β βββ 000062a39995e348.jpg
|
94 |
+
β βββ ...
|
95 |
+
βββ test-annotations-bbox.csv
|
96 |
+
βββ test-images.csv
|
97 |
+
βββ train
|
98 |
+
β βββ 000002b66c9c498e.jpg
|
99 |
+
β βββ 000002b97e5471a0.jpg
|
100 |
+
β βββ ...
|
101 |
+
βββ train-images-boxable.csv
|
102 |
+
βββ validation
|
103 |
+
β βββ 0001eeaf4aed83f9.jpg
|
104 |
+
β βββ 0004886b7d043cfd.jpg
|
105 |
+
β βββ ...
|
106 |
+
βββ validation-annotations-bbox.csv
|
107 |
+
βββ validation-images.csv
|
108 |
+
@param: split: one of 'train', 'validation' or 'test'
|
109 |
+
@param: desired image size (returns square images)
|
110 |
+
"""
|
111 |
+
|
112 |
+
super().__init__(**kwargs)
|
113 |
+
self.use_additional_parameters = use_additional_parameters
|
114 |
+
|
115 |
+
self.categories = load_categories(self.paths['class_descriptions'])
|
116 |
+
self.filter_categories()
|
117 |
+
self.setup_category_id_and_number()
|
118 |
+
|
119 |
+
self.image_descriptions = {}
|
120 |
+
annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
|
121 |
+
self.category_number)
|
122 |
+
self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
|
123 |
+
self.max_objects_per_image)
|
124 |
+
self.image_ids = list(self.annotations.keys())
|
125 |
+
self.clean_up_annotations_and_image_descriptions()
|
126 |
+
|
127 |
+
def get_path_structure(self) -> Dict[str, str]:
|
128 |
+
if self.split not in OPEN_IMAGES_STRUCTURE:
|
129 |
+
raise ValueError(f'Split [{self.split} does not exist for Open Images data.]')
|
130 |
+
return OPEN_IMAGES_STRUCTURE[self.split]
|
131 |
+
|
132 |
+
def get_image_path(self, image_id: str) -> Path:
|
133 |
+
return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
|
134 |
+
|
135 |
+
def get_image_description(self, image_id: str) -> Dict[str, Any]:
|
136 |
+
image_path = self.get_image_path(image_id)
|
137 |
+
return {'file_path': str(image_path), 'file_name': image_path.name}
|
taming/data/base.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bisect
|
2 |
+
import numpy as np
|
3 |
+
import albumentations
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset, ConcatDataset
|
6 |
+
|
7 |
+
|
8 |
+
class ConcatDatasetWithIndex(ConcatDataset):
|
9 |
+
"""Modified from original pytorch code to return dataset idx"""
|
10 |
+
def __getitem__(self, idx):
|
11 |
+
if idx < 0:
|
12 |
+
if -idx > len(self):
|
13 |
+
raise ValueError("absolute value of index should not exceed dataset length")
|
14 |
+
idx = len(self) + idx
|
15 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
16 |
+
if dataset_idx == 0:
|
17 |
+
sample_idx = idx
|
18 |
+
else:
|
19 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
20 |
+
return self.datasets[dataset_idx][sample_idx], dataset_idx
|
21 |
+
|
22 |
+
|
23 |
+
class ImagePaths(Dataset):
|
24 |
+
def __init__(self, paths, size=None, random_crop=False, labels=None):
|
25 |
+
self.size = size
|
26 |
+
self.random_crop = random_crop
|
27 |
+
|
28 |
+
self.labels = dict() if labels is None else labels
|
29 |
+
self.labels["file_path_"] = paths
|
30 |
+
self._length = len(paths)
|
31 |
+
|
32 |
+
if self.size is not None and self.size > 0:
|
33 |
+
self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
|
34 |
+
if not self.random_crop:
|
35 |
+
self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
|
36 |
+
else:
|
37 |
+
self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
|
38 |
+
self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
|
39 |
+
else:
|
40 |
+
self.preprocessor = lambda **kwargs: kwargs
|
41 |
+
|
42 |
+
def __len__(self):
|
43 |
+
return self._length
|
44 |
+
|
45 |
+
def preprocess_image(self, image_path):
|
46 |
+
image = Image.open(image_path)
|
47 |
+
if not image.mode == "RGB":
|
48 |
+
image = image.convert("RGB")
|
49 |
+
image = np.array(image).astype(np.uint8)
|
50 |
+
image = self.preprocessor(image=image)["image"]
|
51 |
+
image = (image/127.5 - 1.0).astype(np.float32)
|
52 |
+
return image
|
53 |
+
|
54 |
+
def __getitem__(self, i):
|
55 |
+
example = dict()
|
56 |
+
example["image"] = self.preprocess_image(self.labels["file_path_"][i])
|
57 |
+
for k in self.labels:
|
58 |
+
example[k] = self.labels[k][i]
|
59 |
+
return example
|
60 |
+
|
61 |
+
|
62 |
+
class NumpyPaths(ImagePaths):
|
63 |
+
def preprocess_image(self, image_path):
|
64 |
+
image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
|
65 |
+
image = np.transpose(image, (1,2,0))
|
66 |
+
image = Image.fromarray(image, mode="RGB")
|
67 |
+
image = np.array(image).astype(np.uint8)
|
68 |
+
image = self.preprocessor(image=image)["image"]
|
69 |
+
image = (image/127.5 - 1.0).astype(np.float32)
|
70 |
+
return image
|
taming/data/coco.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import albumentations
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
|
9 |
+
from taming.data.sflckr import SegmentationBase # for examples included in repo
|
10 |
+
|
11 |
+
|
12 |
+
class Examples(SegmentationBase):
|
13 |
+
def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
|
14 |
+
super().__init__(data_csv="data/coco_examples.txt",
|
15 |
+
data_root="data/coco_images",
|
16 |
+
segmentation_root="data/coco_segmentations",
|
17 |
+
size=size, random_crop=random_crop,
|
18 |
+
interpolation=interpolation,
|
19 |
+
n_labels=183, shift_segmentation=True)
|
20 |
+
|
21 |
+
|
22 |
+
class CocoBase(Dataset):
|
23 |
+
"""needed for (image, caption, segmentation) pairs"""
|
24 |
+
def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
|
25 |
+
crop_size=None, force_no_crop=False, given_files=None):
|
26 |
+
self.split = self.get_split()
|
27 |
+
self.size = size
|
28 |
+
if crop_size is None:
|
29 |
+
self.crop_size = size
|
30 |
+
else:
|
31 |
+
self.crop_size = crop_size
|
32 |
+
|
33 |
+
self.onehot = onehot_segmentation # return segmentation as rgb or one hot
|
34 |
+
self.stuffthing = use_stuffthing # include thing in segmentation
|
35 |
+
if self.onehot and not self.stuffthing:
|
36 |
+
raise NotImplemented("One hot mode is only supported for the "
|
37 |
+
"stuffthings version because labels are stored "
|
38 |
+
"a bit different.")
|
39 |
+
|
40 |
+
data_json = datajson
|
41 |
+
with open(data_json) as json_file:
|
42 |
+
self.json_data = json.load(json_file)
|
43 |
+
self.img_id_to_captions = dict()
|
44 |
+
self.img_id_to_filepath = dict()
|
45 |
+
self.img_id_to_segmentation_filepath = dict()
|
46 |
+
|
47 |
+
assert data_json.split("/")[-1] in ["captions_train2017.json",
|
48 |
+
"captions_val2017.json"]
|
49 |
+
if self.stuffthing:
|
50 |
+
self.segmentation_prefix = (
|
51 |
+
"data/cocostuffthings/val2017" if
|
52 |
+
data_json.endswith("captions_val2017.json") else
|
53 |
+
"data/cocostuffthings/train2017")
|
54 |
+
else:
|
55 |
+
self.segmentation_prefix = (
|
56 |
+
"data/coco/annotations/stuff_val2017_pixelmaps" if
|
57 |
+
data_json.endswith("captions_val2017.json") else
|
58 |
+
"data/coco/annotations/stuff_train2017_pixelmaps")
|
59 |
+
|
60 |
+
imagedirs = self.json_data["images"]
|
61 |
+
self.labels = {"image_ids": list()}
|
62 |
+
for imgdir in tqdm(imagedirs, desc="ImgToPath"):
|
63 |
+
self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
|
64 |
+
self.img_id_to_captions[imgdir["id"]] = list()
|
65 |
+
pngfilename = imgdir["file_name"].replace("jpg", "png")
|
66 |
+
self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
|
67 |
+
self.segmentation_prefix, pngfilename)
|
68 |
+
if given_files is not None:
|
69 |
+
if pngfilename in given_files:
|
70 |
+
self.labels["image_ids"].append(imgdir["id"])
|
71 |
+
else:
|
72 |
+
self.labels["image_ids"].append(imgdir["id"])
|
73 |
+
|
74 |
+
capdirs = self.json_data["annotations"]
|
75 |
+
for capdir in tqdm(capdirs, desc="ImgToCaptions"):
|
76 |
+
# there are in average 5 captions per image
|
77 |
+
self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
|
78 |
+
|
79 |
+
self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
|
80 |
+
if self.split=="validation":
|
81 |
+
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
|
82 |
+
else:
|
83 |
+
self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
|
84 |
+
self.preprocessor = albumentations.Compose(
|
85 |
+
[self.rescaler, self.cropper],
|
86 |
+
additional_targets={"segmentation": "image"})
|
87 |
+
if force_no_crop:
|
88 |
+
self.rescaler = albumentations.Resize(height=self.size, width=self.size)
|
89 |
+
self.preprocessor = albumentations.Compose(
|
90 |
+
[self.rescaler],
|
91 |
+
additional_targets={"segmentation": "image"})
|
92 |
+
|
93 |
+
def __len__(self):
|
94 |
+
return len(self.labels["image_ids"])
|
95 |
+
|
96 |
+
def preprocess_image(self, image_path, segmentation_path):
|
97 |
+
image = Image.open(image_path)
|
98 |
+
if not image.mode == "RGB":
|
99 |
+
image = image.convert("RGB")
|
100 |
+
image = np.array(image).astype(np.uint8)
|
101 |
+
|
102 |
+
segmentation = Image.open(segmentation_path)
|
103 |
+
if not self.onehot and not segmentation.mode == "RGB":
|
104 |
+
segmentation = segmentation.convert("RGB")
|
105 |
+
segmentation = np.array(segmentation).astype(np.uint8)
|
106 |
+
if self.onehot:
|
107 |
+
assert self.stuffthing
|
108 |
+
# stored in caffe format: unlabeled==255. stuff and thing from
|
109 |
+
# 0-181. to be compatible with the labels in
|
110 |
+
# https://github.com/nightrome/cocostuff/blob/master/labels.txt
|
111 |
+
# we shift stuffthing one to the right and put unlabeled in zero
|
112 |
+
# as long as segmentation is uint8 shifting to right handles the
|
113 |
+
# latter too
|
114 |
+
assert segmentation.dtype == np.uint8
|
115 |
+
segmentation = segmentation + 1
|
116 |
+
|
117 |
+
processed = self.preprocessor(image=image, segmentation=segmentation)
|
118 |
+
image, segmentation = processed["image"], processed["segmentation"]
|
119 |
+
image = (image / 127.5 - 1.0).astype(np.float32)
|
120 |
+
|
121 |
+
if self.onehot:
|
122 |
+
assert segmentation.dtype == np.uint8
|
123 |
+
# make it one hot
|
124 |
+
n_labels = 183
|
125 |
+
flatseg = np.ravel(segmentation)
|
126 |
+
onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
|
127 |
+
onehot[np.arange(flatseg.size), flatseg] = True
|
128 |
+
onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
|
129 |
+
segmentation = onehot
|
130 |
+
else:
|
131 |
+
segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
|
132 |
+
return image, segmentation
|
133 |
+
|
134 |
+
def __getitem__(self, i):
|
135 |
+
img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
|
136 |
+
seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
|
137 |
+
image, segmentation = self.preprocess_image(img_path, seg_path)
|
138 |
+
captions = self.img_id_to_captions[self.labels["image_ids"][i]]
|
139 |
+
# randomly draw one of all available captions per image
|
140 |
+
caption = captions[np.random.randint(0, len(captions))]
|
141 |
+
example = {"image": image,
|
142 |
+
"caption": [str(caption[0])],
|
143 |
+
"segmentation": segmentation,
|
144 |
+
"img_path": img_path,
|
145 |
+
"seg_path": seg_path,
|
146 |
+
"filename_": img_path.split(os.sep)[-1]
|
147 |
+
}
|
148 |
+
return example
|
149 |
+
|
150 |
+
|
151 |
+
class CocoImagesAndCaptionsTrain(CocoBase):
|
152 |
+
"""returns a pair of (image, caption)"""
|
153 |
+
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
|
154 |
+
super().__init__(size=size,
|
155 |
+
dataroot="data/coco/train2017",
|
156 |
+
datajson="data/coco/annotations/captions_train2017.json",
|
157 |
+
onehot_segmentation=onehot_segmentation,
|
158 |
+
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
|
159 |
+
|
160 |
+
def get_split(self):
|
161 |
+
return "train"
|
162 |
+
|
163 |
+
|
164 |
+
class CocoImagesAndCaptionsValidation(CocoBase):
|
165 |
+
"""returns a pair of (image, caption)"""
|
166 |
+
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
|
167 |
+
given_files=None):
|
168 |
+
super().__init__(size=size,
|
169 |
+
dataroot="data/coco/val2017",
|
170 |
+
datajson="data/coco/annotations/captions_val2017.json",
|
171 |
+
onehot_segmentation=onehot_segmentation,
|
172 |
+
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
|
173 |
+
given_files=given_files)
|
174 |
+
|
175 |
+
def get_split(self):
|
176 |
+
return "validation"
|
taming/data/conditional_builder/objects_bbox.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import cycle
|
2 |
+
from typing import List, Tuple, Callable, Optional
|
3 |
+
|
4 |
+
from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
|
5 |
+
from more_itertools.recipes import grouper
|
6 |
+
from taming.data.image_transforms import convert_pil_to_tensor
|
7 |
+
from torch import LongTensor, Tensor
|
8 |
+
|
9 |
+
from taming.data.helper_types import BoundingBox, Annotation
|
10 |
+
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
|
11 |
+
from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
|
12 |
+
pad_list, get_plot_font_size, absolute_bbox
|
13 |
+
|
14 |
+
|
15 |
+
class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
|
16 |
+
@property
|
17 |
+
def object_descriptor_length(self) -> int:
|
18 |
+
return 3
|
19 |
+
|
20 |
+
def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
|
21 |
+
object_triples = [
|
22 |
+
(self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
|
23 |
+
for ann in annotations
|
24 |
+
]
|
25 |
+
empty_triple = (self.none, self.none, self.none)
|
26 |
+
object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
|
27 |
+
return object_triples
|
28 |
+
|
29 |
+
def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
|
30 |
+
conditional_list = conditional.tolist()
|
31 |
+
crop_coordinates = None
|
32 |
+
if self.encode_crop:
|
33 |
+
crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
|
34 |
+
conditional_list = conditional_list[:-2]
|
35 |
+
object_triples = grouper(conditional_list, 3)
|
36 |
+
assert conditional.shape[0] == self.embedding_dim
|
37 |
+
return [
|
38 |
+
(object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
|
39 |
+
for object_triple in object_triples if object_triple[0] != self.none
|
40 |
+
], crop_coordinates
|
41 |
+
|
42 |
+
def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
|
43 |
+
line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
|
44 |
+
plot = pil_image.new('RGB', figure_size, WHITE)
|
45 |
+
draw = pil_img_draw.Draw(plot)
|
46 |
+
font = ImageFont.truetype(
|
47 |
+
"/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
|
48 |
+
size=get_plot_font_size(font_size, figure_size)
|
49 |
+
)
|
50 |
+
width, height = plot.size
|
51 |
+
description, crop_coordinates = self.inverse_build(conditional)
|
52 |
+
for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
|
53 |
+
annotation = self.representation_to_annotation(representation)
|
54 |
+
class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
|
55 |
+
bbox = absolute_bbox(bbox, width, height)
|
56 |
+
draw.rectangle(bbox, outline=color, width=line_width)
|
57 |
+
draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
|
58 |
+
if crop_coordinates is not None:
|
59 |
+
draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
|
60 |
+
return convert_pil_to_tensor(plot) / 127.5 - 1.
|
taming/data/conditional_builder/objects_center_points.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import warnings
|
4 |
+
from itertools import cycle
|
5 |
+
from typing import List, Optional, Tuple, Callable
|
6 |
+
|
7 |
+
from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
|
8 |
+
from more_itertools.recipes import grouper
|
9 |
+
from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
|
10 |
+
additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
|
11 |
+
absolute_bbox, rescale_annotations
|
12 |
+
from taming.data.helper_types import BoundingBox, Annotation
|
13 |
+
from taming.data.image_transforms import convert_pil_to_tensor
|
14 |
+
from torch import LongTensor, Tensor
|
15 |
+
|
16 |
+
|
17 |
+
class ObjectsCenterPointsConditionalBuilder:
|
18 |
+
def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
|
19 |
+
use_group_parameter: bool, use_additional_parameters: bool):
|
20 |
+
self.no_object_classes = no_object_classes
|
21 |
+
self.no_max_objects = no_max_objects
|
22 |
+
self.no_tokens = no_tokens
|
23 |
+
self.encode_crop = encode_crop
|
24 |
+
self.no_sections = int(math.sqrt(self.no_tokens))
|
25 |
+
self.use_group_parameter = use_group_parameter
|
26 |
+
self.use_additional_parameters = use_additional_parameters
|
27 |
+
|
28 |
+
@property
|
29 |
+
def none(self) -> int:
|
30 |
+
return self.no_tokens - 1
|
31 |
+
|
32 |
+
@property
|
33 |
+
def object_descriptor_length(self) -> int:
|
34 |
+
return 2
|
35 |
+
|
36 |
+
@property
|
37 |
+
def embedding_dim(self) -> int:
|
38 |
+
extra_length = 2 if self.encode_crop else 0
|
39 |
+
return self.no_max_objects * self.object_descriptor_length + extra_length
|
40 |
+
|
41 |
+
def tokenize_coordinates(self, x: float, y: float) -> int:
|
42 |
+
"""
|
43 |
+
Express 2d coordinates with one number.
|
44 |
+
Example: assume self.no_tokens = 16, then no_sections = 4:
|
45 |
+
0 0 0 0
|
46 |
+
0 0 # 0
|
47 |
+
0 0 0 0
|
48 |
+
0 0 0 x
|
49 |
+
Then the # position corresponds to token 6, the x position to token 15.
|
50 |
+
@param x: float in [0, 1]
|
51 |
+
@param y: float in [0, 1]
|
52 |
+
@return: discrete tokenized coordinate
|
53 |
+
"""
|
54 |
+
x_discrete = int(round(x * (self.no_sections - 1)))
|
55 |
+
y_discrete = int(round(y * (self.no_sections - 1)))
|
56 |
+
return y_discrete * self.no_sections + x_discrete
|
57 |
+
|
58 |
+
def coordinates_from_token(self, token: int) -> (float, float):
|
59 |
+
x = token % self.no_sections
|
60 |
+
y = token // self.no_sections
|
61 |
+
return x / (self.no_sections - 1), y / (self.no_sections - 1)
|
62 |
+
|
63 |
+
def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
|
64 |
+
x0, y0 = self.coordinates_from_token(token1)
|
65 |
+
x1, y1 = self.coordinates_from_token(token2)
|
66 |
+
return x0, y0, x1 - x0, y1 - y0
|
67 |
+
|
68 |
+
def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
|
69 |
+
return self.tokenize_coordinates(bbox[0], bbox[1]), \
|
70 |
+
self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])
|
71 |
+
|
72 |
+
def inverse_build(self, conditional: LongTensor) \
|
73 |
+
-> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
|
74 |
+
conditional_list = conditional.tolist()
|
75 |
+
crop_coordinates = None
|
76 |
+
if self.encode_crop:
|
77 |
+
crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
|
78 |
+
conditional_list = conditional_list[:-2]
|
79 |
+
table_of_content = grouper(conditional_list, self.object_descriptor_length)
|
80 |
+
assert conditional.shape[0] == self.embedding_dim
|
81 |
+
return [
|
82 |
+
(object_tuple[0], self.coordinates_from_token(object_tuple[1]))
|
83 |
+
for object_tuple in table_of_content if object_tuple[0] != self.none
|
84 |
+
], crop_coordinates
|
85 |
+
|
86 |
+
def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
|
87 |
+
line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
|
88 |
+
plot = pil_image.new('RGB', figure_size, WHITE)
|
89 |
+
draw = pil_img_draw.Draw(plot)
|
90 |
+
circle_size = get_circle_size(figure_size)
|
91 |
+
font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
|
92 |
+
size=get_plot_font_size(font_size, figure_size))
|
93 |
+
width, height = plot.size
|
94 |
+
description, crop_coordinates = self.inverse_build(conditional)
|
95 |
+
for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
|
96 |
+
x_abs, y_abs = x * width, y * height
|
97 |
+
ann = self.representation_to_annotation(representation)
|
98 |
+
label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
|
99 |
+
ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
|
100 |
+
draw.ellipse(ellipse_bbox, fill=color, width=0)
|
101 |
+
draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
|
102 |
+
if crop_coordinates is not None:
|
103 |
+
draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
|
104 |
+
return convert_pil_to_tensor(plot) / 127.5 - 1.
|
105 |
+
|
106 |
+
def object_representation(self, annotation: Annotation) -> int:
|
107 |
+
modifier = 0
|
108 |
+
if self.use_group_parameter:
|
109 |
+
modifier |= 1 * (annotation.is_group_of is True)
|
110 |
+
if self.use_additional_parameters:
|
111 |
+
modifier |= 2 * (annotation.is_occluded is True)
|
112 |
+
modifier |= 4 * (annotation.is_depiction is True)
|
113 |
+
modifier |= 8 * (annotation.is_inside is True)
|
114 |
+
return annotation.category_no + self.no_object_classes * modifier
|
115 |
+
|
116 |
+
def representation_to_annotation(self, representation: int) -> Annotation:
|
117 |
+
category_no = representation % self.no_object_classes
|
118 |
+
modifier = representation // self.no_object_classes
|
119 |
+
# noinspection PyTypeChecker
|
120 |
+
return Annotation(
|
121 |
+
area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
|
122 |
+
category_no=category_no,
|
123 |
+
is_group_of=bool((modifier & 1) * self.use_group_parameter),
|
124 |
+
is_occluded=bool((modifier & 2) * self.use_additional_parameters),
|
125 |
+
is_depiction=bool((modifier & 4) * self.use_additional_parameters),
|
126 |
+
is_inside=bool((modifier & 8) * self.use_additional_parameters)
|
127 |
+
)
|
128 |
+
|
129 |
+
def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
|
130 |
+
return list(self.token_pair_from_bbox(crop_coordinates))
|
131 |
+
|
132 |
+
def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
|
133 |
+
object_tuples = [
|
134 |
+
(self.object_representation(a),
|
135 |
+
self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
|
136 |
+
for a in annotations
|
137 |
+
]
|
138 |
+
empty_tuple = (self.none, self.none)
|
139 |
+
object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
|
140 |
+
return object_tuples
|
141 |
+
|
142 |
+
def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
|
143 |
+
-> LongTensor:
|
144 |
+
if len(annotations) == 0:
|
145 |
+
warnings.warn('Did not receive any annotations.')
|
146 |
+
if len(annotations) > self.no_max_objects:
|
147 |
+
warnings.warn('Received more annotations than allowed.')
|
148 |
+
annotations = annotations[:self.no_max_objects]
|
149 |
+
|
150 |
+
if not crop_coordinates:
|
151 |
+
crop_coordinates = FULL_CROP
|
152 |
+
|
153 |
+
random.shuffle(annotations)
|
154 |
+
annotations = filter_annotations(annotations, crop_coordinates)
|
155 |
+
if self.encode_crop:
|
156 |
+
annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
|
157 |
+
if horizontal_flip:
|
158 |
+
crop_coordinates = horizontally_flip_bbox(crop_coordinates)
|
159 |
+
extra = self._crop_encoder(crop_coordinates)
|
160 |
+
else:
|
161 |
+
annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
|
162 |
+
extra = []
|
163 |
+
|
164 |
+
object_tuples = self._make_object_descriptors(annotations)
|
165 |
+
flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
|
166 |
+
assert len(flattened) == self.embedding_dim
|
167 |
+
assert all(0 <= value < self.no_tokens for value in flattened)
|
168 |
+
return LongTensor(flattened)
|
taming/data/conditional_builder/utils.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from typing import List, Any, Tuple, Optional
|
3 |
+
|
4 |
+
from taming.data.helper_types import BoundingBox, Annotation
|
5 |
+
|
6 |
+
# source: seaborn, color palette tab10
|
7 |
+
COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
|
8 |
+
(139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
|
9 |
+
BLACK = (0, 0, 0)
|
10 |
+
GRAY_75 = (63, 63, 63)
|
11 |
+
GRAY_50 = (127, 127, 127)
|
12 |
+
GRAY_25 = (191, 191, 191)
|
13 |
+
WHITE = (255, 255, 255)
|
14 |
+
FULL_CROP = (0., 0., 1., 1.)
|
15 |
+
|
16 |
+
|
17 |
+
def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
|
18 |
+
"""
|
19 |
+
Give intersection area of two rectangles.
|
20 |
+
@param rectangle1: (x0, y0, w, h) of first rectangle
|
21 |
+
@param rectangle2: (x0, y0, w, h) of second rectangle
|
22 |
+
"""
|
23 |
+
rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
|
24 |
+
rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
|
25 |
+
x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
|
26 |
+
y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
|
27 |
+
return x_overlap * y_overlap
|
28 |
+
|
29 |
+
|
30 |
+
def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
|
31 |
+
return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
|
32 |
+
|
33 |
+
|
34 |
+
def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
|
35 |
+
bbox = relative_bbox
|
36 |
+
bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
|
37 |
+
return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
|
38 |
+
|
39 |
+
|
40 |
+
def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
|
41 |
+
return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
|
42 |
+
|
43 |
+
|
44 |
+
def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
|
45 |
+
List[Annotation]:
|
46 |
+
def clamp(x: float):
|
47 |
+
return max(min(x, 1.), 0.)
|
48 |
+
|
49 |
+
def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
|
50 |
+
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
51 |
+
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
52 |
+
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
53 |
+
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
54 |
+
if flip:
|
55 |
+
x0 = 1 - (x0 + w)
|
56 |
+
return x0, y0, w, h
|
57 |
+
|
58 |
+
return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
|
59 |
+
|
60 |
+
|
61 |
+
def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
|
62 |
+
return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
|
63 |
+
|
64 |
+
|
65 |
+
def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
|
66 |
+
sl = slice(1) if short else slice(None)
|
67 |
+
string = ''
|
68 |
+
if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
|
69 |
+
return string
|
70 |
+
if annotation.is_group_of:
|
71 |
+
string += 'group'[sl] + ','
|
72 |
+
if annotation.is_occluded:
|
73 |
+
string += 'occluded'[sl] + ','
|
74 |
+
if annotation.is_depiction:
|
75 |
+
string += 'depiction'[sl] + ','
|
76 |
+
if annotation.is_inside:
|
77 |
+
string += 'inside'[sl]
|
78 |
+
return '(' + string.strip(",") + ')'
|
79 |
+
|
80 |
+
|
81 |
+
def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
|
82 |
+
if font_size is None:
|
83 |
+
font_size = 10
|
84 |
+
if max(figure_size) >= 256:
|
85 |
+
font_size = 12
|
86 |
+
if max(figure_size) >= 512:
|
87 |
+
font_size = 15
|
88 |
+
return font_size
|
89 |
+
|
90 |
+
|
91 |
+
def get_circle_size(figure_size: Tuple[int, int]) -> int:
|
92 |
+
circle_size = 2
|
93 |
+
if max(figure_size) >= 256:
|
94 |
+
circle_size = 3
|
95 |
+
if max(figure_size) >= 512:
|
96 |
+
circle_size = 4
|
97 |
+
return circle_size
|
98 |
+
|
99 |
+
|
100 |
+
def load_object_from_string(object_string: str) -> Any:
|
101 |
+
"""
|
102 |
+
Source: https://stackoverflow.com/a/10773699
|
103 |
+
"""
|
104 |
+
module_name, class_name = object_string.rsplit(".", 1)
|
105 |
+
return getattr(importlib.import_module(module_name), class_name)
|
taming/data/custom.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import albumentations
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
|
6 |
+
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
|
7 |
+
|
8 |
+
|
9 |
+
class CustomBase(Dataset):
|
10 |
+
def __init__(self, *args, **kwargs):
|
11 |
+
super().__init__()
|
12 |
+
self.data = None
|
13 |
+
|
14 |
+
def __len__(self):
|
15 |
+
return len(self.data)
|
16 |
+
|
17 |
+
def __getitem__(self, i):
|
18 |
+
example = self.data[i]
|
19 |
+
return example
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
class CustomTrain(CustomBase):
|
24 |
+
def __init__(self, size, training_images_list_file):
|
25 |
+
super().__init__()
|
26 |
+
with open(training_images_list_file, "r") as f:
|
27 |
+
paths = f.read().splitlines()
|
28 |
+
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
|
29 |
+
|
30 |
+
|
31 |
+
class CustomTest(CustomBase):
|
32 |
+
def __init__(self, size, test_images_list_file):
|
33 |
+
super().__init__()
|
34 |
+
with open(test_images_list_file, "r") as f:
|
35 |
+
paths = f.read().splitlines()
|
36 |
+
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
|
37 |
+
|
38 |
+
|
taming/data/faceshq.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import albumentations
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
|
6 |
+
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
|
7 |
+
|
8 |
+
|
9 |
+
class FacesBase(Dataset):
|
10 |
+
def __init__(self, *args, **kwargs):
|
11 |
+
super().__init__()
|
12 |
+
self.data = None
|
13 |
+
self.keys = None
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return len(self.data)
|
17 |
+
|
18 |
+
def __getitem__(self, i):
|
19 |
+
example = self.data[i]
|
20 |
+
ex = {}
|
21 |
+
if self.keys is not None:
|
22 |
+
for k in self.keys:
|
23 |
+
ex[k] = example[k]
|
24 |
+
else:
|
25 |
+
ex = example
|
26 |
+
return ex
|
27 |
+
|
28 |
+
|
29 |
+
class CelebAHQTrain(FacesBase):
|
30 |
+
def __init__(self, size, keys=None):
|
31 |
+
super().__init__()
|
32 |
+
root = "data/celebahq"
|
33 |
+
with open("data/celebahqtrain.txt", "r") as f:
|
34 |
+
relpaths = f.read().splitlines()
|
35 |
+
paths = [os.path.join(root, relpath) for relpath in relpaths]
|
36 |
+
self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
|
37 |
+
self.keys = keys
|
38 |
+
|
39 |
+
|
40 |
+
class CelebAHQValidation(FacesBase):
|
41 |
+
def __init__(self, size, keys=None):
|
42 |
+
super().__init__()
|
43 |
+
root = "data/celebahq"
|
44 |
+
with open("data/celebahqvalidation.txt", "r") as f:
|
45 |
+
relpaths = f.read().splitlines()
|
46 |
+
paths = [os.path.join(root, relpath) for relpath in relpaths]
|
47 |
+
self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
|
48 |
+
self.keys = keys
|
49 |
+
|
50 |
+
|
51 |
+
class FFHQTrain(FacesBase):
|
52 |
+
def __init__(self, size, keys=None):
|
53 |
+
super().__init__()
|
54 |
+
root = "data/ffhq"
|
55 |
+
with open("data/ffhqtrain.txt", "r") as f:
|
56 |
+
relpaths = f.read().splitlines()
|
57 |
+
paths = [os.path.join(root, relpath) for relpath in relpaths]
|
58 |
+
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
|
59 |
+
self.keys = keys
|
60 |
+
|
61 |
+
|
62 |
+
class FFHQValidation(FacesBase):
|
63 |
+
def __init__(self, size, keys=None):
|
64 |
+
super().__init__()
|
65 |
+
root = "data/ffhq"
|
66 |
+
with open("data/ffhqvalidation.txt", "r") as f:
|
67 |
+
relpaths = f.read().splitlines()
|
68 |
+
paths = [os.path.join(root, relpath) for relpath in relpaths]
|
69 |
+
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
|
70 |
+
self.keys = keys
|
71 |
+
|
72 |
+
|
73 |
+
class FacesHQTrain(Dataset):
|
74 |
+
# CelebAHQ [0] + FFHQ [1]
|
75 |
+
def __init__(self, size, keys=None, crop_size=None, coord=False):
|
76 |
+
d1 = CelebAHQTrain(size=size, keys=keys)
|
77 |
+
d2 = FFHQTrain(size=size, keys=keys)
|
78 |
+
self.data = ConcatDatasetWithIndex([d1, d2])
|
79 |
+
self.coord = coord
|
80 |
+
if crop_size is not None:
|
81 |
+
self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
|
82 |
+
if self.coord:
|
83 |
+
self.cropper = albumentations.Compose([self.cropper],
|
84 |
+
additional_targets={"coord": "image"})
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return len(self.data)
|
88 |
+
|
89 |
+
def __getitem__(self, i):
|
90 |
+
ex, y = self.data[i]
|
91 |
+
if hasattr(self, "cropper"):
|
92 |
+
if not self.coord:
|
93 |
+
out = self.cropper(image=ex["image"])
|
94 |
+
ex["image"] = out["image"]
|
95 |
+
else:
|
96 |
+
h,w,_ = ex["image"].shape
|
97 |
+
coord = np.arange(h*w).reshape(h,w,1)/(h*w)
|
98 |
+
out = self.cropper(image=ex["image"], coord=coord)
|
99 |
+
ex["image"] = out["image"]
|
100 |
+
ex["coord"] = out["coord"]
|
101 |
+
ex["class"] = y
|
102 |
+
return ex
|
103 |
+
|
104 |
+
|
105 |
+
class FacesHQValidation(Dataset):
|
106 |
+
# CelebAHQ [0] + FFHQ [1]
|
107 |
+
def __init__(self, size, keys=None, crop_size=None, coord=False):
|
108 |
+
d1 = CelebAHQValidation(size=size, keys=keys)
|
109 |
+
d2 = FFHQValidation(size=size, keys=keys)
|
110 |
+
self.data = ConcatDatasetWithIndex([d1, d2])
|
111 |
+
self.coord = coord
|
112 |
+
if crop_size is not None:
|
113 |
+
self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
|
114 |
+
if self.coord:
|
115 |
+
self.cropper = albumentations.Compose([self.cropper],
|
116 |
+
additional_targets={"coord": "image"})
|
117 |
+
|
118 |
+
def __len__(self):
|
119 |
+
return len(self.data)
|
120 |
+
|
121 |
+
def __getitem__(self, i):
|
122 |
+
ex, y = self.data[i]
|
123 |
+
if hasattr(self, "cropper"):
|
124 |
+
if not self.coord:
|
125 |
+
out = self.cropper(image=ex["image"])
|
126 |
+
ex["image"] = out["image"]
|
127 |
+
else:
|
128 |
+
h,w,_ = ex["image"].shape
|
129 |
+
coord = np.arange(h*w).reshape(h,w,1)/(h*w)
|
130 |
+
out = self.cropper(image=ex["image"], coord=coord)
|
131 |
+
ex["image"] = out["image"]
|
132 |
+
ex["coord"] = out["coord"]
|
133 |
+
ex["class"] = y
|
134 |
+
return ex
|
taming/data/helper_types.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Tuple, Optional, NamedTuple, Union
|
2 |
+
from PIL.Image import Image as pil_image
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
try:
|
6 |
+
from typing import Literal
|
7 |
+
except ImportError:
|
8 |
+
from typing_extensions import Literal
|
9 |
+
|
10 |
+
Image = Union[Tensor, pil_image]
|
11 |
+
BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
|
12 |
+
CropMethodType = Literal['none', 'random', 'center', 'random-2d']
|
13 |
+
SplitType = Literal['train', 'validation', 'test']
|
14 |
+
|
15 |
+
|
16 |
+
class ImageDescription(NamedTuple):
|
17 |
+
id: int
|
18 |
+
file_name: str
|
19 |
+
original_size: Tuple[int, int] # w, h
|
20 |
+
url: Optional[str] = None
|
21 |
+
license: Optional[int] = None
|
22 |
+
coco_url: Optional[str] = None
|
23 |
+
date_captured: Optional[str] = None
|
24 |
+
flickr_url: Optional[str] = None
|
25 |
+
flickr_id: Optional[str] = None
|
26 |
+
coco_id: Optional[str] = None
|
27 |
+
|
28 |
+
|
29 |
+
class Category(NamedTuple):
|
30 |
+
id: str
|
31 |
+
super_category: Optional[str]
|
32 |
+
name: str
|
33 |
+
|
34 |
+
|
35 |
+
class Annotation(NamedTuple):
|
36 |
+
area: float
|
37 |
+
image_id: str
|
38 |
+
bbox: BoundingBox
|
39 |
+
category_no: int
|
40 |
+
category_id: str
|
41 |
+
id: Optional[int] = None
|
42 |
+
source: Optional[str] = None
|
43 |
+
confidence: Optional[float] = None
|
44 |
+
is_group_of: Optional[bool] = None
|
45 |
+
is_truncated: Optional[bool] = None
|
46 |
+
is_occluded: Optional[bool] = None
|
47 |
+
is_depiction: Optional[bool] = None
|
48 |
+
is_inside: Optional[bool] = None
|
49 |
+
segmentation: Optional[Dict] = None
|
taming/data/image_transforms.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import warnings
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import Tensor
|
7 |
+
from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
|
8 |
+
from torchvision.transforms.functional import _get_image_size as get_image_size
|
9 |
+
|
10 |
+
from taming.data.helper_types import BoundingBox, Image
|
11 |
+
|
12 |
+
pil_to_tensor = PILToTensor()
|
13 |
+
|
14 |
+
|
15 |
+
def convert_pil_to_tensor(image: Image) -> Tensor:
|
16 |
+
with warnings.catch_warnings():
|
17 |
+
# to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
|
18 |
+
warnings.simplefilter("ignore")
|
19 |
+
return pil_to_tensor(image)
|
20 |
+
|
21 |
+
|
22 |
+
class RandomCrop1dReturnCoordinates(RandomCrop):
|
23 |
+
def forward(self, img: Image) -> (BoundingBox, Image):
|
24 |
+
"""
|
25 |
+
Additionally to cropping, returns the relative coordinates of the crop bounding box.
|
26 |
+
Args:
|
27 |
+
img (PIL Image or Tensor): Image to be cropped.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
Bounding box: x0, y0, w, h
|
31 |
+
PIL Image or Tensor: Cropped image.
|
32 |
+
|
33 |
+
Based on:
|
34 |
+
torchvision.transforms.RandomCrop, torchvision 1.7.0
|
35 |
+
"""
|
36 |
+
if self.padding is not None:
|
37 |
+
img = F.pad(img, self.padding, self.fill, self.padding_mode)
|
38 |
+
|
39 |
+
width, height = get_image_size(img)
|
40 |
+
# pad the width if needed
|
41 |
+
if self.pad_if_needed and width < self.size[1]:
|
42 |
+
padding = [self.size[1] - width, 0]
|
43 |
+
img = F.pad(img, padding, self.fill, self.padding_mode)
|
44 |
+
# pad the height if needed
|
45 |
+
if self.pad_if_needed and height < self.size[0]:
|
46 |
+
padding = [0, self.size[0] - height]
|
47 |
+
img = F.pad(img, padding, self.fill, self.padding_mode)
|
48 |
+
|
49 |
+
i, j, h, w = self.get_params(img, self.size)
|
50 |
+
bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
|
51 |
+
return bbox, F.crop(img, i, j, h, w)
|
52 |
+
|
53 |
+
|
54 |
+
class Random2dCropReturnCoordinates(torch.nn.Module):
|
55 |
+
"""
|
56 |
+
Additionally to cropping, returns the relative coordinates of the crop bounding box.
|
57 |
+
Args:
|
58 |
+
img (PIL Image or Tensor): Image to be cropped.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
Bounding box: x0, y0, w, h
|
62 |
+
PIL Image or Tensor: Cropped image.
|
63 |
+
|
64 |
+
Based on:
|
65 |
+
torchvision.transforms.RandomCrop, torchvision 1.7.0
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, min_size: int):
|
69 |
+
super().__init__()
|
70 |
+
self.min_size = min_size
|
71 |
+
|
72 |
+
def forward(self, img: Image) -> (BoundingBox, Image):
|
73 |
+
width, height = get_image_size(img)
|
74 |
+
max_size = min(width, height)
|
75 |
+
if max_size <= self.min_size:
|
76 |
+
size = max_size
|
77 |
+
else:
|
78 |
+
size = random.randint(self.min_size, max_size)
|
79 |
+
top = random.randint(0, height - size)
|
80 |
+
left = random.randint(0, width - size)
|
81 |
+
bbox = left / width, top / height, size / width, size / height
|
82 |
+
return bbox, F.crop(img, top, left, size, size)
|
83 |
+
|
84 |
+
|
85 |
+
class CenterCropReturnCoordinates(CenterCrop):
|
86 |
+
@staticmethod
|
87 |
+
def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
|
88 |
+
if width > height:
|
89 |
+
w = height / width
|
90 |
+
h = 1.0
|
91 |
+
x0 = 0.5 - w / 2
|
92 |
+
y0 = 0.
|
93 |
+
else:
|
94 |
+
w = 1.0
|
95 |
+
h = width / height
|
96 |
+
x0 = 0.
|
97 |
+
y0 = 0.5 - h / 2
|
98 |
+
return x0, y0, w, h
|
99 |
+
|
100 |
+
def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
|
101 |
+
"""
|
102 |
+
Additionally to cropping, returns the relative coordinates of the crop bounding box.
|
103 |
+
Args:
|
104 |
+
img (PIL Image or Tensor): Image to be cropped.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
Bounding box: x0, y0, w, h
|
108 |
+
PIL Image or Tensor: Cropped image.
|
109 |
+
Based on:
|
110 |
+
torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
|
111 |
+
"""
|
112 |
+
width, height = get_image_size(img)
|
113 |
+
return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
|
114 |
+
|
115 |
+
|
116 |
+
class RandomHorizontalFlipReturn(RandomHorizontalFlip):
|
117 |
+
def forward(self, img: Image) -> (bool, Image):
|
118 |
+
"""
|
119 |
+
Additionally to flipping, returns a boolean whether it was flipped or not.
|
120 |
+
Args:
|
121 |
+
img (PIL Image or Tensor): Image to be flipped.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
flipped: whether the image was flipped or not
|
125 |
+
PIL Image or Tensor: Randomly flipped image.
|
126 |
+
|
127 |
+
Based on:
|
128 |
+
torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
|
129 |
+
"""
|
130 |
+
if torch.rand(1) < self.p:
|
131 |
+
return True, F.hflip(img)
|
132 |
+
return False, img
|
taming/data/imagenet.py
ADDED
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, tarfile, glob, shutil
|
2 |
+
import yaml
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
from PIL import Image
|
6 |
+
import albumentations
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
from taming.data.base import ImagePaths
|
11 |
+
from taming.util import download, retrieve
|
12 |
+
import taming.data.utils as bdu
|
13 |
+
|
14 |
+
|
15 |
+
def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"):
|
16 |
+
synsets = []
|
17 |
+
with open(path_to_yaml) as f:
|
18 |
+
di2s = yaml.load(f)
|
19 |
+
for idx in indices:
|
20 |
+
synsets.append(str(di2s[idx]))
|
21 |
+
print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets)))
|
22 |
+
return synsets
|
23 |
+
|
24 |
+
|
25 |
+
def str_to_indices(string):
|
26 |
+
"""Expects a string in the format '32-123, 256, 280-321'"""
|
27 |
+
assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string)
|
28 |
+
subs = string.split(",")
|
29 |
+
indices = []
|
30 |
+
for sub in subs:
|
31 |
+
subsubs = sub.split("-")
|
32 |
+
assert len(subsubs) > 0
|
33 |
+
if len(subsubs) == 1:
|
34 |
+
indices.append(int(subsubs[0]))
|
35 |
+
else:
|
36 |
+
rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))]
|
37 |
+
indices.extend(rang)
|
38 |
+
return sorted(indices)
|
39 |
+
|
40 |
+
|
41 |
+
class ImageNetBase(Dataset):
|
42 |
+
def __init__(self, config=None):
|
43 |
+
self.config = config or OmegaConf.create()
|
44 |
+
if not type(self.config)==dict:
|
45 |
+
self.config = OmegaConf.to_container(self.config)
|
46 |
+
self._prepare()
|
47 |
+
self._prepare_synset_to_human()
|
48 |
+
self._prepare_idx_to_synset()
|
49 |
+
self._load()
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self.data)
|
53 |
+
|
54 |
+
def __getitem__(self, i):
|
55 |
+
return self.data[i]
|
56 |
+
|
57 |
+
def _prepare(self):
|
58 |
+
raise NotImplementedError()
|
59 |
+
|
60 |
+
def _filter_relpaths(self, relpaths):
|
61 |
+
ignore = set([
|
62 |
+
"n06596364_9591.JPEG",
|
63 |
+
])
|
64 |
+
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
|
65 |
+
if "sub_indices" in self.config:
|
66 |
+
indices = str_to_indices(self.config["sub_indices"])
|
67 |
+
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
|
68 |
+
files = []
|
69 |
+
for rpath in relpaths:
|
70 |
+
syn = rpath.split("/")[0]
|
71 |
+
if syn in synsets:
|
72 |
+
files.append(rpath)
|
73 |
+
return files
|
74 |
+
else:
|
75 |
+
return relpaths
|
76 |
+
|
77 |
+
def _prepare_synset_to_human(self):
|
78 |
+
SIZE = 2655750
|
79 |
+
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
|
80 |
+
self.human_dict = os.path.join(self.root, "synset_human.txt")
|
81 |
+
if (not os.path.exists(self.human_dict) or
|
82 |
+
not os.path.getsize(self.human_dict)==SIZE):
|
83 |
+
download(URL, self.human_dict)
|
84 |
+
|
85 |
+
def _prepare_idx_to_synset(self):
|
86 |
+
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
|
87 |
+
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
|
88 |
+
if (not os.path.exists(self.idx2syn)):
|
89 |
+
download(URL, self.idx2syn)
|
90 |
+
|
91 |
+
def _load(self):
|
92 |
+
with open(self.txt_filelist, "r") as f:
|
93 |
+
self.relpaths = f.read().splitlines()
|
94 |
+
l1 = len(self.relpaths)
|
95 |
+
self.relpaths = self._filter_relpaths(self.relpaths)
|
96 |
+
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
|
97 |
+
|
98 |
+
self.synsets = [p.split("/")[0] for p in self.relpaths]
|
99 |
+
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
|
100 |
+
|
101 |
+
unique_synsets = np.unique(self.synsets)
|
102 |
+
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
|
103 |
+
self.class_labels = [class_dict[s] for s in self.synsets]
|
104 |
+
|
105 |
+
with open(self.human_dict, "r") as f:
|
106 |
+
human_dict = f.read().splitlines()
|
107 |
+
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
|
108 |
+
|
109 |
+
self.human_labels = [human_dict[s] for s in self.synsets]
|
110 |
+
|
111 |
+
labels = {
|
112 |
+
"relpath": np.array(self.relpaths),
|
113 |
+
"synsets": np.array(self.synsets),
|
114 |
+
"class_label": np.array(self.class_labels),
|
115 |
+
"human_label": np.array(self.human_labels),
|
116 |
+
}
|
117 |
+
self.data = ImagePaths(self.abspaths,
|
118 |
+
labels=labels,
|
119 |
+
size=retrieve(self.config, "size", default=0),
|
120 |
+
random_crop=self.random_crop)
|
121 |
+
|
122 |
+
|
123 |
+
class ImageNetTrain(ImageNetBase):
|
124 |
+
NAME = "ILSVRC2012_train"
|
125 |
+
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
126 |
+
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
|
127 |
+
FILES = [
|
128 |
+
"ILSVRC2012_img_train.tar",
|
129 |
+
]
|
130 |
+
SIZES = [
|
131 |
+
147897477120,
|
132 |
+
]
|
133 |
+
|
134 |
+
def _prepare(self):
|
135 |
+
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
|
136 |
+
default=True)
|
137 |
+
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
138 |
+
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
139 |
+
self.datadir = os.path.join(self.root, "data")
|
140 |
+
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
141 |
+
self.expected_length = 1281167
|
142 |
+
if not bdu.is_prepared(self.root):
|
143 |
+
# prep
|
144 |
+
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
145 |
+
|
146 |
+
datadir = self.datadir
|
147 |
+
if not os.path.exists(datadir):
|
148 |
+
path = os.path.join(self.root, self.FILES[0])
|
149 |
+
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
150 |
+
import academictorrents as at
|
151 |
+
atpath = at.get(self.AT_HASH, datastore=self.root)
|
152 |
+
assert atpath == path
|
153 |
+
|
154 |
+
print("Extracting {} to {}".format(path, datadir))
|
155 |
+
os.makedirs(datadir, exist_ok=True)
|
156 |
+
with tarfile.open(path, "r:") as tar:
|
157 |
+
tar.extractall(path=datadir)
|
158 |
+
|
159 |
+
print("Extracting sub-tars.")
|
160 |
+
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
|
161 |
+
for subpath in tqdm(subpaths):
|
162 |
+
subdir = subpath[:-len(".tar")]
|
163 |
+
os.makedirs(subdir, exist_ok=True)
|
164 |
+
with tarfile.open(subpath, "r:") as tar:
|
165 |
+
tar.extractall(path=subdir)
|
166 |
+
|
167 |
+
|
168 |
+
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
169 |
+
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
170 |
+
filelist = sorted(filelist)
|
171 |
+
filelist = "\n".join(filelist)+"\n"
|
172 |
+
with open(self.txt_filelist, "w") as f:
|
173 |
+
f.write(filelist)
|
174 |
+
|
175 |
+
bdu.mark_prepared(self.root)
|
176 |
+
|
177 |
+
|
178 |
+
class ImageNetValidation(ImageNetBase):
|
179 |
+
NAME = "ILSVRC2012_validation"
|
180 |
+
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
181 |
+
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
|
182 |
+
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
|
183 |
+
FILES = [
|
184 |
+
"ILSVRC2012_img_val.tar",
|
185 |
+
"validation_synset.txt",
|
186 |
+
]
|
187 |
+
SIZES = [
|
188 |
+
6744924160,
|
189 |
+
1950000,
|
190 |
+
]
|
191 |
+
|
192 |
+
def _prepare(self):
|
193 |
+
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
|
194 |
+
default=False)
|
195 |
+
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
196 |
+
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
197 |
+
self.datadir = os.path.join(self.root, "data")
|
198 |
+
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
199 |
+
self.expected_length = 50000
|
200 |
+
if not bdu.is_prepared(self.root):
|
201 |
+
# prep
|
202 |
+
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
203 |
+
|
204 |
+
datadir = self.datadir
|
205 |
+
if not os.path.exists(datadir):
|
206 |
+
path = os.path.join(self.root, self.FILES[0])
|
207 |
+
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
208 |
+
import academictorrents as at
|
209 |
+
atpath = at.get(self.AT_HASH, datastore=self.root)
|
210 |
+
assert atpath == path
|
211 |
+
|
212 |
+
print("Extracting {} to {}".format(path, datadir))
|
213 |
+
os.makedirs(datadir, exist_ok=True)
|
214 |
+
with tarfile.open(path, "r:") as tar:
|
215 |
+
tar.extractall(path=datadir)
|
216 |
+
|
217 |
+
vspath = os.path.join(self.root, self.FILES[1])
|
218 |
+
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
|
219 |
+
download(self.VS_URL, vspath)
|
220 |
+
|
221 |
+
with open(vspath, "r") as f:
|
222 |
+
synset_dict = f.read().splitlines()
|
223 |
+
synset_dict = dict(line.split() for line in synset_dict)
|
224 |
+
|
225 |
+
print("Reorganizing into synset folders")
|
226 |
+
synsets = np.unique(list(synset_dict.values()))
|
227 |
+
for s in synsets:
|
228 |
+
os.makedirs(os.path.join(datadir, s), exist_ok=True)
|
229 |
+
for k, v in synset_dict.items():
|
230 |
+
src = os.path.join(datadir, k)
|
231 |
+
dst = os.path.join(datadir, v)
|
232 |
+
shutil.move(src, dst)
|
233 |
+
|
234 |
+
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
235 |
+
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
236 |
+
filelist = sorted(filelist)
|
237 |
+
filelist = "\n".join(filelist)+"\n"
|
238 |
+
with open(self.txt_filelist, "w") as f:
|
239 |
+
f.write(filelist)
|
240 |
+
|
241 |
+
bdu.mark_prepared(self.root)
|
242 |
+
|
243 |
+
|
244 |
+
def get_preprocessor(size=None, random_crop=False, additional_targets=None,
|
245 |
+
crop_size=None):
|
246 |
+
if size is not None and size > 0:
|
247 |
+
transforms = list()
|
248 |
+
rescaler = albumentations.SmallestMaxSize(max_size = size)
|
249 |
+
transforms.append(rescaler)
|
250 |
+
if not random_crop:
|
251 |
+
cropper = albumentations.CenterCrop(height=size,width=size)
|
252 |
+
transforms.append(cropper)
|
253 |
+
else:
|
254 |
+
cropper = albumentations.RandomCrop(height=size,width=size)
|
255 |
+
transforms.append(cropper)
|
256 |
+
flipper = albumentations.HorizontalFlip()
|
257 |
+
transforms.append(flipper)
|
258 |
+
preprocessor = albumentations.Compose(transforms,
|
259 |
+
additional_targets=additional_targets)
|
260 |
+
elif crop_size is not None and crop_size > 0:
|
261 |
+
if not random_crop:
|
262 |
+
cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
|
263 |
+
else:
|
264 |
+
cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
|
265 |
+
transforms = [cropper]
|
266 |
+
preprocessor = albumentations.Compose(transforms,
|
267 |
+
additional_targets=additional_targets)
|
268 |
+
else:
|
269 |
+
preprocessor = lambda **kwargs: kwargs
|
270 |
+
return preprocessor
|
271 |
+
|
272 |
+
|
273 |
+
def rgba_to_depth(x):
|
274 |
+
assert x.dtype == np.uint8
|
275 |
+
assert len(x.shape) == 3 and x.shape[2] == 4
|
276 |
+
y = x.copy()
|
277 |
+
y.dtype = np.float32
|
278 |
+
y = y.reshape(x.shape[:2])
|
279 |
+
return np.ascontiguousarray(y)
|
280 |
+
|
281 |
+
|
282 |
+
class BaseWithDepth(Dataset):
|
283 |
+
DEFAULT_DEPTH_ROOT="data/imagenet_depth"
|
284 |
+
|
285 |
+
def __init__(self, config=None, size=None, random_crop=False,
|
286 |
+
crop_size=None, root=None):
|
287 |
+
self.config = config
|
288 |
+
self.base_dset = self.get_base_dset()
|
289 |
+
self.preprocessor = get_preprocessor(
|
290 |
+
size=size,
|
291 |
+
crop_size=crop_size,
|
292 |
+
random_crop=random_crop,
|
293 |
+
additional_targets={"depth": "image"})
|
294 |
+
self.crop_size = crop_size
|
295 |
+
if self.crop_size is not None:
|
296 |
+
self.rescaler = albumentations.Compose(
|
297 |
+
[albumentations.SmallestMaxSize(max_size = self.crop_size)],
|
298 |
+
additional_targets={"depth": "image"})
|
299 |
+
if root is not None:
|
300 |
+
self.DEFAULT_DEPTH_ROOT = root
|
301 |
+
|
302 |
+
def __len__(self):
|
303 |
+
return len(self.base_dset)
|
304 |
+
|
305 |
+
def preprocess_depth(self, path):
|
306 |
+
rgba = np.array(Image.open(path))
|
307 |
+
depth = rgba_to_depth(rgba)
|
308 |
+
depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
|
309 |
+
depth = 2.0*depth-1.0
|
310 |
+
return depth
|
311 |
+
|
312 |
+
def __getitem__(self, i):
|
313 |
+
e = self.base_dset[i]
|
314 |
+
e["depth"] = self.preprocess_depth(self.get_depth_path(e))
|
315 |
+
# up if necessary
|
316 |
+
h,w,c = e["image"].shape
|
317 |
+
if self.crop_size and min(h,w) < self.crop_size:
|
318 |
+
# have to upscale to be able to crop - this just uses bilinear
|
319 |
+
out = self.rescaler(image=e["image"], depth=e["depth"])
|
320 |
+
e["image"] = out["image"]
|
321 |
+
e["depth"] = out["depth"]
|
322 |
+
transformed = self.preprocessor(image=e["image"], depth=e["depth"])
|
323 |
+
e["image"] = transformed["image"]
|
324 |
+
e["depth"] = transformed["depth"]
|
325 |
+
return e
|
326 |
+
|
327 |
+
|
328 |
+
class ImageNetTrainWithDepth(BaseWithDepth):
|
329 |
+
# default to random_crop=True
|
330 |
+
def __init__(self, random_crop=True, sub_indices=None, **kwargs):
|
331 |
+
self.sub_indices = sub_indices
|
332 |
+
super().__init__(random_crop=random_crop, **kwargs)
|
333 |
+
|
334 |
+
def get_base_dset(self):
|
335 |
+
if self.sub_indices is None:
|
336 |
+
return ImageNetTrain()
|
337 |
+
else:
|
338 |
+
return ImageNetTrain({"sub_indices": self.sub_indices})
|
339 |
+
|
340 |
+
def get_depth_path(self, e):
|
341 |
+
fid = os.path.splitext(e["relpath"])[0]+".png"
|
342 |
+
fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid)
|
343 |
+
return fid
|
344 |
+
|
345 |
+
|
346 |
+
class ImageNetValidationWithDepth(BaseWithDepth):
|
347 |
+
def __init__(self, sub_indices=None, **kwargs):
|
348 |
+
self.sub_indices = sub_indices
|
349 |
+
super().__init__(**kwargs)
|
350 |
+
|
351 |
+
def get_base_dset(self):
|
352 |
+
if self.sub_indices is None:
|
353 |
+
return ImageNetValidation()
|
354 |
+
else:
|
355 |
+
return ImageNetValidation({"sub_indices": self.sub_indices})
|
356 |
+
|
357 |
+
def get_depth_path(self, e):
|
358 |
+
fid = os.path.splitext(e["relpath"])[0]+".png"
|
359 |
+
fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid)
|
360 |
+
return fid
|
361 |
+
|
362 |
+
|
363 |
+
class RINTrainWithDepth(ImageNetTrainWithDepth):
|
364 |
+
def __init__(self, config=None, size=None, random_crop=True, crop_size=None):
|
365 |
+
sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
|
366 |
+
super().__init__(config=config, size=size, random_crop=random_crop,
|
367 |
+
sub_indices=sub_indices, crop_size=crop_size)
|
368 |
+
|
369 |
+
|
370 |
+
class RINValidationWithDepth(ImageNetValidationWithDepth):
|
371 |
+
def __init__(self, config=None, size=None, random_crop=False, crop_size=None):
|
372 |
+
sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
|
373 |
+
super().__init__(config=config, size=size, random_crop=random_crop,
|
374 |
+
sub_indices=sub_indices, crop_size=crop_size)
|
375 |
+
|
376 |
+
|
377 |
+
class DRINExamples(Dataset):
|
378 |
+
def __init__(self):
|
379 |
+
self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"})
|
380 |
+
with open("data/drin_examples.txt", "r") as f:
|
381 |
+
relpaths = f.read().splitlines()
|
382 |
+
self.image_paths = [os.path.join("data/drin_images",
|
383 |
+
relpath) for relpath in relpaths]
|
384 |
+
self.depth_paths = [os.path.join("data/drin_depth",
|
385 |
+
relpath.replace(".JPEG", ".png")) for relpath in relpaths]
|
386 |
+
|
387 |
+
def __len__(self):
|
388 |
+
return len(self.image_paths)
|
389 |
+
|
390 |
+
def preprocess_image(self, image_path):
|
391 |
+
image = Image.open(image_path)
|
392 |
+
if not image.mode == "RGB":
|
393 |
+
image = image.convert("RGB")
|
394 |
+
image = np.array(image).astype(np.uint8)
|
395 |
+
image = self.preprocessor(image=image)["image"]
|
396 |
+
image = (image/127.5 - 1.0).astype(np.float32)
|
397 |
+
return image
|
398 |
+
|
399 |
+
def preprocess_depth(self, path):
|
400 |
+
rgba = np.array(Image.open(path))
|
401 |
+
depth = rgba_to_depth(rgba)
|
402 |
+
depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
|
403 |
+
depth = 2.0*depth-1.0
|
404 |
+
return depth
|
405 |
+
|
406 |
+
def __getitem__(self, i):
|
407 |
+
e = dict()
|
408 |
+
e["image"] = self.preprocess_image(self.image_paths[i])
|
409 |
+
e["depth"] = self.preprocess_depth(self.depth_paths[i])
|
410 |
+
transformed = self.preprocessor(image=e["image"], depth=e["depth"])
|
411 |
+
e["image"] = transformed["image"]
|
412 |
+
e["depth"] = transformed["depth"]
|
413 |
+
return e
|
414 |
+
|
415 |
+
|
416 |
+
def imscale(x, factor, keepshapes=False, keepmode="bicubic"):
|
417 |
+
if factor is None or factor==1:
|
418 |
+
return x
|
419 |
+
|
420 |
+
dtype = x.dtype
|
421 |
+
assert dtype in [np.float32, np.float64]
|
422 |
+
assert x.min() >= -1
|
423 |
+
assert x.max() <= 1
|
424 |
+
|
425 |
+
keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR,
|
426 |
+
"bicubic": Image.BICUBIC}[keepmode]
|
427 |
+
|
428 |
+
lr = (x+1.0)*127.5
|
429 |
+
lr = lr.clip(0,255).astype(np.uint8)
|
430 |
+
lr = Image.fromarray(lr)
|
431 |
+
|
432 |
+
h, w, _ = x.shape
|
433 |
+
nh = h//factor
|
434 |
+
nw = w//factor
|
435 |
+
assert nh > 0 and nw > 0, (nh, nw)
|
436 |
+
|
437 |
+
lr = lr.resize((nw,nh), Image.BICUBIC)
|
438 |
+
if keepshapes:
|
439 |
+
lr = lr.resize((w,h), keepmode)
|
440 |
+
lr = np.array(lr)/127.5-1.0
|
441 |
+
lr = lr.astype(dtype)
|
442 |
+
|
443 |
+
return lr
|
444 |
+
|
445 |
+
|
446 |
+
class ImageNetScale(Dataset):
|
447 |
+
def __init__(self, size=None, crop_size=None, random_crop=False,
|
448 |
+
up_factor=None, hr_factor=None, keep_mode="bicubic"):
|
449 |
+
self.base = self.get_base()
|
450 |
+
|
451 |
+
self.size = size
|
452 |
+
self.crop_size = crop_size if crop_size is not None else self.size
|
453 |
+
self.random_crop = random_crop
|
454 |
+
self.up_factor = up_factor
|
455 |
+
self.hr_factor = hr_factor
|
456 |
+
self.keep_mode = keep_mode
|
457 |
+
|
458 |
+
transforms = list()
|
459 |
+
|
460 |
+
if self.size is not None and self.size > 0:
|
461 |
+
rescaler = albumentations.SmallestMaxSize(max_size = self.size)
|
462 |
+
self.rescaler = rescaler
|
463 |
+
transforms.append(rescaler)
|
464 |
+
|
465 |
+
if self.crop_size is not None and self.crop_size > 0:
|
466 |
+
if len(transforms) == 0:
|
467 |
+
self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size)
|
468 |
+
|
469 |
+
if not self.random_crop:
|
470 |
+
cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size)
|
471 |
+
else:
|
472 |
+
cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size)
|
473 |
+
transforms.append(cropper)
|
474 |
+
|
475 |
+
if len(transforms) > 0:
|
476 |
+
if self.up_factor is not None:
|
477 |
+
additional_targets = {"lr": "image"}
|
478 |
+
else:
|
479 |
+
additional_targets = None
|
480 |
+
self.preprocessor = albumentations.Compose(transforms,
|
481 |
+
additional_targets=additional_targets)
|
482 |
+
else:
|
483 |
+
self.preprocessor = lambda **kwargs: kwargs
|
484 |
+
|
485 |
+
def __len__(self):
|
486 |
+
return len(self.base)
|
487 |
+
|
488 |
+
def __getitem__(self, i):
|
489 |
+
example = self.base[i]
|
490 |
+
image = example["image"]
|
491 |
+
# adjust resolution
|
492 |
+
image = imscale(image, self.hr_factor, keepshapes=False)
|
493 |
+
h,w,c = image.shape
|
494 |
+
if self.crop_size and min(h,w) < self.crop_size:
|
495 |
+
# have to upscale to be able to crop - this just uses bilinear
|
496 |
+
image = self.rescaler(image=image)["image"]
|
497 |
+
if self.up_factor is None:
|
498 |
+
image = self.preprocessor(image=image)["image"]
|
499 |
+
example["image"] = image
|
500 |
+
else:
|
501 |
+
lr = imscale(image, self.up_factor, keepshapes=True,
|
502 |
+
keepmode=self.keep_mode)
|
503 |
+
|
504 |
+
out = self.preprocessor(image=image, lr=lr)
|
505 |
+
example["image"] = out["image"]
|
506 |
+
example["lr"] = out["lr"]
|
507 |
+
|
508 |
+
return example
|
509 |
+
|
510 |
+
class ImageNetScaleTrain(ImageNetScale):
|
511 |
+
def __init__(self, random_crop=True, **kwargs):
|
512 |
+
super().__init__(random_crop=random_crop, **kwargs)
|
513 |
+
|
514 |
+
def get_base(self):
|
515 |
+
return ImageNetTrain()
|
516 |
+
|
517 |
+
class ImageNetScaleValidation(ImageNetScale):
|
518 |
+
def get_base(self):
|
519 |
+
return ImageNetValidation()
|
520 |
+
|
521 |
+
|
522 |
+
from skimage.feature import canny
|
523 |
+
from skimage.color import rgb2gray
|
524 |
+
|
525 |
+
|
526 |
+
class ImageNetEdges(ImageNetScale):
|
527 |
+
def __init__(self, up_factor=1, **kwargs):
|
528 |
+
super().__init__(up_factor=1, **kwargs)
|
529 |
+
|
530 |
+
def __getitem__(self, i):
|
531 |
+
example = self.base[i]
|
532 |
+
image = example["image"]
|
533 |
+
h,w,c = image.shape
|
534 |
+
if self.crop_size and min(h,w) < self.crop_size:
|
535 |
+
# have to upscale to be able to crop - this just uses bilinear
|
536 |
+
image = self.rescaler(image=image)["image"]
|
537 |
+
|
538 |
+
lr = canny(rgb2gray(image), sigma=2)
|
539 |
+
lr = lr.astype(np.float32)
|
540 |
+
lr = lr[:,:,None][:,:,[0,0,0]]
|
541 |
+
|
542 |
+
out = self.preprocessor(image=image, lr=lr)
|
543 |
+
example["image"] = out["image"]
|
544 |
+
example["lr"] = out["lr"]
|
545 |
+
|
546 |
+
return example
|
547 |
+
|
548 |
+
|
549 |
+
class ImageNetEdgesTrain(ImageNetEdges):
|
550 |
+
def __init__(self, random_crop=True, **kwargs):
|
551 |
+
super().__init__(random_crop=random_crop, **kwargs)
|
552 |
+
|
553 |
+
def get_base(self):
|
554 |
+
return ImageNetTrain()
|
555 |
+
|
556 |
+
class ImageNetEdgesValidation(ImageNetEdges):
|
557 |
+
def get_base(self):
|
558 |
+
return ImageNetValidation()
|
taming/data/open_images_helper.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
open_images_unify_categories_for_coco = {
|
2 |
+
'/m/03bt1vf': '/m/01g317',
|
3 |
+
'/m/04yx4': '/m/01g317',
|
4 |
+
'/m/05r655': '/m/01g317',
|
5 |
+
'/m/01bl7v': '/m/01g317',
|
6 |
+
'/m/0cnyhnx': '/m/01xq0k1',
|
7 |
+
'/m/01226z': '/m/018xm',
|
8 |
+
'/m/05ctyq': '/m/018xm',
|
9 |
+
'/m/058qzx': '/m/04ctx',
|
10 |
+
'/m/06pcq': '/m/0l515',
|
11 |
+
'/m/03m3pdh': '/m/02crq1',
|
12 |
+
'/m/046dlr': '/m/01x3z',
|
13 |
+
'/m/0h8mzrc': '/m/01x3z',
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
top_300_classes_plus_coco_compatibility = [
|
18 |
+
('Man', 1060962),
|
19 |
+
('Clothing', 986610),
|
20 |
+
('Tree', 748162),
|
21 |
+
('Woman', 611896),
|
22 |
+
('Person', 610294),
|
23 |
+
('Human face', 442948),
|
24 |
+
('Girl', 175399),
|
25 |
+
('Building', 162147),
|
26 |
+
('Car', 159135),
|
27 |
+
('Plant', 155704),
|
28 |
+
('Human body', 137073),
|
29 |
+
('Flower', 133128),
|
30 |
+
('Window', 127485),
|
31 |
+
('Human arm', 118380),
|
32 |
+
('House', 114365),
|
33 |
+
('Wheel', 111684),
|
34 |
+
('Suit', 99054),
|
35 |
+
('Human hair', 98089),
|
36 |
+
('Human head', 92763),
|
37 |
+
('Chair', 88624),
|
38 |
+
('Boy', 79849),
|
39 |
+
('Table', 73699),
|
40 |
+
('Jeans', 57200),
|
41 |
+
('Tire', 55725),
|
42 |
+
('Skyscraper', 53321),
|
43 |
+
('Food', 52400),
|
44 |
+
('Footwear', 50335),
|
45 |
+
('Dress', 50236),
|
46 |
+
('Human leg', 47124),
|
47 |
+
('Toy', 46636),
|
48 |
+
('Tower', 45605),
|
49 |
+
('Boat', 43486),
|
50 |
+
('Land vehicle', 40541),
|
51 |
+
('Bicycle wheel', 34646),
|
52 |
+
('Palm tree', 33729),
|
53 |
+
('Fashion accessory', 32914),
|
54 |
+
('Glasses', 31940),
|
55 |
+
('Bicycle', 31409),
|
56 |
+
('Furniture', 30656),
|
57 |
+
('Sculpture', 29643),
|
58 |
+
('Bottle', 27558),
|
59 |
+
('Dog', 26980),
|
60 |
+
('Snack', 26796),
|
61 |
+
('Human hand', 26664),
|
62 |
+
('Bird', 25791),
|
63 |
+
('Book', 25415),
|
64 |
+
('Guitar', 24386),
|
65 |
+
('Jacket', 23998),
|
66 |
+
('Poster', 22192),
|
67 |
+
('Dessert', 21284),
|
68 |
+
('Baked goods', 20657),
|
69 |
+
('Drink', 19754),
|
70 |
+
('Flag', 18588),
|
71 |
+
('Houseplant', 18205),
|
72 |
+
('Tableware', 17613),
|
73 |
+
('Airplane', 17218),
|
74 |
+
('Door', 17195),
|
75 |
+
('Sports uniform', 17068),
|
76 |
+
('Shelf', 16865),
|
77 |
+
('Drum', 16612),
|
78 |
+
('Vehicle', 16542),
|
79 |
+
('Microphone', 15269),
|
80 |
+
('Street light', 14957),
|
81 |
+
('Cat', 14879),
|
82 |
+
('Fruit', 13684),
|
83 |
+
('Fast food', 13536),
|
84 |
+
('Animal', 12932),
|
85 |
+
('Vegetable', 12534),
|
86 |
+
('Train', 12358),
|
87 |
+
('Horse', 11948),
|
88 |
+
('Flowerpot', 11728),
|
89 |
+
('Motorcycle', 11621),
|
90 |
+
('Fish', 11517),
|
91 |
+
('Desk', 11405),
|
92 |
+
('Helmet', 10996),
|
93 |
+
('Truck', 10915),
|
94 |
+
('Bus', 10695),
|
95 |
+
('Hat', 10532),
|
96 |
+
('Auto part', 10488),
|
97 |
+
('Musical instrument', 10303),
|
98 |
+
('Sunglasses', 10207),
|
99 |
+
('Picture frame', 10096),
|
100 |
+
('Sports equipment', 10015),
|
101 |
+
('Shorts', 9999),
|
102 |
+
('Wine glass', 9632),
|
103 |
+
('Duck', 9242),
|
104 |
+
('Wine', 9032),
|
105 |
+
('Rose', 8781),
|
106 |
+
('Tie', 8693),
|
107 |
+
('Butterfly', 8436),
|
108 |
+
('Beer', 7978),
|
109 |
+
('Cabinetry', 7956),
|
110 |
+
('Laptop', 7907),
|
111 |
+
('Insect', 7497),
|
112 |
+
('Goggles', 7363),
|
113 |
+
('Shirt', 7098),
|
114 |
+
('Dairy Product', 7021),
|
115 |
+
('Marine invertebrates', 7014),
|
116 |
+
('Cattle', 7006),
|
117 |
+
('Trousers', 6903),
|
118 |
+
('Van', 6843),
|
119 |
+
('Billboard', 6777),
|
120 |
+
('Balloon', 6367),
|
121 |
+
('Human nose', 6103),
|
122 |
+
('Tent', 6073),
|
123 |
+
('Camera', 6014),
|
124 |
+
('Doll', 6002),
|
125 |
+
('Coat', 5951),
|
126 |
+
('Mobile phone', 5758),
|
127 |
+
('Swimwear', 5729),
|
128 |
+
('Strawberry', 5691),
|
129 |
+
('Stairs', 5643),
|
130 |
+
('Goose', 5599),
|
131 |
+
('Umbrella', 5536),
|
132 |
+
('Cake', 5508),
|
133 |
+
('Sun hat', 5475),
|
134 |
+
('Bench', 5310),
|
135 |
+
('Bookcase', 5163),
|
136 |
+
('Bee', 5140),
|
137 |
+
('Computer monitor', 5078),
|
138 |
+
('Hiking equipment', 4983),
|
139 |
+
('Office building', 4981),
|
140 |
+
('Coffee cup', 4748),
|
141 |
+
('Curtain', 4685),
|
142 |
+
('Plate', 4651),
|
143 |
+
('Box', 4621),
|
144 |
+
('Tomato', 4595),
|
145 |
+
('Coffee table', 4529),
|
146 |
+
('Office supplies', 4473),
|
147 |
+
('Maple', 4416),
|
148 |
+
('Muffin', 4365),
|
149 |
+
('Cocktail', 4234),
|
150 |
+
('Castle', 4197),
|
151 |
+
('Couch', 4134),
|
152 |
+
('Pumpkin', 3983),
|
153 |
+
('Computer keyboard', 3960),
|
154 |
+
('Human mouth', 3926),
|
155 |
+
('Christmas tree', 3893),
|
156 |
+
('Mushroom', 3883),
|
157 |
+
('Swimming pool', 3809),
|
158 |
+
('Pastry', 3799),
|
159 |
+
('Lavender (Plant)', 3769),
|
160 |
+
('Football helmet', 3732),
|
161 |
+
('Bread', 3648),
|
162 |
+
('Traffic sign', 3628),
|
163 |
+
('Common sunflower', 3597),
|
164 |
+
('Television', 3550),
|
165 |
+
('Bed', 3525),
|
166 |
+
('Cookie', 3485),
|
167 |
+
('Fountain', 3484),
|
168 |
+
('Paddle', 3447),
|
169 |
+
('Bicycle helmet', 3429),
|
170 |
+
('Porch', 3420),
|
171 |
+
('Deer', 3387),
|
172 |
+
('Fedora', 3339),
|
173 |
+
('Canoe', 3338),
|
174 |
+
('Carnivore', 3266),
|
175 |
+
('Bowl', 3202),
|
176 |
+
('Human eye', 3166),
|
177 |
+
('Ball', 3118),
|
178 |
+
('Pillow', 3077),
|
179 |
+
('Salad', 3061),
|
180 |
+
('Beetle', 3060),
|
181 |
+
('Orange', 3050),
|
182 |
+
('Drawer', 2958),
|
183 |
+
('Platter', 2937),
|
184 |
+
('Elephant', 2921),
|
185 |
+
('Seafood', 2921),
|
186 |
+
('Monkey', 2915),
|
187 |
+
('Countertop', 2879),
|
188 |
+
('Watercraft', 2831),
|
189 |
+
('Helicopter', 2805),
|
190 |
+
('Kitchen appliance', 2797),
|
191 |
+
('Personal flotation device', 2781),
|
192 |
+
('Swan', 2739),
|
193 |
+
('Lamp', 2711),
|
194 |
+
('Boot', 2695),
|
195 |
+
('Bronze sculpture', 2693),
|
196 |
+
('Chicken', 2677),
|
197 |
+
('Taxi', 2643),
|
198 |
+
('Juice', 2615),
|
199 |
+
('Cowboy hat', 2604),
|
200 |
+
('Apple', 2600),
|
201 |
+
('Tin can', 2590),
|
202 |
+
('Necklace', 2564),
|
203 |
+
('Ice cream', 2560),
|
204 |
+
('Human beard', 2539),
|
205 |
+
('Coin', 2536),
|
206 |
+
('Candle', 2515),
|
207 |
+
('Cart', 2512),
|
208 |
+
('High heels', 2441),
|
209 |
+
('Weapon', 2433),
|
210 |
+
('Handbag', 2406),
|
211 |
+
('Penguin', 2396),
|
212 |
+
('Rifle', 2352),
|
213 |
+
('Violin', 2336),
|
214 |
+
('Skull', 2304),
|
215 |
+
('Lantern', 2285),
|
216 |
+
('Scarf', 2269),
|
217 |
+
('Saucer', 2225),
|
218 |
+
('Sheep', 2215),
|
219 |
+
('Vase', 2189),
|
220 |
+
('Lily', 2180),
|
221 |
+
('Mug', 2154),
|
222 |
+
('Parrot', 2140),
|
223 |
+
('Human ear', 2137),
|
224 |
+
('Sandal', 2115),
|
225 |
+
('Lizard', 2100),
|
226 |
+
('Kitchen & dining room table', 2063),
|
227 |
+
('Spider', 1977),
|
228 |
+
('Coffee', 1974),
|
229 |
+
('Goat', 1926),
|
230 |
+
('Squirrel', 1922),
|
231 |
+
('Cello', 1913),
|
232 |
+
('Sushi', 1881),
|
233 |
+
('Tortoise', 1876),
|
234 |
+
('Pizza', 1870),
|
235 |
+
('Studio couch', 1864),
|
236 |
+
('Barrel', 1862),
|
237 |
+
('Cosmetics', 1841),
|
238 |
+
('Moths and butterflies', 1841),
|
239 |
+
('Convenience store', 1817),
|
240 |
+
('Watch', 1792),
|
241 |
+
('Home appliance', 1786),
|
242 |
+
('Harbor seal', 1780),
|
243 |
+
('Luggage and bags', 1756),
|
244 |
+
('Vehicle registration plate', 1754),
|
245 |
+
('Shrimp', 1751),
|
246 |
+
('Jellyfish', 1730),
|
247 |
+
('French fries', 1723),
|
248 |
+
('Egg (Food)', 1698),
|
249 |
+
('Football', 1697),
|
250 |
+
('Musical keyboard', 1683),
|
251 |
+
('Falcon', 1674),
|
252 |
+
('Candy', 1660),
|
253 |
+
('Medical equipment', 1654),
|
254 |
+
('Eagle', 1651),
|
255 |
+
('Dinosaur', 1634),
|
256 |
+
('Surfboard', 1630),
|
257 |
+
('Tank', 1628),
|
258 |
+
('Grape', 1624),
|
259 |
+
('Lion', 1624),
|
260 |
+
('Owl', 1622),
|
261 |
+
('Ski', 1613),
|
262 |
+
('Waste container', 1606),
|
263 |
+
('Frog', 1591),
|
264 |
+
('Sparrow', 1585),
|
265 |
+
('Rabbit', 1581),
|
266 |
+
('Pen', 1546),
|
267 |
+
('Sea lion', 1537),
|
268 |
+
('Spoon', 1521),
|
269 |
+
('Sink', 1512),
|
270 |
+
('Teddy bear', 1507),
|
271 |
+
('Bull', 1495),
|
272 |
+
('Sofa bed', 1490),
|
273 |
+
('Dragonfly', 1479),
|
274 |
+
('Brassiere', 1478),
|
275 |
+
('Chest of drawers', 1472),
|
276 |
+
('Aircraft', 1466),
|
277 |
+
('Human foot', 1463),
|
278 |
+
('Pig', 1455),
|
279 |
+
('Fork', 1454),
|
280 |
+
('Antelope', 1438),
|
281 |
+
('Tripod', 1427),
|
282 |
+
('Tool', 1424),
|
283 |
+
('Cheese', 1422),
|
284 |
+
('Lemon', 1397),
|
285 |
+
('Hamburger', 1393),
|
286 |
+
('Dolphin', 1390),
|
287 |
+
('Mirror', 1390),
|
288 |
+
('Marine mammal', 1387),
|
289 |
+
('Giraffe', 1385),
|
290 |
+
('Snake', 1368),
|
291 |
+
('Gondola', 1364),
|
292 |
+
('Wheelchair', 1360),
|
293 |
+
('Piano', 1358),
|
294 |
+
('Cupboard', 1348),
|
295 |
+
('Banana', 1345),
|
296 |
+
('Trumpet', 1335),
|
297 |
+
('Lighthouse', 1333),
|
298 |
+
('Invertebrate', 1317),
|
299 |
+
('Carrot', 1268),
|
300 |
+
('Sock', 1260),
|
301 |
+
('Tiger', 1241),
|
302 |
+
('Camel', 1224),
|
303 |
+
('Parachute', 1224),
|
304 |
+
('Bathroom accessory', 1223),
|
305 |
+
('Earrings', 1221),
|
306 |
+
('Headphones', 1218),
|
307 |
+
('Skirt', 1198),
|
308 |
+
('Skateboard', 1190),
|
309 |
+
('Sandwich', 1148),
|
310 |
+
('Saxophone', 1141),
|
311 |
+
('Goldfish', 1136),
|
312 |
+
('Stool', 1104),
|
313 |
+
('Traffic light', 1097),
|
314 |
+
('Shellfish', 1081),
|
315 |
+
('Backpack', 1079),
|
316 |
+
('Sea turtle', 1078),
|
317 |
+
('Cucumber', 1075),
|
318 |
+
('Tea', 1051),
|
319 |
+
('Toilet', 1047),
|
320 |
+
('Roller skates', 1040),
|
321 |
+
('Mule', 1039),
|
322 |
+
('Bust', 1031),
|
323 |
+
('Broccoli', 1030),
|
324 |
+
('Crab', 1020),
|
325 |
+
('Oyster', 1019),
|
326 |
+
('Cannon', 1012),
|
327 |
+
('Zebra', 1012),
|
328 |
+
('French horn', 1008),
|
329 |
+
('Grapefruit', 998),
|
330 |
+
('Whiteboard', 997),
|
331 |
+
('Zucchini', 997),
|
332 |
+
('Crocodile', 992),
|
333 |
+
|
334 |
+
('Clock', 960),
|
335 |
+
('Wall clock', 958),
|
336 |
+
|
337 |
+
('Doughnut', 869),
|
338 |
+
('Snail', 868),
|
339 |
+
|
340 |
+
('Baseball glove', 859),
|
341 |
+
|
342 |
+
('Panda', 830),
|
343 |
+
('Tennis racket', 830),
|
344 |
+
|
345 |
+
('Pear', 652),
|
346 |
+
|
347 |
+
('Bagel', 617),
|
348 |
+
('Oven', 616),
|
349 |
+
('Ladybug', 615),
|
350 |
+
('Shark', 615),
|
351 |
+
('Polar bear', 614),
|
352 |
+
('Ostrich', 609),
|
353 |
+
|
354 |
+
('Hot dog', 473),
|
355 |
+
('Microwave oven', 467),
|
356 |
+
('Fire hydrant', 20),
|
357 |
+
('Stop sign', 20),
|
358 |
+
('Parking meter', 20),
|
359 |
+
('Bear', 20),
|
360 |
+
('Flying disc', 20),
|
361 |
+
('Snowboard', 20),
|
362 |
+
('Tennis ball', 20),
|
363 |
+
('Kite', 20),
|
364 |
+
('Baseball bat', 20),
|
365 |
+
('Kitchen knife', 20),
|
366 |
+
('Knife', 20),
|
367 |
+
('Submarine sandwich', 20),
|
368 |
+
('Computer mouse', 20),
|
369 |
+
('Remote control', 20),
|
370 |
+
('Toaster', 20),
|
371 |
+
('Sink', 20),
|
372 |
+
('Refrigerator', 20),
|
373 |
+
('Alarm clock', 20),
|
374 |
+
('Wall clock', 20),
|
375 |
+
('Scissors', 20),
|
376 |
+
('Hair dryer', 20),
|
377 |
+
('Toothbrush', 20),
|
378 |
+
('Suitcase', 20)
|
379 |
+
]
|
taming/data/sflckr.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import albumentations
|
5 |
+
from PIL import Image
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
|
8 |
+
|
9 |
+
class SegmentationBase(Dataset):
|
10 |
+
def __init__(self,
|
11 |
+
data_csv, data_root, segmentation_root,
|
12 |
+
size=None, random_crop=False, interpolation="bicubic",
|
13 |
+
n_labels=182, shift_segmentation=False,
|
14 |
+
):
|
15 |
+
self.n_labels = n_labels
|
16 |
+
self.shift_segmentation = shift_segmentation
|
17 |
+
self.data_csv = data_csv
|
18 |
+
self.data_root = data_root
|
19 |
+
self.segmentation_root = segmentation_root
|
20 |
+
with open(self.data_csv, "r") as f:
|
21 |
+
self.image_paths = f.read().splitlines()
|
22 |
+
self._length = len(self.image_paths)
|
23 |
+
self.labels = {
|
24 |
+
"relative_file_path_": [l for l in self.image_paths],
|
25 |
+
"file_path_": [os.path.join(self.data_root, l)
|
26 |
+
for l in self.image_paths],
|
27 |
+
"segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
|
28 |
+
for l in self.image_paths]
|
29 |
+
}
|
30 |
+
|
31 |
+
size = None if size is not None and size<=0 else size
|
32 |
+
self.size = size
|
33 |
+
if self.size is not None:
|
34 |
+
self.interpolation = interpolation
|
35 |
+
self.interpolation = {
|
36 |
+
"nearest": cv2.INTER_NEAREST,
|
37 |
+
"bilinear": cv2.INTER_LINEAR,
|
38 |
+
"bicubic": cv2.INTER_CUBIC,
|
39 |
+
"area": cv2.INTER_AREA,
|
40 |
+
"lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
|
41 |
+
self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
|
42 |
+
interpolation=self.interpolation)
|
43 |
+
self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
|
44 |
+
interpolation=cv2.INTER_NEAREST)
|
45 |
+
self.center_crop = not random_crop
|
46 |
+
if self.center_crop:
|
47 |
+
self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
|
48 |
+
else:
|
49 |
+
self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
|
50 |
+
self.preprocessor = self.cropper
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return self._length
|
54 |
+
|
55 |
+
def __getitem__(self, i):
|
56 |
+
example = dict((k, self.labels[k][i]) for k in self.labels)
|
57 |
+
image = Image.open(example["file_path_"])
|
58 |
+
if not image.mode == "RGB":
|
59 |
+
image = image.convert("RGB")
|
60 |
+
image = np.array(image).astype(np.uint8)
|
61 |
+
if self.size is not None:
|
62 |
+
image = self.image_rescaler(image=image)["image"]
|
63 |
+
segmentation = Image.open(example["segmentation_path_"])
|
64 |
+
assert segmentation.mode == "L", segmentation.mode
|
65 |
+
segmentation = np.array(segmentation).astype(np.uint8)
|
66 |
+
if self.shift_segmentation:
|
67 |
+
# used to support segmentations containing unlabeled==255 label
|
68 |
+
segmentation = segmentation+1
|
69 |
+
if self.size is not None:
|
70 |
+
segmentation = self.segmentation_rescaler(image=segmentation)["image"]
|
71 |
+
if self.size is not None:
|
72 |
+
processed = self.preprocessor(image=image,
|
73 |
+
mask=segmentation
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
processed = {"image": image,
|
77 |
+
"mask": segmentation
|
78 |
+
}
|
79 |
+
example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
|
80 |
+
segmentation = processed["mask"]
|
81 |
+
onehot = np.eye(self.n_labels)[segmentation]
|
82 |
+
example["segmentation"] = onehot
|
83 |
+
return example
|
84 |
+
|
85 |
+
|
86 |
+
class Examples(SegmentationBase):
|
87 |
+
def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
|
88 |
+
super().__init__(data_csv="data/sflckr_examples.txt",
|
89 |
+
data_root="data/sflckr_images",
|
90 |
+
segmentation_root="data/sflckr_segmentations",
|
91 |
+
size=size, random_crop=random_crop, interpolation=interpolation)
|
taming/data/utils.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import os
|
3 |
+
import tarfile
|
4 |
+
import urllib
|
5 |
+
import zipfile
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from taming.data.helper_types import Annotation
|
11 |
+
from torch._six import string_classes
|
12 |
+
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
def unpack(path):
|
17 |
+
if path.endswith("tar.gz"):
|
18 |
+
with tarfile.open(path, "r:gz") as tar:
|
19 |
+
tar.extractall(path=os.path.split(path)[0])
|
20 |
+
elif path.endswith("tar"):
|
21 |
+
with tarfile.open(path, "r:") as tar:
|
22 |
+
tar.extractall(path=os.path.split(path)[0])
|
23 |
+
elif path.endswith("zip"):
|
24 |
+
with zipfile.ZipFile(path, "r") as f:
|
25 |
+
f.extractall(path=os.path.split(path)[0])
|
26 |
+
else:
|
27 |
+
raise NotImplementedError(
|
28 |
+
"Unknown file extension: {}".format(os.path.splitext(path)[1])
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def reporthook(bar):
|
33 |
+
"""tqdm progress bar for downloads."""
|
34 |
+
|
35 |
+
def hook(b=1, bsize=1, tsize=None):
|
36 |
+
if tsize is not None:
|
37 |
+
bar.total = tsize
|
38 |
+
bar.update(b * bsize - bar.n)
|
39 |
+
|
40 |
+
return hook
|
41 |
+
|
42 |
+
|
43 |
+
def get_root(name):
|
44 |
+
base = "data/"
|
45 |
+
root = os.path.join(base, name)
|
46 |
+
os.makedirs(root, exist_ok=True)
|
47 |
+
return root
|
48 |
+
|
49 |
+
|
50 |
+
def is_prepared(root):
|
51 |
+
return Path(root).joinpath(".ready").exists()
|
52 |
+
|
53 |
+
|
54 |
+
def mark_prepared(root):
|
55 |
+
Path(root).joinpath(".ready").touch()
|
56 |
+
|
57 |
+
|
58 |
+
def prompt_download(file_, source, target_dir, content_dir=None):
|
59 |
+
targetpath = os.path.join(target_dir, file_)
|
60 |
+
while not os.path.exists(targetpath):
|
61 |
+
if content_dir is not None and os.path.exists(
|
62 |
+
os.path.join(target_dir, content_dir)
|
63 |
+
):
|
64 |
+
break
|
65 |
+
print(
|
66 |
+
"Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
|
67 |
+
)
|
68 |
+
if content_dir is not None:
|
69 |
+
print(
|
70 |
+
"Or place its content into '{}'.".format(
|
71 |
+
os.path.join(target_dir, content_dir)
|
72 |
+
)
|
73 |
+
)
|
74 |
+
input("Press Enter when done...")
|
75 |
+
return targetpath
|
76 |
+
|
77 |
+
|
78 |
+
def download_url(file_, url, target_dir):
|
79 |
+
targetpath = os.path.join(target_dir, file_)
|
80 |
+
os.makedirs(target_dir, exist_ok=True)
|
81 |
+
with tqdm(
|
82 |
+
unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
|
83 |
+
) as bar:
|
84 |
+
urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
|
85 |
+
return targetpath
|
86 |
+
|
87 |
+
|
88 |
+
def download_urls(urls, target_dir):
|
89 |
+
paths = dict()
|
90 |
+
for fname, url in urls.items():
|
91 |
+
outpath = download_url(fname, url, target_dir)
|
92 |
+
paths[fname] = outpath
|
93 |
+
return paths
|
94 |
+
|
95 |
+
|
96 |
+
def quadratic_crop(x, bbox, alpha=1.0):
|
97 |
+
"""bbox is xmin, ymin, xmax, ymax"""
|
98 |
+
im_h, im_w = x.shape[:2]
|
99 |
+
bbox = np.array(bbox, dtype=np.float32)
|
100 |
+
bbox = np.clip(bbox, 0, max(im_h, im_w))
|
101 |
+
center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
|
102 |
+
w = bbox[2] - bbox[0]
|
103 |
+
h = bbox[3] - bbox[1]
|
104 |
+
l = int(alpha * max(w, h))
|
105 |
+
l = max(l, 2)
|
106 |
+
|
107 |
+
required_padding = -1 * min(
|
108 |
+
center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
|
109 |
+
)
|
110 |
+
required_padding = int(np.ceil(required_padding))
|
111 |
+
if required_padding > 0:
|
112 |
+
padding = [
|
113 |
+
[required_padding, required_padding],
|
114 |
+
[required_padding, required_padding],
|
115 |
+
]
|
116 |
+
padding += [[0, 0]] * (len(x.shape) - 2)
|
117 |
+
x = np.pad(x, padding, "reflect")
|
118 |
+
center = center[0] + required_padding, center[1] + required_padding
|
119 |
+
xmin = int(center[0] - l / 2)
|
120 |
+
ymin = int(center[1] - l / 2)
|
121 |
+
return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
|
122 |
+
|
123 |
+
|
124 |
+
def custom_collate(batch):
|
125 |
+
r"""source: pytorch 1.9.0, only one modification to original code """
|
126 |
+
|
127 |
+
elem = batch[0]
|
128 |
+
elem_type = type(elem)
|
129 |
+
if isinstance(elem, torch.Tensor):
|
130 |
+
out = None
|
131 |
+
if torch.utils.data.get_worker_info() is not None:
|
132 |
+
# If we're in a background process, concatenate directly into a
|
133 |
+
# shared memory tensor to avoid an extra copy
|
134 |
+
numel = sum([x.numel() for x in batch])
|
135 |
+
storage = elem.storage()._new_shared(numel)
|
136 |
+
out = elem.new(storage)
|
137 |
+
return torch.stack(batch, 0, out=out)
|
138 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
139 |
+
and elem_type.__name__ != 'string_':
|
140 |
+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
141 |
+
# array of string classes and object
|
142 |
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
143 |
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
144 |
+
|
145 |
+
return custom_collate([torch.as_tensor(b) for b in batch])
|
146 |
+
elif elem.shape == (): # scalars
|
147 |
+
return torch.as_tensor(batch)
|
148 |
+
elif isinstance(elem, float):
|
149 |
+
return torch.tensor(batch, dtype=torch.float64)
|
150 |
+
elif isinstance(elem, int):
|
151 |
+
return torch.tensor(batch)
|
152 |
+
elif isinstance(elem, string_classes):
|
153 |
+
return batch
|
154 |
+
elif isinstance(elem, collections.abc.Mapping):
|
155 |
+
return {key: custom_collate([d[key] for d in batch]) for key in elem}
|
156 |
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
157 |
+
return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
|
158 |
+
if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
|
159 |
+
return batch # added
|
160 |
+
elif isinstance(elem, collections.abc.Sequence):
|
161 |
+
# check to make sure that the elements in batch have consistent size
|
162 |
+
it = iter(batch)
|
163 |
+
elem_size = len(next(it))
|
164 |
+
if not all(len(elem) == elem_size for elem in it):
|
165 |
+
raise RuntimeError('each element in list of batch should be of equal size')
|
166 |
+
transposed = zip(*batch)
|
167 |
+
return [custom_collate(samples) for samples in transposed]
|
168 |
+
|
169 |
+
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
taming/lr_scheduler.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
9 |
+
self.lr_warm_up_steps = warm_up_steps
|
10 |
+
self.lr_start = lr_start
|
11 |
+
self.lr_min = lr_min
|
12 |
+
self.lr_max = lr_max
|
13 |
+
self.lr_max_decay_steps = max_decay_steps
|
14 |
+
self.last_lr = 0.
|
15 |
+
self.verbosity_interval = verbosity_interval
|
16 |
+
|
17 |
+
def schedule(self, n):
|
18 |
+
if self.verbosity_interval > 0:
|
19 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
20 |
+
if n < self.lr_warm_up_steps:
|
21 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
22 |
+
self.last_lr = lr
|
23 |
+
return lr
|
24 |
+
else:
|
25 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
26 |
+
t = min(t, 1.0)
|
27 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
28 |
+
1 + np.cos(t * np.pi))
|
29 |
+
self.last_lr = lr
|
30 |
+
return lr
|
31 |
+
|
32 |
+
def __call__(self, n):
|
33 |
+
return self.schedule(n)
|
34 |
+
|
taming/models/cond_transformer.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, math
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
|
6 |
+
from main import instantiate_from_config
|
7 |
+
from taming.modules.util import SOSProvider
|
8 |
+
|
9 |
+
|
10 |
+
def disabled_train(self, mode=True):
|
11 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
12 |
+
does not change anymore."""
|
13 |
+
return self
|
14 |
+
|
15 |
+
|
16 |
+
class Net2NetTransformer(pl.LightningModule):
|
17 |
+
def __init__(self,
|
18 |
+
transformer_config,
|
19 |
+
first_stage_config,
|
20 |
+
cond_stage_config,
|
21 |
+
permuter_config=None,
|
22 |
+
ckpt_path=None,
|
23 |
+
ignore_keys=[],
|
24 |
+
first_stage_key="image",
|
25 |
+
cond_stage_key="depth",
|
26 |
+
downsample_cond_size=-1,
|
27 |
+
pkeep=1.0,
|
28 |
+
sos_token=0,
|
29 |
+
unconditional=False,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
self.be_unconditional = unconditional
|
33 |
+
self.sos_token = sos_token
|
34 |
+
self.first_stage_key = first_stage_key
|
35 |
+
self.cond_stage_key = cond_stage_key
|
36 |
+
self.init_first_stage_from_ckpt(first_stage_config)
|
37 |
+
self.init_cond_stage_from_ckpt(cond_stage_config)
|
38 |
+
if permuter_config is None:
|
39 |
+
permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
|
40 |
+
self.permuter = instantiate_from_config(config=permuter_config)
|
41 |
+
self.transformer = instantiate_from_config(config=transformer_config)
|
42 |
+
|
43 |
+
if ckpt_path is not None:
|
44 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
45 |
+
self.downsample_cond_size = downsample_cond_size
|
46 |
+
self.pkeep = pkeep
|
47 |
+
|
48 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
49 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
50 |
+
for k in sd.keys():
|
51 |
+
for ik in ignore_keys:
|
52 |
+
if k.startswith(ik):
|
53 |
+
self.print("Deleting key {} from state_dict.".format(k))
|
54 |
+
del sd[k]
|
55 |
+
self.load_state_dict(sd, strict=False)
|
56 |
+
print(f"Restored from {path}")
|
57 |
+
|
58 |
+
def init_first_stage_from_ckpt(self, config):
|
59 |
+
model = instantiate_from_config(config)
|
60 |
+
model = model.eval()
|
61 |
+
model.train = disabled_train
|
62 |
+
self.first_stage_model = model
|
63 |
+
|
64 |
+
def init_cond_stage_from_ckpt(self, config):
|
65 |
+
if config == "__is_first_stage__":
|
66 |
+
print("Using first stage also as cond stage.")
|
67 |
+
self.cond_stage_model = self.first_stage_model
|
68 |
+
elif config == "__is_unconditional__" or self.be_unconditional:
|
69 |
+
print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
|
70 |
+
f"Prepending {self.sos_token} as a sos token.")
|
71 |
+
self.be_unconditional = True
|
72 |
+
self.cond_stage_key = self.first_stage_key
|
73 |
+
self.cond_stage_model = SOSProvider(self.sos_token)
|
74 |
+
else:
|
75 |
+
model = instantiate_from_config(config)
|
76 |
+
model = model.eval()
|
77 |
+
model.train = disabled_train
|
78 |
+
self.cond_stage_model = model
|
79 |
+
|
80 |
+
def forward(self, x, c):
|
81 |
+
# one step to produce the logits
|
82 |
+
_, z_indices = self.encode_to_z(x)
|
83 |
+
_, c_indices = self.encode_to_c(c)
|
84 |
+
|
85 |
+
if self.training and self.pkeep < 1.0:
|
86 |
+
mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
|
87 |
+
device=z_indices.device))
|
88 |
+
mask = mask.round().to(dtype=torch.int64)
|
89 |
+
r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
|
90 |
+
a_indices = mask*z_indices+(1-mask)*r_indices
|
91 |
+
else:
|
92 |
+
a_indices = z_indices
|
93 |
+
|
94 |
+
cz_indices = torch.cat((c_indices, a_indices), dim=1)
|
95 |
+
|
96 |
+
# target includes all sequence elements (no need to handle first one
|
97 |
+
# differently because we are conditioning)
|
98 |
+
target = z_indices
|
99 |
+
# make the prediction
|
100 |
+
logits, _ = self.transformer(cz_indices[:, :-1])
|
101 |
+
# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
|
102 |
+
logits = logits[:, c_indices.shape[1]-1:]
|
103 |
+
|
104 |
+
return logits, target
|
105 |
+
|
106 |
+
def top_k_logits(self, logits, k):
|
107 |
+
v, ix = torch.topk(logits, k)
|
108 |
+
out = logits.clone()
|
109 |
+
out[out < v[..., [-1]]] = -float('Inf')
|
110 |
+
return out
|
111 |
+
|
112 |
+
@torch.no_grad()
|
113 |
+
def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
|
114 |
+
callback=lambda k: None):
|
115 |
+
x = torch.cat((c,x),dim=1)
|
116 |
+
block_size = self.transformer.get_block_size()
|
117 |
+
assert not self.transformer.training
|
118 |
+
if self.pkeep <= 0.0:
|
119 |
+
# one pass suffices since input is pure noise anyway
|
120 |
+
assert len(x.shape)==2
|
121 |
+
noise_shape = (x.shape[0], steps-1)
|
122 |
+
#noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
|
123 |
+
noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
|
124 |
+
x = torch.cat((x,noise),dim=1)
|
125 |
+
logits, _ = self.transformer(x)
|
126 |
+
# take all logits for now and scale by temp
|
127 |
+
logits = logits / temperature
|
128 |
+
# optionally crop probabilities to only the top k options
|
129 |
+
if top_k is not None:
|
130 |
+
logits = self.top_k_logits(logits, top_k)
|
131 |
+
# apply softmax to convert to probabilities
|
132 |
+
probs = F.softmax(logits, dim=-1)
|
133 |
+
# sample from the distribution or take the most likely
|
134 |
+
if sample:
|
135 |
+
shape = probs.shape
|
136 |
+
probs = probs.reshape(shape[0]*shape[1],shape[2])
|
137 |
+
ix = torch.multinomial(probs, num_samples=1)
|
138 |
+
probs = probs.reshape(shape[0],shape[1],shape[2])
|
139 |
+
ix = ix.reshape(shape[0],shape[1])
|
140 |
+
else:
|
141 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
142 |
+
# cut off conditioning
|
143 |
+
x = ix[:, c.shape[1]-1:]
|
144 |
+
else:
|
145 |
+
for k in range(steps):
|
146 |
+
callback(k)
|
147 |
+
assert x.size(1) <= block_size # make sure model can see conditioning
|
148 |
+
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
149 |
+
logits, _ = self.transformer(x_cond)
|
150 |
+
# pluck the logits at the final step and scale by temperature
|
151 |
+
logits = logits[:, -1, :] / temperature
|
152 |
+
# optionally crop probabilities to only the top k options
|
153 |
+
if top_k is not None:
|
154 |
+
logits = self.top_k_logits(logits, top_k)
|
155 |
+
# apply softmax to convert to probabilities
|
156 |
+
probs = F.softmax(logits, dim=-1)
|
157 |
+
# sample from the distribution or take the most likely
|
158 |
+
if sample:
|
159 |
+
ix = torch.multinomial(probs, num_samples=1)
|
160 |
+
else:
|
161 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
162 |
+
# append to the sequence and continue
|
163 |
+
x = torch.cat((x, ix), dim=1)
|
164 |
+
# cut off conditioning
|
165 |
+
x = x[:, c.shape[1]:]
|
166 |
+
return x
|
167 |
+
|
168 |
+
@torch.no_grad()
|
169 |
+
def encode_to_z(self, x):
|
170 |
+
quant_z, _, info = self.first_stage_model.encode(x)
|
171 |
+
indices = info[2].view(quant_z.shape[0], -1)
|
172 |
+
indices = self.permuter(indices)
|
173 |
+
return quant_z, indices
|
174 |
+
|
175 |
+
@torch.no_grad()
|
176 |
+
def encode_to_c(self, c):
|
177 |
+
if self.downsample_cond_size > -1:
|
178 |
+
c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
|
179 |
+
quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
|
180 |
+
if len(indices.shape) > 2:
|
181 |
+
indices = indices.view(c.shape[0], -1)
|
182 |
+
return quant_c, indices
|
183 |
+
|
184 |
+
@torch.no_grad()
|
185 |
+
def decode_to_img(self, index, zshape):
|
186 |
+
index = self.permuter(index, reverse=True)
|
187 |
+
bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
|
188 |
+
quant_z = self.first_stage_model.quantize.get_codebook_entry(
|
189 |
+
index.reshape(-1), shape=bhwc)
|
190 |
+
x = self.first_stage_model.decode(quant_z)
|
191 |
+
return x
|
192 |
+
|
193 |
+
@torch.no_grad()
|
194 |
+
def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
|
195 |
+
log = dict()
|
196 |
+
|
197 |
+
N = 4
|
198 |
+
if lr_interface:
|
199 |
+
x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
|
200 |
+
else:
|
201 |
+
x, c = self.get_xc(batch, N)
|
202 |
+
x = x.to(device=self.device)
|
203 |
+
c = c.to(device=self.device)
|
204 |
+
|
205 |
+
quant_z, z_indices = self.encode_to_z(x)
|
206 |
+
quant_c, c_indices = self.encode_to_c(c)
|
207 |
+
|
208 |
+
# create a "half"" sample
|
209 |
+
z_start_indices = z_indices[:,:z_indices.shape[1]//2]
|
210 |
+
index_sample = self.sample(z_start_indices, c_indices,
|
211 |
+
steps=z_indices.shape[1]-z_start_indices.shape[1],
|
212 |
+
temperature=temperature if temperature is not None else 1.0,
|
213 |
+
sample=True,
|
214 |
+
top_k=top_k if top_k is not None else 100,
|
215 |
+
callback=callback if callback is not None else lambda k: None)
|
216 |
+
x_sample = self.decode_to_img(index_sample, quant_z.shape)
|
217 |
+
|
218 |
+
# sample
|
219 |
+
z_start_indices = z_indices[:, :0]
|
220 |
+
index_sample = self.sample(z_start_indices, c_indices,
|
221 |
+
steps=z_indices.shape[1],
|
222 |
+
temperature=temperature if temperature is not None else 1.0,
|
223 |
+
sample=True,
|
224 |
+
top_k=top_k if top_k is not None else 100,
|
225 |
+
callback=callback if callback is not None else lambda k: None)
|
226 |
+
x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
|
227 |
+
|
228 |
+
# det sample
|
229 |
+
z_start_indices = z_indices[:, :0]
|
230 |
+
index_sample = self.sample(z_start_indices, c_indices,
|
231 |
+
steps=z_indices.shape[1],
|
232 |
+
sample=False,
|
233 |
+
callback=callback if callback is not None else lambda k: None)
|
234 |
+
x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
|
235 |
+
|
236 |
+
# reconstruction
|
237 |
+
x_rec = self.decode_to_img(z_indices, quant_z.shape)
|
238 |
+
|
239 |
+
log["inputs"] = x
|
240 |
+
log["reconstructions"] = x_rec
|
241 |
+
|
242 |
+
if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
|
243 |
+
figure_size = (x_rec.shape[2], x_rec.shape[3])
|
244 |
+
dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
|
245 |
+
label_for_category_no = dataset.get_textual_label_for_category_no
|
246 |
+
plotter = dataset.conditional_builders[self.cond_stage_key].plot
|
247 |
+
log["conditioning"] = torch.zeros_like(log["reconstructions"])
|
248 |
+
for i in range(quant_c.shape[0]):
|
249 |
+
log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
|
250 |
+
log["conditioning_rec"] = log["conditioning"]
|
251 |
+
elif self.cond_stage_key != "image":
|
252 |
+
cond_rec = self.cond_stage_model.decode(quant_c)
|
253 |
+
if self.cond_stage_key == "segmentation":
|
254 |
+
# get image from segmentation mask
|
255 |
+
num_classes = cond_rec.shape[1]
|
256 |
+
|
257 |
+
c = torch.argmax(c, dim=1, keepdim=True)
|
258 |
+
c = F.one_hot(c, num_classes=num_classes)
|
259 |
+
c = c.squeeze(1).permute(0, 3, 1, 2).float()
|
260 |
+
c = self.cond_stage_model.to_rgb(c)
|
261 |
+
|
262 |
+
cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
|
263 |
+
cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
|
264 |
+
cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
|
265 |
+
cond_rec = self.cond_stage_model.to_rgb(cond_rec)
|
266 |
+
log["conditioning_rec"] = cond_rec
|
267 |
+
log["conditioning"] = c
|
268 |
+
|
269 |
+
log["samples_half"] = x_sample
|
270 |
+
log["samples_nopix"] = x_sample_nopix
|
271 |
+
log["samples_det"] = x_sample_det
|
272 |
+
return log
|
273 |
+
|
274 |
+
def get_input(self, key, batch):
|
275 |
+
x = batch[key]
|
276 |
+
if len(x.shape) == 3:
|
277 |
+
x = x[..., None]
|
278 |
+
if len(x.shape) == 4:
|
279 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
280 |
+
if x.dtype == torch.double:
|
281 |
+
x = x.float()
|
282 |
+
return x
|
283 |
+
|
284 |
+
def get_xc(self, batch, N=None):
|
285 |
+
x = self.get_input(self.first_stage_key, batch)
|
286 |
+
c = self.get_input(self.cond_stage_key, batch)
|
287 |
+
if N is not None:
|
288 |
+
x = x[:N]
|
289 |
+
c = c[:N]
|
290 |
+
return x, c
|
291 |
+
|
292 |
+
def shared_step(self, batch, batch_idx):
|
293 |
+
x, c = self.get_xc(batch)
|
294 |
+
logits, target = self(x, c)
|
295 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
|
296 |
+
return loss
|
297 |
+
|
298 |
+
def training_step(self, batch, batch_idx):
|
299 |
+
loss = self.shared_step(batch, batch_idx)
|
300 |
+
self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
301 |
+
return loss
|
302 |
+
|
303 |
+
def validation_step(self, batch, batch_idx):
|
304 |
+
loss = self.shared_step(batch, batch_idx)
|
305 |
+
self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
306 |
+
return loss
|
307 |
+
|
308 |
+
def configure_optimizers(self):
|
309 |
+
"""
|
310 |
+
Following minGPT:
|
311 |
+
This long function is unfortunately doing something very simple and is being very defensive:
|
312 |
+
We are separating out all parameters of the model into two buckets: those that will experience
|
313 |
+
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
314 |
+
We are then returning the PyTorch optimizer object.
|
315 |
+
"""
|
316 |
+
# separate out all parameters to those that will and won't experience regularizing weight decay
|
317 |
+
decay = set()
|
318 |
+
no_decay = set()
|
319 |
+
whitelist_weight_modules = (torch.nn.Linear, )
|
320 |
+
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
321 |
+
for mn, m in self.transformer.named_modules():
|
322 |
+
for pn, p in m.named_parameters():
|
323 |
+
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
324 |
+
|
325 |
+
if pn.endswith('bias'):
|
326 |
+
# all biases will not be decayed
|
327 |
+
no_decay.add(fpn)
|
328 |
+
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
329 |
+
# weights of whitelist modules will be weight decayed
|
330 |
+
decay.add(fpn)
|
331 |
+
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
332 |
+
# weights of blacklist modules will NOT be weight decayed
|
333 |
+
no_decay.add(fpn)
|
334 |
+
|
335 |
+
# special case the position embedding parameter in the root GPT module as not decayed
|
336 |
+
no_decay.add('pos_emb')
|
337 |
+
|
338 |
+
# validate that we considered every parameter
|
339 |
+
param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
|
340 |
+
inter_params = decay & no_decay
|
341 |
+
union_params = decay | no_decay
|
342 |
+
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
343 |
+
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
344 |
+
% (str(param_dict.keys() - union_params), )
|
345 |
+
|
346 |
+
# create the pytorch optimizer object
|
347 |
+
optim_groups = [
|
348 |
+
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
|
349 |
+
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
350 |
+
]
|
351 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
|
352 |
+
return optimizer
|
taming/models/dummy_cond_stage.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor
|
2 |
+
|
3 |
+
|
4 |
+
class DummyCondStage:
|
5 |
+
def __init__(self, conditional_key):
|
6 |
+
self.conditional_key = conditional_key
|
7 |
+
self.train = None
|
8 |
+
|
9 |
+
def eval(self):
|
10 |
+
return self
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def encode(c: Tensor):
|
14 |
+
return c, None, (None, None, c)
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def decode(c: Tensor):
|
18 |
+
return c
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def to_rgb(c: Tensor):
|
22 |
+
return c
|
taming/models/vqgan.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
|
5 |
+
from main import instantiate_from_config
|
6 |
+
|
7 |
+
from taming.modules.diffusionmodules.model import Encoder, Decoder
|
8 |
+
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
9 |
+
from taming.modules.vqvae.quantize import GumbelQuantize
|
10 |
+
from taming.modules.vqvae.quantize import EMAVectorQuantizer
|
11 |
+
|
12 |
+
class VQModel(pl.LightningModule):
|
13 |
+
def __init__(self,
|
14 |
+
ddconfig,
|
15 |
+
lossconfig,
|
16 |
+
n_embed,
|
17 |
+
embed_dim,
|
18 |
+
ckpt_path=None,
|
19 |
+
ignore_keys=[],
|
20 |
+
image_key="image",
|
21 |
+
colorize_nlabels=None,
|
22 |
+
monitor=None,
|
23 |
+
remap=None,
|
24 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.image_key = image_key
|
28 |
+
self.encoder = Encoder(**ddconfig)
|
29 |
+
self.decoder = Decoder(**ddconfig)
|
30 |
+
self.loss = instantiate_from_config(lossconfig)
|
31 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
32 |
+
remap=remap, sane_index_shape=sane_index_shape)
|
33 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
34 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
35 |
+
if ckpt_path is not None:
|
36 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
37 |
+
self.image_key = image_key
|
38 |
+
if colorize_nlabels is not None:
|
39 |
+
assert type(colorize_nlabels)==int
|
40 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
41 |
+
if monitor is not None:
|
42 |
+
self.monitor = monitor
|
43 |
+
|
44 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
45 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
46 |
+
keys = list(sd.keys())
|
47 |
+
for k in keys:
|
48 |
+
for ik in ignore_keys:
|
49 |
+
if k.startswith(ik):
|
50 |
+
print("Deleting key {} from state_dict.".format(k))
|
51 |
+
del sd[k]
|
52 |
+
self.load_state_dict(sd, strict=False)
|
53 |
+
print(f"Restored from {path}")
|
54 |
+
|
55 |
+
def encode(self, x):
|
56 |
+
h = self.encoder(x)
|
57 |
+
h = self.quant_conv(h)
|
58 |
+
quant, emb_loss, info = self.quantize(h)
|
59 |
+
return quant, emb_loss, info
|
60 |
+
|
61 |
+
def decode(self, quant):
|
62 |
+
quant = self.post_quant_conv(quant)
|
63 |
+
dec = self.decoder(quant)
|
64 |
+
return dec
|
65 |
+
|
66 |
+
def decode_code(self, code_b):
|
67 |
+
quant_b = self.quantize.embed_code(code_b)
|
68 |
+
dec = self.decode(quant_b)
|
69 |
+
return dec
|
70 |
+
|
71 |
+
def forward(self, input):
|
72 |
+
quant, diff, _ = self.encode(input)
|
73 |
+
dec = self.decode(quant)
|
74 |
+
return dec, diff
|
75 |
+
|
76 |
+
def get_input(self, batch, k):
|
77 |
+
x = batch[k]
|
78 |
+
if len(x.shape) == 3:
|
79 |
+
x = x[..., None]
|
80 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
81 |
+
return x.float()
|
82 |
+
|
83 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
84 |
+
x = self.get_input(batch, self.image_key)
|
85 |
+
xrec, qloss = self(x)
|
86 |
+
|
87 |
+
if optimizer_idx == 0:
|
88 |
+
# autoencode
|
89 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
90 |
+
last_layer=self.get_last_layer(), split="train")
|
91 |
+
|
92 |
+
self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
93 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
94 |
+
return aeloss
|
95 |
+
|
96 |
+
if optimizer_idx == 1:
|
97 |
+
# discriminator
|
98 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
99 |
+
last_layer=self.get_last_layer(), split="train")
|
100 |
+
self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
101 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
102 |
+
return discloss
|
103 |
+
|
104 |
+
def validation_step(self, batch, batch_idx):
|
105 |
+
x = self.get_input(batch, self.image_key)
|
106 |
+
xrec, qloss = self(x)
|
107 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
108 |
+
last_layer=self.get_last_layer(), split="val")
|
109 |
+
|
110 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
111 |
+
last_layer=self.get_last_layer(), split="val")
|
112 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
113 |
+
self.log("val/rec_loss", rec_loss,
|
114 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
115 |
+
self.log("val/aeloss", aeloss,
|
116 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
117 |
+
self.log_dict(log_dict_ae)
|
118 |
+
self.log_dict(log_dict_disc)
|
119 |
+
return self.log_dict
|
120 |
+
|
121 |
+
def configure_optimizers(self):
|
122 |
+
lr = self.learning_rate
|
123 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
124 |
+
list(self.decoder.parameters())+
|
125 |
+
list(self.quantize.parameters())+
|
126 |
+
list(self.quant_conv.parameters())+
|
127 |
+
list(self.post_quant_conv.parameters()),
|
128 |
+
lr=lr, betas=(0.5, 0.9))
|
129 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
130 |
+
lr=lr, betas=(0.5, 0.9))
|
131 |
+
return [opt_ae, opt_disc], []
|
132 |
+
|
133 |
+
def get_last_layer(self):
|
134 |
+
return self.decoder.conv_out.weight
|
135 |
+
|
136 |
+
def log_images(self, batch, **kwargs):
|
137 |
+
log = dict()
|
138 |
+
x = self.get_input(batch, self.image_key)
|
139 |
+
x = x.to(self.device)
|
140 |
+
xrec, _ = self(x)
|
141 |
+
if x.shape[1] > 3:
|
142 |
+
# colorize with random projection
|
143 |
+
assert xrec.shape[1] > 3
|
144 |
+
x = self.to_rgb(x)
|
145 |
+
xrec = self.to_rgb(xrec)
|
146 |
+
log["inputs"] = x
|
147 |
+
log["reconstructions"] = xrec
|
148 |
+
return log
|
149 |
+
|
150 |
+
def to_rgb(self, x):
|
151 |
+
assert self.image_key == "segmentation"
|
152 |
+
if not hasattr(self, "colorize"):
|
153 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
154 |
+
x = F.conv2d(x, weight=self.colorize)
|
155 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class VQSegmentationModel(VQModel):
|
160 |
+
def __init__(self, n_labels, *args, **kwargs):
|
161 |
+
super().__init__(*args, **kwargs)
|
162 |
+
self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
|
163 |
+
|
164 |
+
def configure_optimizers(self):
|
165 |
+
lr = self.learning_rate
|
166 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
167 |
+
list(self.decoder.parameters())+
|
168 |
+
list(self.quantize.parameters())+
|
169 |
+
list(self.quant_conv.parameters())+
|
170 |
+
list(self.post_quant_conv.parameters()),
|
171 |
+
lr=lr, betas=(0.5, 0.9))
|
172 |
+
return opt_ae
|
173 |
+
|
174 |
+
def training_step(self, batch, batch_idx):
|
175 |
+
x = self.get_input(batch, self.image_key)
|
176 |
+
xrec, qloss = self(x)
|
177 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
|
178 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
179 |
+
return aeloss
|
180 |
+
|
181 |
+
def validation_step(self, batch, batch_idx):
|
182 |
+
x = self.get_input(batch, self.image_key)
|
183 |
+
xrec, qloss = self(x)
|
184 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
|
185 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
186 |
+
total_loss = log_dict_ae["val/total_loss"]
|
187 |
+
self.log("val/total_loss", total_loss,
|
188 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
189 |
+
return aeloss
|
190 |
+
|
191 |
+
@torch.no_grad()
|
192 |
+
def log_images(self, batch, **kwargs):
|
193 |
+
log = dict()
|
194 |
+
x = self.get_input(batch, self.image_key)
|
195 |
+
x = x.to(self.device)
|
196 |
+
xrec, _ = self(x)
|
197 |
+
if x.shape[1] > 3:
|
198 |
+
# colorize with random projection
|
199 |
+
assert xrec.shape[1] > 3
|
200 |
+
# convert logits to indices
|
201 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
202 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
203 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
204 |
+
x = self.to_rgb(x)
|
205 |
+
xrec = self.to_rgb(xrec)
|
206 |
+
log["inputs"] = x
|
207 |
+
log["reconstructions"] = xrec
|
208 |
+
return log
|
209 |
+
|
210 |
+
|
211 |
+
class VQNoDiscModel(VQModel):
|
212 |
+
def __init__(self,
|
213 |
+
ddconfig,
|
214 |
+
lossconfig,
|
215 |
+
n_embed,
|
216 |
+
embed_dim,
|
217 |
+
ckpt_path=None,
|
218 |
+
ignore_keys=[],
|
219 |
+
image_key="image",
|
220 |
+
colorize_nlabels=None
|
221 |
+
):
|
222 |
+
super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
|
223 |
+
ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
|
224 |
+
colorize_nlabels=colorize_nlabels)
|
225 |
+
|
226 |
+
def training_step(self, batch, batch_idx):
|
227 |
+
x = self.get_input(batch, self.image_key)
|
228 |
+
xrec, qloss = self(x)
|
229 |
+
# autoencode
|
230 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
|
231 |
+
output = pl.TrainResult(minimize=aeloss)
|
232 |
+
output.log("train/aeloss", aeloss,
|
233 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
234 |
+
output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
235 |
+
return output
|
236 |
+
|
237 |
+
def validation_step(self, batch, batch_idx):
|
238 |
+
x = self.get_input(batch, self.image_key)
|
239 |
+
xrec, qloss = self(x)
|
240 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
|
241 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
242 |
+
output = pl.EvalResult(checkpoint_on=rec_loss)
|
243 |
+
output.log("val/rec_loss", rec_loss,
|
244 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
245 |
+
output.log("val/aeloss", aeloss,
|
246 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
247 |
+
output.log_dict(log_dict_ae)
|
248 |
+
|
249 |
+
return output
|
250 |
+
|
251 |
+
def configure_optimizers(self):
|
252 |
+
optimizer = torch.optim.Adam(list(self.encoder.parameters())+
|
253 |
+
list(self.decoder.parameters())+
|
254 |
+
list(self.quantize.parameters())+
|
255 |
+
list(self.quant_conv.parameters())+
|
256 |
+
list(self.post_quant_conv.parameters()),
|
257 |
+
lr=self.learning_rate, betas=(0.5, 0.9))
|
258 |
+
return optimizer
|
259 |
+
|
260 |
+
|
261 |
+
class GumbelVQ(VQModel):
|
262 |
+
def __init__(self,
|
263 |
+
ddconfig,
|
264 |
+
lossconfig,
|
265 |
+
n_embed,
|
266 |
+
embed_dim,
|
267 |
+
temperature_scheduler_config,
|
268 |
+
ckpt_path=None,
|
269 |
+
ignore_keys=[],
|
270 |
+
image_key="image",
|
271 |
+
colorize_nlabels=None,
|
272 |
+
monitor=None,
|
273 |
+
kl_weight=1e-8,
|
274 |
+
remap=None,
|
275 |
+
):
|
276 |
+
|
277 |
+
z_channels = ddconfig["z_channels"]
|
278 |
+
super().__init__(ddconfig,
|
279 |
+
lossconfig,
|
280 |
+
n_embed,
|
281 |
+
embed_dim,
|
282 |
+
ckpt_path=None,
|
283 |
+
ignore_keys=ignore_keys,
|
284 |
+
image_key=image_key,
|
285 |
+
colorize_nlabels=colorize_nlabels,
|
286 |
+
monitor=monitor,
|
287 |
+
)
|
288 |
+
|
289 |
+
self.loss.n_classes = n_embed
|
290 |
+
self.vocab_size = n_embed
|
291 |
+
|
292 |
+
self.quantize = GumbelQuantize(z_channels, embed_dim,
|
293 |
+
n_embed=n_embed,
|
294 |
+
kl_weight=kl_weight, temp_init=1.0,
|
295 |
+
remap=remap)
|
296 |
+
|
297 |
+
self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
|
298 |
+
|
299 |
+
if ckpt_path is not None:
|
300 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
301 |
+
|
302 |
+
def temperature_scheduling(self):
|
303 |
+
self.quantize.temperature = self.temperature_scheduler(self.global_step)
|
304 |
+
|
305 |
+
def encode_to_prequant(self, x):
|
306 |
+
h = self.encoder(x)
|
307 |
+
h = self.quant_conv(h)
|
308 |
+
return h
|
309 |
+
|
310 |
+
def decode_code(self, code_b):
|
311 |
+
raise NotImplementedError
|
312 |
+
|
313 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
314 |
+
self.temperature_scheduling()
|
315 |
+
x = self.get_input(batch, self.image_key)
|
316 |
+
xrec, qloss = self(x)
|
317 |
+
|
318 |
+
if optimizer_idx == 0:
|
319 |
+
# autoencode
|
320 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
321 |
+
last_layer=self.get_last_layer(), split="train")
|
322 |
+
|
323 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
324 |
+
self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
325 |
+
return aeloss
|
326 |
+
|
327 |
+
if optimizer_idx == 1:
|
328 |
+
# discriminator
|
329 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
330 |
+
last_layer=self.get_last_layer(), split="train")
|
331 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
332 |
+
return discloss
|
333 |
+
|
334 |
+
def validation_step(self, batch, batch_idx):
|
335 |
+
x = self.get_input(batch, self.image_key)
|
336 |
+
xrec, qloss = self(x, return_pred_indices=True)
|
337 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
338 |
+
last_layer=self.get_last_layer(), split="val")
|
339 |
+
|
340 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
341 |
+
last_layer=self.get_last_layer(), split="val")
|
342 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
343 |
+
self.log("val/rec_loss", rec_loss,
|
344 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
345 |
+
self.log("val/aeloss", aeloss,
|
346 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
347 |
+
self.log_dict(log_dict_ae)
|
348 |
+
self.log_dict(log_dict_disc)
|
349 |
+
return self.log_dict
|
350 |
+
|
351 |
+
def log_images(self, batch, **kwargs):
|
352 |
+
log = dict()
|
353 |
+
x = self.get_input(batch, self.image_key)
|
354 |
+
x = x.to(self.device)
|
355 |
+
# encode
|
356 |
+
h = self.encoder(x)
|
357 |
+
h = self.quant_conv(h)
|
358 |
+
quant, _, _ = self.quantize(h)
|
359 |
+
# decode
|
360 |
+
x_rec = self.decode(quant)
|
361 |
+
log["inputs"] = x
|
362 |
+
log["reconstructions"] = x_rec
|
363 |
+
return log
|
364 |
+
|
365 |
+
|
366 |
+
class EMAVQ(VQModel):
|
367 |
+
def __init__(self,
|
368 |
+
ddconfig,
|
369 |
+
lossconfig,
|
370 |
+
n_embed,
|
371 |
+
embed_dim,
|
372 |
+
ckpt_path=None,
|
373 |
+
ignore_keys=[],
|
374 |
+
image_key="image",
|
375 |
+
colorize_nlabels=None,
|
376 |
+
monitor=None,
|
377 |
+
remap=None,
|
378 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
379 |
+
):
|
380 |
+
super().__init__(ddconfig,
|
381 |
+
lossconfig,
|
382 |
+
n_embed,
|
383 |
+
embed_dim,
|
384 |
+
ckpt_path=None,
|
385 |
+
ignore_keys=ignore_keys,
|
386 |
+
image_key=image_key,
|
387 |
+
colorize_nlabels=colorize_nlabels,
|
388 |
+
monitor=monitor,
|
389 |
+
)
|
390 |
+
self.quantize = EMAVectorQuantizer(n_embed=n_embed,
|
391 |
+
embedding_dim=embed_dim,
|
392 |
+
beta=0.25,
|
393 |
+
remap=remap)
|
394 |
+
def configure_optimizers(self):
|
395 |
+
lr = self.learning_rate
|
396 |
+
#Remove self.quantize from parameter list since it is updated via EMA
|
397 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
398 |
+
list(self.decoder.parameters())+
|
399 |
+
list(self.quant_conv.parameters())+
|
400 |
+
list(self.post_quant_conv.parameters()),
|
401 |
+
lr=lr, betas=(0.5, 0.9))
|
402 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
403 |
+
lr=lr, betas=(0.5, 0.9))
|
404 |
+
return [opt_ae, opt_disc], []
|
taming/modules/__pycache__/util.cpython-38.pyc
ADDED
Binary file (4.28 kB). View file
|
|
taming/modules/autoencoder/lpips/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
|
3 |
+
size 7289
|
taming/modules/diffusionmodules/model.py
ADDED
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytorch_diffusion + derived encoder decoder
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
9 |
+
"""
|
10 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
11 |
+
From Fairseq.
|
12 |
+
Build sinusoidal embeddings.
|
13 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
14 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
15 |
+
"""
|
16 |
+
assert len(timesteps.shape) == 1
|
17 |
+
|
18 |
+
half_dim = embedding_dim // 2
|
19 |
+
emb = math.log(10000) / (half_dim - 1)
|
20 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
21 |
+
emb = emb.to(device=timesteps.device)
|
22 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
23 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
24 |
+
if embedding_dim % 2 == 1: # zero pad
|
25 |
+
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
26 |
+
return emb
|
27 |
+
|
28 |
+
|
29 |
+
def nonlinearity(x):
|
30 |
+
# swish
|
31 |
+
return x*torch.sigmoid(x)
|
32 |
+
|
33 |
+
|
34 |
+
def Normalize(in_channels):
|
35 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
36 |
+
|
37 |
+
|
38 |
+
class Upsample(nn.Module):
|
39 |
+
def __init__(self, in_channels, with_conv):
|
40 |
+
super().__init__()
|
41 |
+
self.with_conv = with_conv
|
42 |
+
if self.with_conv:
|
43 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
44 |
+
in_channels,
|
45 |
+
kernel_size=3,
|
46 |
+
stride=1,
|
47 |
+
padding=1)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
51 |
+
if self.with_conv:
|
52 |
+
x = self.conv(x)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
class Downsample(nn.Module):
|
57 |
+
def __init__(self, in_channels, with_conv):
|
58 |
+
super().__init__()
|
59 |
+
self.with_conv = with_conv
|
60 |
+
if self.with_conv:
|
61 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
62 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
63 |
+
in_channels,
|
64 |
+
kernel_size=3,
|
65 |
+
stride=2,
|
66 |
+
padding=0)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
if self.with_conv:
|
70 |
+
pad = (0,1,0,1)
|
71 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
72 |
+
x = self.conv(x)
|
73 |
+
else:
|
74 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
class ResnetBlock(nn.Module):
|
79 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
80 |
+
dropout, temb_channels=512):
|
81 |
+
super().__init__()
|
82 |
+
self.in_channels = in_channels
|
83 |
+
out_channels = in_channels if out_channels is None else out_channels
|
84 |
+
self.out_channels = out_channels
|
85 |
+
self.use_conv_shortcut = conv_shortcut
|
86 |
+
|
87 |
+
self.norm1 = Normalize(in_channels)
|
88 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
89 |
+
out_channels,
|
90 |
+
kernel_size=3,
|
91 |
+
stride=1,
|
92 |
+
padding=1)
|
93 |
+
if temb_channels > 0:
|
94 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
95 |
+
out_channels)
|
96 |
+
self.norm2 = Normalize(out_channels)
|
97 |
+
self.dropout = torch.nn.Dropout(dropout)
|
98 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
99 |
+
out_channels,
|
100 |
+
kernel_size=3,
|
101 |
+
stride=1,
|
102 |
+
padding=1)
|
103 |
+
if self.in_channels != self.out_channels:
|
104 |
+
if self.use_conv_shortcut:
|
105 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
106 |
+
out_channels,
|
107 |
+
kernel_size=3,
|
108 |
+
stride=1,
|
109 |
+
padding=1)
|
110 |
+
else:
|
111 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
112 |
+
out_channels,
|
113 |
+
kernel_size=1,
|
114 |
+
stride=1,
|
115 |
+
padding=0)
|
116 |
+
|
117 |
+
def forward(self, x, temb):
|
118 |
+
h = x
|
119 |
+
h = self.norm1(h)
|
120 |
+
h = nonlinearity(h)
|
121 |
+
h = self.conv1(h)
|
122 |
+
|
123 |
+
if temb is not None:
|
124 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
125 |
+
|
126 |
+
h = self.norm2(h)
|
127 |
+
h = nonlinearity(h)
|
128 |
+
h = self.dropout(h)
|
129 |
+
h = self.conv2(h)
|
130 |
+
|
131 |
+
if self.in_channels != self.out_channels:
|
132 |
+
if self.use_conv_shortcut:
|
133 |
+
x = self.conv_shortcut(x)
|
134 |
+
else:
|
135 |
+
x = self.nin_shortcut(x)
|
136 |
+
|
137 |
+
return x+h
|
138 |
+
|
139 |
+
|
140 |
+
class AttnBlock(nn.Module):
|
141 |
+
def __init__(self, in_channels):
|
142 |
+
super().__init__()
|
143 |
+
self.in_channels = in_channels
|
144 |
+
|
145 |
+
self.norm = Normalize(in_channels)
|
146 |
+
self.q = torch.nn.Conv2d(in_channels,
|
147 |
+
in_channels,
|
148 |
+
kernel_size=1,
|
149 |
+
stride=1,
|
150 |
+
padding=0)
|
151 |
+
self.k = torch.nn.Conv2d(in_channels,
|
152 |
+
in_channels,
|
153 |
+
kernel_size=1,
|
154 |
+
stride=1,
|
155 |
+
padding=0)
|
156 |
+
self.v = torch.nn.Conv2d(in_channels,
|
157 |
+
in_channels,
|
158 |
+
kernel_size=1,
|
159 |
+
stride=1,
|
160 |
+
padding=0)
|
161 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
162 |
+
in_channels,
|
163 |
+
kernel_size=1,
|
164 |
+
stride=1,
|
165 |
+
padding=0)
|
166 |
+
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
h_ = x
|
170 |
+
h_ = self.norm(h_)
|
171 |
+
q = self.q(h_)
|
172 |
+
k = self.k(h_)
|
173 |
+
v = self.v(h_)
|
174 |
+
|
175 |
+
# compute attention
|
176 |
+
b,c,h,w = q.shape
|
177 |
+
q = q.reshape(b,c,h*w)
|
178 |
+
q = q.permute(0,2,1) # b,hw,c
|
179 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
180 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
181 |
+
w_ = w_ * (int(c)**(-0.5))
|
182 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
183 |
+
|
184 |
+
# attend to values
|
185 |
+
v = v.reshape(b,c,h*w)
|
186 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
187 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
188 |
+
h_ = h_.reshape(b,c,h,w)
|
189 |
+
|
190 |
+
h_ = self.proj_out(h_)
|
191 |
+
|
192 |
+
return x+h_
|
193 |
+
|
194 |
+
|
195 |
+
class Model(nn.Module):
|
196 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
197 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
198 |
+
resolution, use_timestep=True):
|
199 |
+
super().__init__()
|
200 |
+
self.ch = ch
|
201 |
+
self.temb_ch = self.ch*4
|
202 |
+
self.num_resolutions = len(ch_mult)
|
203 |
+
self.num_res_blocks = num_res_blocks
|
204 |
+
self.resolution = resolution
|
205 |
+
self.in_channels = in_channels
|
206 |
+
|
207 |
+
self.use_timestep = use_timestep
|
208 |
+
if self.use_timestep:
|
209 |
+
# timestep embedding
|
210 |
+
self.temb = nn.Module()
|
211 |
+
self.temb.dense = nn.ModuleList([
|
212 |
+
torch.nn.Linear(self.ch,
|
213 |
+
self.temb_ch),
|
214 |
+
torch.nn.Linear(self.temb_ch,
|
215 |
+
self.temb_ch),
|
216 |
+
])
|
217 |
+
|
218 |
+
# downsampling
|
219 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
220 |
+
self.ch,
|
221 |
+
kernel_size=3,
|
222 |
+
stride=1,
|
223 |
+
padding=1)
|
224 |
+
|
225 |
+
curr_res = resolution
|
226 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
227 |
+
self.down = nn.ModuleList()
|
228 |
+
for i_level in range(self.num_resolutions):
|
229 |
+
block = nn.ModuleList()
|
230 |
+
attn = nn.ModuleList()
|
231 |
+
block_in = ch*in_ch_mult[i_level]
|
232 |
+
block_out = ch*ch_mult[i_level]
|
233 |
+
for i_block in range(self.num_res_blocks):
|
234 |
+
block.append(ResnetBlock(in_channels=block_in,
|
235 |
+
out_channels=block_out,
|
236 |
+
temb_channels=self.temb_ch,
|
237 |
+
dropout=dropout))
|
238 |
+
block_in = block_out
|
239 |
+
if curr_res in attn_resolutions:
|
240 |
+
attn.append(AttnBlock(block_in))
|
241 |
+
down = nn.Module()
|
242 |
+
down.block = block
|
243 |
+
down.attn = attn
|
244 |
+
if i_level != self.num_resolutions-1:
|
245 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
246 |
+
curr_res = curr_res // 2
|
247 |
+
self.down.append(down)
|
248 |
+
|
249 |
+
# middle
|
250 |
+
self.mid = nn.Module()
|
251 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
252 |
+
out_channels=block_in,
|
253 |
+
temb_channels=self.temb_ch,
|
254 |
+
dropout=dropout)
|
255 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
256 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
257 |
+
out_channels=block_in,
|
258 |
+
temb_channels=self.temb_ch,
|
259 |
+
dropout=dropout)
|
260 |
+
|
261 |
+
# upsampling
|
262 |
+
self.up = nn.ModuleList()
|
263 |
+
for i_level in reversed(range(self.num_resolutions)):
|
264 |
+
block = nn.ModuleList()
|
265 |
+
attn = nn.ModuleList()
|
266 |
+
block_out = ch*ch_mult[i_level]
|
267 |
+
skip_in = ch*ch_mult[i_level]
|
268 |
+
for i_block in range(self.num_res_blocks+1):
|
269 |
+
if i_block == self.num_res_blocks:
|
270 |
+
skip_in = ch*in_ch_mult[i_level]
|
271 |
+
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
272 |
+
out_channels=block_out,
|
273 |
+
temb_channels=self.temb_ch,
|
274 |
+
dropout=dropout))
|
275 |
+
block_in = block_out
|
276 |
+
if curr_res in attn_resolutions:
|
277 |
+
attn.append(AttnBlock(block_in))
|
278 |
+
up = nn.Module()
|
279 |
+
up.block = block
|
280 |
+
up.attn = attn
|
281 |
+
if i_level != 0:
|
282 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
283 |
+
curr_res = curr_res * 2
|
284 |
+
self.up.insert(0, up) # prepend to get consistent order
|
285 |
+
|
286 |
+
# end
|
287 |
+
self.norm_out = Normalize(block_in)
|
288 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
289 |
+
out_ch,
|
290 |
+
kernel_size=3,
|
291 |
+
stride=1,
|
292 |
+
padding=1)
|
293 |
+
|
294 |
+
|
295 |
+
def forward(self, x, t=None):
|
296 |
+
#assert x.shape[2] == x.shape[3] == self.resolution
|
297 |
+
|
298 |
+
if self.use_timestep:
|
299 |
+
# timestep embedding
|
300 |
+
assert t is not None
|
301 |
+
temb = get_timestep_embedding(t, self.ch)
|
302 |
+
temb = self.temb.dense[0](temb)
|
303 |
+
temb = nonlinearity(temb)
|
304 |
+
temb = self.temb.dense[1](temb)
|
305 |
+
else:
|
306 |
+
temb = None
|
307 |
+
|
308 |
+
# downsampling
|
309 |
+
hs = [self.conv_in(x)]
|
310 |
+
for i_level in range(self.num_resolutions):
|
311 |
+
for i_block in range(self.num_res_blocks):
|
312 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
313 |
+
if len(self.down[i_level].attn) > 0:
|
314 |
+
h = self.down[i_level].attn[i_block](h)
|
315 |
+
hs.append(h)
|
316 |
+
if i_level != self.num_resolutions-1:
|
317 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
318 |
+
|
319 |
+
# middle
|
320 |
+
h = hs[-1]
|
321 |
+
h = self.mid.block_1(h, temb)
|
322 |
+
h = self.mid.attn_1(h)
|
323 |
+
h = self.mid.block_2(h, temb)
|
324 |
+
|
325 |
+
# upsampling
|
326 |
+
for i_level in reversed(range(self.num_resolutions)):
|
327 |
+
for i_block in range(self.num_res_blocks+1):
|
328 |
+
h = self.up[i_level].block[i_block](
|
329 |
+
torch.cat([h, hs.pop()], dim=1), temb)
|
330 |
+
if len(self.up[i_level].attn) > 0:
|
331 |
+
h = self.up[i_level].attn[i_block](h)
|
332 |
+
if i_level != 0:
|
333 |
+
h = self.up[i_level].upsample(h)
|
334 |
+
|
335 |
+
# end
|
336 |
+
h = self.norm_out(h)
|
337 |
+
h = nonlinearity(h)
|
338 |
+
h = self.conv_out(h)
|
339 |
+
return h
|
340 |
+
|
341 |
+
|
342 |
+
class Encoder(nn.Module):
|
343 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
344 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
345 |
+
resolution, z_channels, double_z=True, **ignore_kwargs):
|
346 |
+
super().__init__()
|
347 |
+
self.ch = ch
|
348 |
+
self.temb_ch = 0
|
349 |
+
self.num_resolutions = len(ch_mult)
|
350 |
+
self.num_res_blocks = num_res_blocks
|
351 |
+
self.resolution = resolution
|
352 |
+
self.in_channels = in_channels
|
353 |
+
|
354 |
+
# downsampling
|
355 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
356 |
+
self.ch,
|
357 |
+
kernel_size=3,
|
358 |
+
stride=1,
|
359 |
+
padding=1)
|
360 |
+
|
361 |
+
curr_res = resolution
|
362 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
363 |
+
self.down = nn.ModuleList()
|
364 |
+
for i_level in range(self.num_resolutions):
|
365 |
+
block = nn.ModuleList()
|
366 |
+
attn = nn.ModuleList()
|
367 |
+
block_in = ch*in_ch_mult[i_level]
|
368 |
+
block_out = ch*ch_mult[i_level]
|
369 |
+
for i_block in range(self.num_res_blocks):
|
370 |
+
block.append(ResnetBlock(in_channels=block_in,
|
371 |
+
out_channels=block_out,
|
372 |
+
temb_channels=self.temb_ch,
|
373 |
+
dropout=dropout))
|
374 |
+
block_in = block_out
|
375 |
+
if curr_res in attn_resolutions:
|
376 |
+
attn.append(AttnBlock(block_in))
|
377 |
+
down = nn.Module()
|
378 |
+
down.block = block
|
379 |
+
down.attn = attn
|
380 |
+
if i_level != self.num_resolutions-1:
|
381 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
382 |
+
curr_res = curr_res // 2
|
383 |
+
self.down.append(down)
|
384 |
+
|
385 |
+
# middle
|
386 |
+
self.mid = nn.Module()
|
387 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
388 |
+
out_channels=block_in,
|
389 |
+
temb_channels=self.temb_ch,
|
390 |
+
dropout=dropout)
|
391 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
392 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
393 |
+
out_channels=block_in,
|
394 |
+
temb_channels=self.temb_ch,
|
395 |
+
dropout=dropout)
|
396 |
+
|
397 |
+
# end
|
398 |
+
self.norm_out = Normalize(block_in)
|
399 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
400 |
+
2*z_channels if double_z else z_channels,
|
401 |
+
kernel_size=3,
|
402 |
+
stride=1,
|
403 |
+
padding=1)
|
404 |
+
|
405 |
+
|
406 |
+
def forward(self, x):
|
407 |
+
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
408 |
+
|
409 |
+
# timestep embedding
|
410 |
+
temb = None
|
411 |
+
|
412 |
+
# downsampling
|
413 |
+
hs = [self.conv_in(x)]
|
414 |
+
for i_level in range(self.num_resolutions):
|
415 |
+
for i_block in range(self.num_res_blocks):
|
416 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
417 |
+
if len(self.down[i_level].attn) > 0:
|
418 |
+
h = self.down[i_level].attn[i_block](h)
|
419 |
+
hs.append(h)
|
420 |
+
if i_level != self.num_resolutions-1:
|
421 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
422 |
+
|
423 |
+
# middle
|
424 |
+
h = hs[-1]
|
425 |
+
h = self.mid.block_1(h, temb)
|
426 |
+
h = self.mid.attn_1(h)
|
427 |
+
h = self.mid.block_2(h, temb)
|
428 |
+
|
429 |
+
# end
|
430 |
+
h = self.norm_out(h)
|
431 |
+
h = nonlinearity(h)
|
432 |
+
h = self.conv_out(h)
|
433 |
+
return h
|
434 |
+
|
435 |
+
|
436 |
+
class Decoder(nn.Module):
|
437 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
438 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
439 |
+
resolution, z_channels, give_pre_end=False, **ignorekwargs):
|
440 |
+
super().__init__()
|
441 |
+
self.ch = ch
|
442 |
+
self.temb_ch = 0
|
443 |
+
self.num_resolutions = len(ch_mult)
|
444 |
+
self.num_res_blocks = num_res_blocks
|
445 |
+
self.resolution = resolution
|
446 |
+
self.in_channels = in_channels
|
447 |
+
self.give_pre_end = give_pre_end
|
448 |
+
|
449 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
450 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
451 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
452 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
453 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
454 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
455 |
+
self.z_shape, np.prod(self.z_shape)))
|
456 |
+
|
457 |
+
# z to block_in
|
458 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
459 |
+
block_in,
|
460 |
+
kernel_size=3,
|
461 |
+
stride=1,
|
462 |
+
padding=1)
|
463 |
+
|
464 |
+
# middle
|
465 |
+
self.mid = nn.Module()
|
466 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
467 |
+
out_channels=block_in,
|
468 |
+
temb_channels=self.temb_ch,
|
469 |
+
dropout=dropout)
|
470 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
471 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
472 |
+
out_channels=block_in,
|
473 |
+
temb_channels=self.temb_ch,
|
474 |
+
dropout=dropout)
|
475 |
+
|
476 |
+
# upsampling
|
477 |
+
self.up = nn.ModuleList()
|
478 |
+
for i_level in reversed(range(self.num_resolutions)):
|
479 |
+
block = nn.ModuleList()
|
480 |
+
attn = nn.ModuleList()
|
481 |
+
block_out = ch*ch_mult[i_level]
|
482 |
+
for i_block in range(self.num_res_blocks+1):
|
483 |
+
block.append(ResnetBlock(in_channels=block_in,
|
484 |
+
out_channels=block_out,
|
485 |
+
temb_channels=self.temb_ch,
|
486 |
+
dropout=dropout))
|
487 |
+
block_in = block_out
|
488 |
+
if curr_res in attn_resolutions:
|
489 |
+
attn.append(AttnBlock(block_in))
|
490 |
+
up = nn.Module()
|
491 |
+
up.block = block
|
492 |
+
up.attn = attn
|
493 |
+
if i_level != 0:
|
494 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
495 |
+
curr_res = curr_res * 2
|
496 |
+
self.up.insert(0, up) # prepend to get consistent order
|
497 |
+
|
498 |
+
# end
|
499 |
+
self.norm_out = Normalize(block_in)
|
500 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
501 |
+
out_ch,
|
502 |
+
kernel_size=3,
|
503 |
+
stride=1,
|
504 |
+
padding=1)
|
505 |
+
|
506 |
+
def forward(self, z):
|
507 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
508 |
+
self.last_z_shape = z.shape
|
509 |
+
|
510 |
+
# timestep embedding
|
511 |
+
temb = None
|
512 |
+
|
513 |
+
# z to block_in
|
514 |
+
h = self.conv_in(z)
|
515 |
+
|
516 |
+
# middle
|
517 |
+
h = self.mid.block_1(h, temb)
|
518 |
+
h = self.mid.attn_1(h)
|
519 |
+
h = self.mid.block_2(h, temb)
|
520 |
+
|
521 |
+
# upsampling
|
522 |
+
for i_level in reversed(range(self.num_resolutions)):
|
523 |
+
for i_block in range(self.num_res_blocks+1):
|
524 |
+
h = self.up[i_level].block[i_block](h, temb)
|
525 |
+
if len(self.up[i_level].attn) > 0:
|
526 |
+
h = self.up[i_level].attn[i_block](h)
|
527 |
+
if i_level != 0:
|
528 |
+
h = self.up[i_level].upsample(h)
|
529 |
+
|
530 |
+
# end
|
531 |
+
if self.give_pre_end:
|
532 |
+
return h
|
533 |
+
|
534 |
+
h = self.norm_out(h)
|
535 |
+
h = nonlinearity(h)
|
536 |
+
h = self.conv_out(h)
|
537 |
+
return h
|
538 |
+
|
539 |
+
|
540 |
+
class VUNet(nn.Module):
|
541 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
542 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
543 |
+
in_channels, c_channels,
|
544 |
+
resolution, z_channels, use_timestep=False, **ignore_kwargs):
|
545 |
+
super().__init__()
|
546 |
+
self.ch = ch
|
547 |
+
self.temb_ch = self.ch*4
|
548 |
+
self.num_resolutions = len(ch_mult)
|
549 |
+
self.num_res_blocks = num_res_blocks
|
550 |
+
self.resolution = resolution
|
551 |
+
|
552 |
+
self.use_timestep = use_timestep
|
553 |
+
if self.use_timestep:
|
554 |
+
# timestep embedding
|
555 |
+
self.temb = nn.Module()
|
556 |
+
self.temb.dense = nn.ModuleList([
|
557 |
+
torch.nn.Linear(self.ch,
|
558 |
+
self.temb_ch),
|
559 |
+
torch.nn.Linear(self.temb_ch,
|
560 |
+
self.temb_ch),
|
561 |
+
])
|
562 |
+
|
563 |
+
# downsampling
|
564 |
+
self.conv_in = torch.nn.Conv2d(c_channels,
|
565 |
+
self.ch,
|
566 |
+
kernel_size=3,
|
567 |
+
stride=1,
|
568 |
+
padding=1)
|
569 |
+
|
570 |
+
curr_res = resolution
|
571 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
572 |
+
self.down = nn.ModuleList()
|
573 |
+
for i_level in range(self.num_resolutions):
|
574 |
+
block = nn.ModuleList()
|
575 |
+
attn = nn.ModuleList()
|
576 |
+
block_in = ch*in_ch_mult[i_level]
|
577 |
+
block_out = ch*ch_mult[i_level]
|
578 |
+
for i_block in range(self.num_res_blocks):
|
579 |
+
block.append(ResnetBlock(in_channels=block_in,
|
580 |
+
out_channels=block_out,
|
581 |
+
temb_channels=self.temb_ch,
|
582 |
+
dropout=dropout))
|
583 |
+
block_in = block_out
|
584 |
+
if curr_res in attn_resolutions:
|
585 |
+
attn.append(AttnBlock(block_in))
|
586 |
+
down = nn.Module()
|
587 |
+
down.block = block
|
588 |
+
down.attn = attn
|
589 |
+
if i_level != self.num_resolutions-1:
|
590 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
591 |
+
curr_res = curr_res // 2
|
592 |
+
self.down.append(down)
|
593 |
+
|
594 |
+
self.z_in = torch.nn.Conv2d(z_channels,
|
595 |
+
block_in,
|
596 |
+
kernel_size=1,
|
597 |
+
stride=1,
|
598 |
+
padding=0)
|
599 |
+
# middle
|
600 |
+
self.mid = nn.Module()
|
601 |
+
self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
|
602 |
+
out_channels=block_in,
|
603 |
+
temb_channels=self.temb_ch,
|
604 |
+
dropout=dropout)
|
605 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
606 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
607 |
+
out_channels=block_in,
|
608 |
+
temb_channels=self.temb_ch,
|
609 |
+
dropout=dropout)
|
610 |
+
|
611 |
+
# upsampling
|
612 |
+
self.up = nn.ModuleList()
|
613 |
+
for i_level in reversed(range(self.num_resolutions)):
|
614 |
+
block = nn.ModuleList()
|
615 |
+
attn = nn.ModuleList()
|
616 |
+
block_out = ch*ch_mult[i_level]
|
617 |
+
skip_in = ch*ch_mult[i_level]
|
618 |
+
for i_block in range(self.num_res_blocks+1):
|
619 |
+
if i_block == self.num_res_blocks:
|
620 |
+
skip_in = ch*in_ch_mult[i_level]
|
621 |
+
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
622 |
+
out_channels=block_out,
|
623 |
+
temb_channels=self.temb_ch,
|
624 |
+
dropout=dropout))
|
625 |
+
block_in = block_out
|
626 |
+
if curr_res in attn_resolutions:
|
627 |
+
attn.append(AttnBlock(block_in))
|
628 |
+
up = nn.Module()
|
629 |
+
up.block = block
|
630 |
+
up.attn = attn
|
631 |
+
if i_level != 0:
|
632 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
633 |
+
curr_res = curr_res * 2
|
634 |
+
self.up.insert(0, up) # prepend to get consistent order
|
635 |
+
|
636 |
+
# end
|
637 |
+
self.norm_out = Normalize(block_in)
|
638 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
639 |
+
out_ch,
|
640 |
+
kernel_size=3,
|
641 |
+
stride=1,
|
642 |
+
padding=1)
|
643 |
+
|
644 |
+
|
645 |
+
def forward(self, x, z):
|
646 |
+
#assert x.shape[2] == x.shape[3] == self.resolution
|
647 |
+
|
648 |
+
if self.use_timestep:
|
649 |
+
# timestep embedding
|
650 |
+
assert t is not None
|
651 |
+
temb = get_timestep_embedding(t, self.ch)
|
652 |
+
temb = self.temb.dense[0](temb)
|
653 |
+
temb = nonlinearity(temb)
|
654 |
+
temb = self.temb.dense[1](temb)
|
655 |
+
else:
|
656 |
+
temb = None
|
657 |
+
|
658 |
+
# downsampling
|
659 |
+
hs = [self.conv_in(x)]
|
660 |
+
for i_level in range(self.num_resolutions):
|
661 |
+
for i_block in range(self.num_res_blocks):
|
662 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
663 |
+
if len(self.down[i_level].attn) > 0:
|
664 |
+
h = self.down[i_level].attn[i_block](h)
|
665 |
+
hs.append(h)
|
666 |
+
if i_level != self.num_resolutions-1:
|
667 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
668 |
+
|
669 |
+
# middle
|
670 |
+
h = hs[-1]
|
671 |
+
z = self.z_in(z)
|
672 |
+
h = torch.cat((h,z),dim=1)
|
673 |
+
h = self.mid.block_1(h, temb)
|
674 |
+
h = self.mid.attn_1(h)
|
675 |
+
h = self.mid.block_2(h, temb)
|
676 |
+
|
677 |
+
# upsampling
|
678 |
+
for i_level in reversed(range(self.num_resolutions)):
|
679 |
+
for i_block in range(self.num_res_blocks+1):
|
680 |
+
h = self.up[i_level].block[i_block](
|
681 |
+
torch.cat([h, hs.pop()], dim=1), temb)
|
682 |
+
if len(self.up[i_level].attn) > 0:
|
683 |
+
h = self.up[i_level].attn[i_block](h)
|
684 |
+
if i_level != 0:
|
685 |
+
h = self.up[i_level].upsample(h)
|
686 |
+
|
687 |
+
# end
|
688 |
+
h = self.norm_out(h)
|
689 |
+
h = nonlinearity(h)
|
690 |
+
h = self.conv_out(h)
|
691 |
+
return h
|
692 |
+
|
693 |
+
|
694 |
+
class SimpleDecoder(nn.Module):
|
695 |
+
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
696 |
+
super().__init__()
|
697 |
+
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
698 |
+
ResnetBlock(in_channels=in_channels,
|
699 |
+
out_channels=2 * in_channels,
|
700 |
+
temb_channels=0, dropout=0.0),
|
701 |
+
ResnetBlock(in_channels=2 * in_channels,
|
702 |
+
out_channels=4 * in_channels,
|
703 |
+
temb_channels=0, dropout=0.0),
|
704 |
+
ResnetBlock(in_channels=4 * in_channels,
|
705 |
+
out_channels=2 * in_channels,
|
706 |
+
temb_channels=0, dropout=0.0),
|
707 |
+
nn.Conv2d(2*in_channels, in_channels, 1),
|
708 |
+
Upsample(in_channels, with_conv=True)])
|
709 |
+
# end
|
710 |
+
self.norm_out = Normalize(in_channels)
|
711 |
+
self.conv_out = torch.nn.Conv2d(in_channels,
|
712 |
+
out_channels,
|
713 |
+
kernel_size=3,
|
714 |
+
stride=1,
|
715 |
+
padding=1)
|
716 |
+
|
717 |
+
def forward(self, x):
|
718 |
+
for i, layer in enumerate(self.model):
|
719 |
+
if i in [1,2,3]:
|
720 |
+
x = layer(x, None)
|
721 |
+
else:
|
722 |
+
x = layer(x)
|
723 |
+
|
724 |
+
h = self.norm_out(x)
|
725 |
+
h = nonlinearity(h)
|
726 |
+
x = self.conv_out(h)
|
727 |
+
return x
|
728 |
+
|
729 |
+
|
730 |
+
class UpsampleDecoder(nn.Module):
|
731 |
+
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
732 |
+
ch_mult=(2,2), dropout=0.0):
|
733 |
+
super().__init__()
|
734 |
+
# upsampling
|
735 |
+
self.temb_ch = 0
|
736 |
+
self.num_resolutions = len(ch_mult)
|
737 |
+
self.num_res_blocks = num_res_blocks
|
738 |
+
block_in = in_channels
|
739 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
740 |
+
self.res_blocks = nn.ModuleList()
|
741 |
+
self.upsample_blocks = nn.ModuleList()
|
742 |
+
for i_level in range(self.num_resolutions):
|
743 |
+
res_block = []
|
744 |
+
block_out = ch * ch_mult[i_level]
|
745 |
+
for i_block in range(self.num_res_blocks + 1):
|
746 |
+
res_block.append(ResnetBlock(in_channels=block_in,
|
747 |
+
out_channels=block_out,
|
748 |
+
temb_channels=self.temb_ch,
|
749 |
+
dropout=dropout))
|
750 |
+
block_in = block_out
|
751 |
+
self.res_blocks.append(nn.ModuleList(res_block))
|
752 |
+
if i_level != self.num_resolutions - 1:
|
753 |
+
self.upsample_blocks.append(Upsample(block_in, True))
|
754 |
+
curr_res = curr_res * 2
|
755 |
+
|
756 |
+
# end
|
757 |
+
self.norm_out = Normalize(block_in)
|
758 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
759 |
+
out_channels,
|
760 |
+
kernel_size=3,
|
761 |
+
stride=1,
|
762 |
+
padding=1)
|
763 |
+
|
764 |
+
def forward(self, x):
|
765 |
+
# upsampling
|
766 |
+
h = x
|
767 |
+
for k, i_level in enumerate(range(self.num_resolutions)):
|
768 |
+
for i_block in range(self.num_res_blocks + 1):
|
769 |
+
h = self.res_blocks[i_level][i_block](h, None)
|
770 |
+
if i_level != self.num_resolutions - 1:
|
771 |
+
h = self.upsample_blocks[k](h)
|
772 |
+
h = self.norm_out(h)
|
773 |
+
h = nonlinearity(h)
|
774 |
+
h = self.conv_out(h)
|
775 |
+
return h
|
776 |
+
|
taming/modules/discriminator/__pycache__/model.cpython-38.pyc
ADDED
Binary file (2.34 kB). View file
|
|
taming/modules/discriminator/model.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
from taming.modules.util import ActNorm
|
6 |
+
|
7 |
+
|
8 |
+
def weights_init(m):
|
9 |
+
classname = m.__class__.__name__
|
10 |
+
if classname.find('Conv') != -1:
|
11 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
12 |
+
elif classname.find('BatchNorm') != -1:
|
13 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
14 |
+
nn.init.constant_(m.bias.data, 0)
|
15 |
+
|
16 |
+
|
17 |
+
class NLayerDiscriminator(nn.Module):
|
18 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
19 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
20 |
+
"""
|
21 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
22 |
+
"""Construct a PatchGAN discriminator
|
23 |
+
Parameters:
|
24 |
+
input_nc (int) -- the number of channels in input images
|
25 |
+
ndf (int) -- the number of filters in the last conv layer
|
26 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
27 |
+
norm_layer -- normalization layer
|
28 |
+
"""
|
29 |
+
super(NLayerDiscriminator, self).__init__()
|
30 |
+
if not use_actnorm:
|
31 |
+
norm_layer = nn.BatchNorm2d
|
32 |
+
else:
|
33 |
+
norm_layer = ActNorm
|
34 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
35 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
36 |
+
else:
|
37 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
38 |
+
|
39 |
+
kw = 4
|
40 |
+
padw = 1
|
41 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
42 |
+
nf_mult = 1
|
43 |
+
nf_mult_prev = 1
|
44 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
45 |
+
nf_mult_prev = nf_mult
|
46 |
+
nf_mult = min(2 ** n, 8)
|
47 |
+
sequence += [
|
48 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
49 |
+
norm_layer(ndf * nf_mult),
|
50 |
+
nn.LeakyReLU(0.2, True)
|
51 |
+
]
|
52 |
+
|
53 |
+
nf_mult_prev = nf_mult
|
54 |
+
nf_mult = min(2 ** n_layers, 8)
|
55 |
+
sequence += [
|
56 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
57 |
+
norm_layer(ndf * nf_mult),
|
58 |
+
nn.LeakyReLU(0.2, True)
|
59 |
+
]
|
60 |
+
|
61 |
+
sequence += [
|
62 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
63 |
+
self.main = nn.Sequential(*sequence)
|
64 |
+
|
65 |
+
def forward(self, input):
|
66 |
+
"""Standard forward."""
|
67 |
+
return self.main(input)
|
taming/modules/losses/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from taming.modules.losses.vqperceptual import DummyLoss
|
2 |
+
|
taming/modules/losses/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (231 Bytes). View file
|
|
taming/modules/losses/__pycache__/lpips.cpython-38.pyc
ADDED
Binary file (5.3 kB). View file
|
|
taming/modules/losses/__pycache__/vqperceptual.cpython-38.pyc
ADDED
Binary file (6.94 kB). View file
|
|
taming/modules/losses/lpips.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torchvision import models
|
6 |
+
from collections import namedtuple
|
7 |
+
|
8 |
+
from taming.util import get_ckpt_path
|
9 |
+
|
10 |
+
|
11 |
+
class LPIPS(nn.Module):
|
12 |
+
# Learned perceptual metric
|
13 |
+
def __init__(self, use_dropout=True):
|
14 |
+
super().__init__()
|
15 |
+
self.scaling_layer = ScalingLayer()
|
16 |
+
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
17 |
+
self.net = vgg16(pretrained=True, requires_grad=False)
|
18 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
19 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
20 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
21 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
22 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
23 |
+
self.load_from_pretrained()
|
24 |
+
for param in self.parameters():
|
25 |
+
param.requires_grad = False
|
26 |
+
|
27 |
+
def load_from_pretrained(self, name="vgg_lpips"):
|
28 |
+
ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
|
29 |
+
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
30 |
+
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def from_pretrained(cls, name="vgg_lpips"):
|
34 |
+
if name != "vgg_lpips":
|
35 |
+
raise NotImplementedError
|
36 |
+
model = cls()
|
37 |
+
ckpt = get_ckpt_path(name)
|
38 |
+
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
39 |
+
return model
|
40 |
+
|
41 |
+
def forward(self, input, target):
|
42 |
+
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
43 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
44 |
+
feats0, feats1, diffs = {}, {}, {}
|
45 |
+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
46 |
+
for kk in range(len(self.chns)):
|
47 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
48 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
49 |
+
|
50 |
+
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
|
51 |
+
val = res[0]
|
52 |
+
for l in range(1, len(self.chns)):
|
53 |
+
val += res[l]
|
54 |
+
return val
|
55 |
+
|
56 |
+
|
57 |
+
class ScalingLayer(nn.Module):
|
58 |
+
def __init__(self):
|
59 |
+
super(ScalingLayer, self).__init__()
|
60 |
+
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
61 |
+
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
62 |
+
|
63 |
+
def forward(self, inp):
|
64 |
+
return (inp - self.shift) / self.scale
|
65 |
+
|
66 |
+
|
67 |
+
class NetLinLayer(nn.Module):
|
68 |
+
""" A single linear layer which does a 1x1 conv """
|
69 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
70 |
+
super(NetLinLayer, self).__init__()
|
71 |
+
layers = [nn.Dropout(), ] if (use_dropout) else []
|
72 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
|
73 |
+
self.model = nn.Sequential(*layers)
|
74 |
+
|
75 |
+
|
76 |
+
class vgg16(torch.nn.Module):
|
77 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
78 |
+
super(vgg16, self).__init__()
|
79 |
+
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
80 |
+
self.slice1 = torch.nn.Sequential()
|
81 |
+
self.slice2 = torch.nn.Sequential()
|
82 |
+
self.slice3 = torch.nn.Sequential()
|
83 |
+
self.slice4 = torch.nn.Sequential()
|
84 |
+
self.slice5 = torch.nn.Sequential()
|
85 |
+
self.N_slices = 5
|
86 |
+
for x in range(4):
|
87 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
88 |
+
for x in range(4, 9):
|
89 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
90 |
+
for x in range(9, 16):
|
91 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
92 |
+
for x in range(16, 23):
|
93 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
94 |
+
for x in range(23, 30):
|
95 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
96 |
+
if not requires_grad:
|
97 |
+
for param in self.parameters():
|
98 |
+
param.requires_grad = False
|
99 |
+
|
100 |
+
def forward(self, X):
|
101 |
+
h = self.slice1(X)
|
102 |
+
h_relu1_2 = h
|
103 |
+
h = self.slice2(h)
|
104 |
+
h_relu2_2 = h
|
105 |
+
h = self.slice3(h)
|
106 |
+
h_relu3_3 = h
|
107 |
+
h = self.slice4(h)
|
108 |
+
h_relu4_3 = h
|
109 |
+
h = self.slice5(h)
|
110 |
+
h_relu5_3 = h
|
111 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
112 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
113 |
+
return out
|
114 |
+
|
115 |
+
|
116 |
+
def normalize_tensor(x,eps=1e-10):
|
117 |
+
norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
|
118 |
+
return x/(norm_factor+eps)
|
119 |
+
|
120 |
+
|
121 |
+
def spatial_average(x, keepdim=True):
|
122 |
+
return x.mean([2,3],keepdim=keepdim)
|
123 |
+
|
taming/modules/losses/segmentation.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
class BCELoss(nn.Module):
|
6 |
+
def forward(self, prediction, target):
|
7 |
+
loss = F.binary_cross_entropy_with_logits(prediction,target)
|
8 |
+
return loss, {}
|
9 |
+
|
10 |
+
|
11 |
+
class BCELossWithQuant(nn.Module):
|
12 |
+
def __init__(self, codebook_weight=1.):
|
13 |
+
super().__init__()
|
14 |
+
self.codebook_weight = codebook_weight
|
15 |
+
|
16 |
+
def forward(self, qloss, target, prediction, split):
|
17 |
+
bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
|
18 |
+
loss = bce_loss + self.codebook_weight*qloss
|
19 |
+
return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
20 |
+
"{}/bce_loss".format(split): bce_loss.detach().mean(),
|
21 |
+
"{}/quant_loss".format(split): qloss.detach().mean()
|
22 |
+
}
|
taming/modules/losses/vqperceptual.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from taming.modules.losses.lpips import LPIPS
|
6 |
+
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
7 |
+
|
8 |
+
|
9 |
+
class DummyLoss(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
|
14 |
+
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
15 |
+
if global_step < threshold:
|
16 |
+
weight = value
|
17 |
+
return weight
|
18 |
+
|
19 |
+
|
20 |
+
def hinge_d_loss(logits_real, logits_fake):
|
21 |
+
loss_real = torch.mean(F.relu(1. - logits_real))
|
22 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
23 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
24 |
+
return d_loss
|
25 |
+
|
26 |
+
|
27 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
28 |
+
d_loss = 0.5 * (
|
29 |
+
torch.mean(torch.nn.functional.softplus(-logits_real)) +
|
30 |
+
torch.mean(torch.nn.functional.softplus(logits_fake)))
|
31 |
+
return d_loss
|
32 |
+
|
33 |
+
|
34 |
+
class VQLPIPSWithDiscriminator(nn.Module):
|
35 |
+
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
|
36 |
+
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
37 |
+
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
38 |
+
disc_ndf=64, disc_loss="hinge"):
|
39 |
+
super().__init__()
|
40 |
+
assert disc_loss in ["hinge", "vanilla"]
|
41 |
+
self.codebook_weight = codebook_weight
|
42 |
+
self.pixel_weight = pixelloss_weight
|
43 |
+
self.perceptual_loss = LPIPS().eval()
|
44 |
+
self.perceptual_weight = perceptual_weight
|
45 |
+
|
46 |
+
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
47 |
+
n_layers=disc_num_layers,
|
48 |
+
use_actnorm=use_actnorm,
|
49 |
+
ndf=disc_ndf
|
50 |
+
).apply(weights_init)
|
51 |
+
self.discriminator_iter_start = disc_start
|
52 |
+
if disc_loss == "hinge":
|
53 |
+
self.disc_loss = hinge_d_loss
|
54 |
+
elif disc_loss == "vanilla":
|
55 |
+
self.disc_loss = vanilla_d_loss
|
56 |
+
else:
|
57 |
+
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
58 |
+
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
59 |
+
self.disc_factor = disc_factor
|
60 |
+
self.discriminator_weight = disc_weight
|
61 |
+
self.disc_conditional = disc_conditional
|
62 |
+
|
63 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
64 |
+
if last_layer is not None:
|
65 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
66 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
67 |
+
else:
|
68 |
+
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
69 |
+
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
70 |
+
|
71 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
72 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
73 |
+
d_weight = d_weight * self.discriminator_weight
|
74 |
+
return d_weight
|
75 |
+
|
76 |
+
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
|
77 |
+
global_step, last_layer=None, cond=None, split="train"):
|
78 |
+
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
79 |
+
if self.perceptual_weight > 0:
|
80 |
+
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
81 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
82 |
+
else:
|
83 |
+
p_loss = torch.tensor([0.0])
|
84 |
+
|
85 |
+
nll_loss = rec_loss
|
86 |
+
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
87 |
+
nll_loss = torch.mean(nll_loss)
|
88 |
+
|
89 |
+
# now the GAN part
|
90 |
+
if optimizer_idx == 0:
|
91 |
+
# generator update
|
92 |
+
if cond is None:
|
93 |
+
assert not self.disc_conditional
|
94 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
95 |
+
else:
|
96 |
+
assert self.disc_conditional
|
97 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
98 |
+
g_loss = -torch.mean(logits_fake)
|
99 |
+
|
100 |
+
try:
|
101 |
+
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
102 |
+
except RuntimeError:
|
103 |
+
assert not self.training
|
104 |
+
d_weight = torch.tensor(0.0)
|
105 |
+
|
106 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
107 |
+
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
|
108 |
+
|
109 |
+
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
110 |
+
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
|
111 |
+
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
112 |
+
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
113 |
+
"{}/p_loss".format(split): p_loss.detach().mean(),
|
114 |
+
"{}/d_weight".format(split): d_weight.detach(),
|
115 |
+
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
116 |
+
"{}/g_loss".format(split): g_loss.detach().mean(),
|
117 |
+
}
|
118 |
+
return loss, log
|
119 |
+
|
120 |
+
if optimizer_idx == 1:
|
121 |
+
# second pass for discriminator update
|
122 |
+
if cond is None:
|
123 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
124 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
125 |
+
else:
|
126 |
+
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
127 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
128 |
+
|
129 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
130 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
131 |
+
|
132 |
+
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
133 |
+
"{}/logits_real".format(split): logits_real.detach().mean(),
|
134 |
+
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
135 |
+
}
|
136 |
+
return d_loss, log
|
137 |
+
|
138 |
+
class LPIPSWithDiscriminator(nn.Module):
|
139 |
+
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
|
140 |
+
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
141 |
+
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
142 |
+
disc_loss="hinge"):
|
143 |
+
|
144 |
+
super().__init__()
|
145 |
+
assert disc_loss in ["hinge", "vanilla"]
|
146 |
+
self.kl_weight = kl_weight
|
147 |
+
self.pixel_weight = pixelloss_weight
|
148 |
+
self.perceptual_loss = LPIPS().eval()
|
149 |
+
self.perceptual_weight = perceptual_weight
|
150 |
+
# output log variance
|
151 |
+
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
152 |
+
|
153 |
+
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
154 |
+
n_layers=disc_num_layers,
|
155 |
+
use_actnorm=use_actnorm
|
156 |
+
).apply(weights_init)
|
157 |
+
self.discriminator_iter_start = disc_start
|
158 |
+
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
159 |
+
self.disc_factor = disc_factor
|
160 |
+
self.discriminator_weight = disc_weight
|
161 |
+
self.disc_conditional = disc_conditional
|
162 |
+
|
163 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
164 |
+
if last_layer is not None:
|
165 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
166 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
167 |
+
else:
|
168 |
+
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
169 |
+
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
170 |
+
|
171 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
172 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
173 |
+
d_weight = d_weight * self.discriminator_weight
|
174 |
+
return d_weight
|
175 |
+
|
176 |
+
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
|
177 |
+
global_step, last_layer=None, cond=None, split="train",
|
178 |
+
weights=None):
|
179 |
+
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
180 |
+
if self.perceptual_weight > 0:
|
181 |
+
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
182 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
183 |
+
|
184 |
+
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
185 |
+
weighted_nll_loss = nll_loss
|
186 |
+
if weights is not None:
|
187 |
+
weighted_nll_loss = weights*nll_loss
|
188 |
+
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
189 |
+
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
190 |
+
kl_loss = posteriors.kl()
|
191 |
+
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
192 |
+
|
193 |
+
# now the GAN part
|
194 |
+
if optimizer_idx == 0:
|
195 |
+
# generator update
|
196 |
+
if cond is None:
|
197 |
+
assert not self.disc_conditional
|
198 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
199 |
+
else:
|
200 |
+
assert self.disc_conditional
|
201 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
202 |
+
g_loss = -torch.mean(logits_fake)
|
203 |
+
|
204 |
+
if self.disc_factor > 0.0:
|
205 |
+
try:
|
206 |
+
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
207 |
+
except RuntimeError:
|
208 |
+
assert not self.training
|
209 |
+
d_weight = torch.tensor(0.0)
|
210 |
+
else:
|
211 |
+
d_weight = torch.tensor(0.0)
|
212 |
+
|
213 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
214 |
+
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
|
215 |
+
|
216 |
+
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
|
217 |
+
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
|
218 |
+
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
219 |
+
"{}/d_weight".format(split): d_weight.detach(),
|
220 |
+
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
221 |
+
"{}/g_loss".format(split): g_loss.detach().mean(),
|
222 |
+
}
|
223 |
+
return loss, log
|
224 |
+
|
225 |
+
if optimizer_idx == 1:
|
226 |
+
# second pass for discriminator update
|
227 |
+
if cond is None:
|
228 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
229 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
230 |
+
else:
|
231 |
+
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
232 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
233 |
+
|
234 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
235 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
236 |
+
|
237 |
+
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
238 |
+
"{}/logits_real".format(split): logits_real.detach().mean(),
|
239 |
+
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
240 |
+
}
|
241 |
+
return d_loss, log
|
taming/modules/misc/coord.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class CoordStage(object):
|
4 |
+
def __init__(self, n_embed, down_factor):
|
5 |
+
self.n_embed = n_embed
|
6 |
+
self.down_factor = down_factor
|
7 |
+
|
8 |
+
def eval(self):
|
9 |
+
return self
|
10 |
+
|
11 |
+
def encode(self, c):
|
12 |
+
"""fake vqmodel interface"""
|
13 |
+
assert 0.0 <= c.min() and c.max() <= 1.0
|
14 |
+
b,ch,h,w = c.shape
|
15 |
+
assert ch == 1
|
16 |
+
|
17 |
+
c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
|
18 |
+
mode="area")
|
19 |
+
c = c.clamp(0.0, 1.0)
|
20 |
+
c = self.n_embed*c
|
21 |
+
c_quant = c.round()
|
22 |
+
c_ind = c_quant.to(dtype=torch.long)
|
23 |
+
|
24 |
+
info = None, None, c_ind
|
25 |
+
return c_quant, None, info
|
26 |
+
|
27 |
+
def decode(self, c):
|
28 |
+
c = c/self.n_embed
|
29 |
+
c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
|
30 |
+
mode="nearest")
|
31 |
+
return c
|
taming/modules/transformer/mingpt.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
taken from: https://github.com/karpathy/minGPT/
|
3 |
+
GPT model:
|
4 |
+
- the initial stem consists of a combination of token encoding and a positional encoding
|
5 |
+
- the meat of it is a uniform sequence of Transformer blocks
|
6 |
+
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
|
7 |
+
- all blocks feed into a central residual pathway similar to resnets
|
8 |
+
- the final decoder is a linear projection into a vanilla Softmax classifier
|
9 |
+
"""
|
10 |
+
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
from transformers import top_k_top_p_filtering
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class GPTConfig:
|
23 |
+
""" base GPT config, params common to all GPT versions """
|
24 |
+
embd_pdrop = 0.1
|
25 |
+
resid_pdrop = 0.1
|
26 |
+
attn_pdrop = 0.1
|
27 |
+
|
28 |
+
def __init__(self, vocab_size, block_size, **kwargs):
|
29 |
+
self.vocab_size = vocab_size
|
30 |
+
self.block_size = block_size
|
31 |
+
for k,v in kwargs.items():
|
32 |
+
setattr(self, k, v)
|
33 |
+
|
34 |
+
|
35 |
+
class GPT1Config(GPTConfig):
|
36 |
+
""" GPT-1 like network roughly 125M params """
|
37 |
+
n_layer = 12
|
38 |
+
n_head = 12
|
39 |
+
n_embd = 768
|
40 |
+
|
41 |
+
|
42 |
+
class CausalSelfAttention(nn.Module):
|
43 |
+
"""
|
44 |
+
A vanilla multi-head masked self-attention layer with a projection at the end.
|
45 |
+
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
46 |
+
explicit implementation here to show that there is nothing too scary here.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, config):
|
50 |
+
super().__init__()
|
51 |
+
assert config.n_embd % config.n_head == 0
|
52 |
+
# key, query, value projections for all heads
|
53 |
+
self.key = nn.Linear(config.n_embd, config.n_embd)
|
54 |
+
self.query = nn.Linear(config.n_embd, config.n_embd)
|
55 |
+
self.value = nn.Linear(config.n_embd, config.n_embd)
|
56 |
+
# regularization
|
57 |
+
self.attn_drop = nn.Dropout(config.attn_pdrop)
|
58 |
+
self.resid_drop = nn.Dropout(config.resid_pdrop)
|
59 |
+
# output projection
|
60 |
+
self.proj = nn.Linear(config.n_embd, config.n_embd)
|
61 |
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
62 |
+
mask = torch.tril(torch.ones(config.block_size,
|
63 |
+
config.block_size))
|
64 |
+
if hasattr(config, "n_unmasked"):
|
65 |
+
mask[:config.n_unmasked, :config.n_unmasked] = 1
|
66 |
+
self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
|
67 |
+
self.n_head = config.n_head
|
68 |
+
|
69 |
+
def forward(self, x, layer_past=None):
|
70 |
+
B, T, C = x.size()
|
71 |
+
|
72 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
73 |
+
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
74 |
+
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
75 |
+
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
76 |
+
|
77 |
+
present = torch.stack((k, v))
|
78 |
+
if layer_past is not None:
|
79 |
+
past_key, past_value = layer_past
|
80 |
+
k = torch.cat((past_key, k), dim=-2)
|
81 |
+
v = torch.cat((past_value, v), dim=-2)
|
82 |
+
|
83 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
84 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
85 |
+
if layer_past is None:
|
86 |
+
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
87 |
+
|
88 |
+
att = F.softmax(att, dim=-1)
|
89 |
+
att = self.attn_drop(att)
|
90 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
91 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
92 |
+
|
93 |
+
# output projection
|
94 |
+
y = self.resid_drop(self.proj(y))
|
95 |
+
return y, present # TODO: check that this does not break anything
|
96 |
+
|
97 |
+
|
98 |
+
class Block(nn.Module):
|
99 |
+
""" an unassuming Transformer block """
|
100 |
+
def __init__(self, config):
|
101 |
+
super().__init__()
|
102 |
+
self.ln1 = nn.LayerNorm(config.n_embd)
|
103 |
+
self.ln2 = nn.LayerNorm(config.n_embd)
|
104 |
+
self.attn = CausalSelfAttention(config)
|
105 |
+
self.mlp = nn.Sequential(
|
106 |
+
nn.Linear(config.n_embd, 4 * config.n_embd),
|
107 |
+
nn.GELU(), # nice
|
108 |
+
nn.Linear(4 * config.n_embd, config.n_embd),
|
109 |
+
nn.Dropout(config.resid_pdrop),
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, x, layer_past=None, return_present=False):
|
113 |
+
# TODO: check that training still works
|
114 |
+
if return_present: assert not self.training
|
115 |
+
# layer past: tuple of length two with B, nh, T, hs
|
116 |
+
attn, present = self.attn(self.ln1(x), layer_past=layer_past)
|
117 |
+
|
118 |
+
x = x + attn
|
119 |
+
x = x + self.mlp(self.ln2(x))
|
120 |
+
if layer_past is not None or return_present:
|
121 |
+
return x, present
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
class GPT(nn.Module):
|
126 |
+
""" the full GPT language model, with a context size of block_size """
|
127 |
+
def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
|
128 |
+
embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
|
129 |
+
super().__init__()
|
130 |
+
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
|
131 |
+
embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
|
132 |
+
n_layer=n_layer, n_head=n_head, n_embd=n_embd,
|
133 |
+
n_unmasked=n_unmasked)
|
134 |
+
# input embedding stem
|
135 |
+
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
|
136 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
137 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
138 |
+
# transformer
|
139 |
+
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
140 |
+
# decoder head
|
141 |
+
self.ln_f = nn.LayerNorm(config.n_embd)
|
142 |
+
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
143 |
+
self.block_size = config.block_size
|
144 |
+
self.apply(self._init_weights)
|
145 |
+
self.config = config
|
146 |
+
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
147 |
+
|
148 |
+
def get_block_size(self):
|
149 |
+
return self.block_size
|
150 |
+
|
151 |
+
def _init_weights(self, module):
|
152 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
153 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
154 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
155 |
+
module.bias.data.zero_()
|
156 |
+
elif isinstance(module, nn.LayerNorm):
|
157 |
+
module.bias.data.zero_()
|
158 |
+
module.weight.data.fill_(1.0)
|
159 |
+
|
160 |
+
def forward(self, idx, embeddings=None, targets=None):
|
161 |
+
# forward the GPT model
|
162 |
+
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
163 |
+
|
164 |
+
if embeddings is not None: # prepend explicit embeddings
|
165 |
+
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
166 |
+
|
167 |
+
t = token_embeddings.shape[1]
|
168 |
+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
169 |
+
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
170 |
+
x = self.drop(token_embeddings + position_embeddings)
|
171 |
+
x = self.blocks(x)
|
172 |
+
x = self.ln_f(x)
|
173 |
+
logits = self.head(x)
|
174 |
+
|
175 |
+
# if we are given some desired targets also calculate the loss
|
176 |
+
loss = None
|
177 |
+
if targets is not None:
|
178 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
179 |
+
|
180 |
+
return logits, loss
|
181 |
+
|
182 |
+
def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
|
183 |
+
# inference only
|
184 |
+
assert not self.training
|
185 |
+
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
186 |
+
if embeddings is not None: # prepend explicit embeddings
|
187 |
+
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
188 |
+
|
189 |
+
if past is not None:
|
190 |
+
assert past_length is not None
|
191 |
+
past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
|
192 |
+
past_shape = list(past.shape)
|
193 |
+
expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
|
194 |
+
assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
|
195 |
+
position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
|
196 |
+
else:
|
197 |
+
position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
|
198 |
+
|
199 |
+
x = self.drop(token_embeddings + position_embeddings)
|
200 |
+
presents = [] # accumulate over layers
|
201 |
+
for i, block in enumerate(self.blocks):
|
202 |
+
x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
|
203 |
+
presents.append(present)
|
204 |
+
|
205 |
+
x = self.ln_f(x)
|
206 |
+
logits = self.head(x)
|
207 |
+
# if we are given some desired targets also calculate the loss
|
208 |
+
loss = None
|
209 |
+
if targets is not None:
|
210 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
211 |
+
|
212 |
+
return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
|
213 |
+
|
214 |
+
|
215 |
+
class DummyGPT(nn.Module):
|
216 |
+
# for debugging
|
217 |
+
def __init__(self, add_value=1):
|
218 |
+
super().__init__()
|
219 |
+
self.add_value = add_value
|
220 |
+
|
221 |
+
def forward(self, idx):
|
222 |
+
return idx + self.add_value, None
|
223 |
+
|
224 |
+
|
225 |
+
class CodeGPT(nn.Module):
|
226 |
+
"""Takes in semi-embeddings"""
|
227 |
+
def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
|
228 |
+
embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
|
229 |
+
super().__init__()
|
230 |
+
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
|
231 |
+
embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
|
232 |
+
n_layer=n_layer, n_head=n_head, n_embd=n_embd,
|
233 |
+
n_unmasked=n_unmasked)
|
234 |
+
# input embedding stem
|
235 |
+
self.tok_emb = nn.Linear(in_channels, config.n_embd)
|
236 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
237 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
238 |
+
# transformer
|
239 |
+
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
240 |
+
# decoder head
|
241 |
+
self.ln_f = nn.LayerNorm(config.n_embd)
|
242 |
+
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
243 |
+
self.block_size = config.block_size
|
244 |
+
self.apply(self._init_weights)
|
245 |
+
self.config = config
|
246 |
+
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
247 |
+
|
248 |
+
def get_block_size(self):
|
249 |
+
return self.block_size
|
250 |
+
|
251 |
+
def _init_weights(self, module):
|
252 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
253 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
254 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
255 |
+
module.bias.data.zero_()
|
256 |
+
elif isinstance(module, nn.LayerNorm):
|
257 |
+
module.bias.data.zero_()
|
258 |
+
module.weight.data.fill_(1.0)
|
259 |
+
|
260 |
+
def forward(self, idx, embeddings=None, targets=None):
|
261 |
+
# forward the GPT model
|
262 |
+
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
263 |
+
|
264 |
+
if embeddings is not None: # prepend explicit embeddings
|
265 |
+
token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
|
266 |
+
|
267 |
+
t = token_embeddings.shape[1]
|
268 |
+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
269 |
+
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
270 |
+
x = self.drop(token_embeddings + position_embeddings)
|
271 |
+
x = self.blocks(x)
|
272 |
+
x = self.taming_cinln_f(x)
|
273 |
+
logits = self.head(x)
|
274 |
+
|
275 |
+
# if we are given some desired targets also calculate the loss
|
276 |
+
loss = None
|
277 |
+
if targets is not None:
|
278 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
279 |
+
|
280 |
+
return logits, loss
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
#### sampling utils
|
285 |
+
|
286 |
+
def top_k_logits(logits, k):
|
287 |
+
v, ix = torch.topk(logits, k)
|
288 |
+
out = logits.clone()
|
289 |
+
out[out < v[:, [-1]]] = -float('Inf')
|
290 |
+
return out
|
291 |
+
|
292 |
+
@torch.no_grad()
|
293 |
+
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
|
294 |
+
"""
|
295 |
+
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
|
296 |
+
the sequence, feeding the predictions back into the model each time. Clearly the sampling
|
297 |
+
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
|
298 |
+
of block_size, unlike an RNN that has an infinite context window.
|
299 |
+
"""
|
300 |
+
block_size = model.get_block_size()
|
301 |
+
model.eval()
|
302 |
+
for k in range(steps):
|
303 |
+
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
304 |
+
logits, _ = model(x_cond)
|
305 |
+
# pluck the logits at the final step and scale by temperature
|
306 |
+
logits = logits[:, -1, :] / temperature
|
307 |
+
# optionally crop probabilities to only the top k options
|
308 |
+
if top_k is not None:
|
309 |
+
logits = top_k_logits(logits, top_k)
|
310 |
+
# apply softmax to convert to probabilities
|
311 |
+
probs = F.softmax(logits, dim=-1)
|
312 |
+
# sample from the distribution or take the most likely
|
313 |
+
if sample:
|
314 |
+
ix = torch.multinomial(probs, num_samples=1)
|
315 |
+
else:
|
316 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
317 |
+
# append to the sequence and continue
|
318 |
+
x = torch.cat((x, ix), dim=1)
|
319 |
+
|
320 |
+
return x
|
321 |
+
|
322 |
+
|
323 |
+
@torch.no_grad()
|
324 |
+
def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
|
325 |
+
top_k=None, top_p=None, callback=None):
|
326 |
+
# x is conditioning
|
327 |
+
sample = x
|
328 |
+
cond_len = x.shape[1]
|
329 |
+
past = None
|
330 |
+
for n in range(steps):
|
331 |
+
if callback is not None:
|
332 |
+
callback(n)
|
333 |
+
logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
|
334 |
+
if past is None:
|
335 |
+
past = [present]
|
336 |
+
else:
|
337 |
+
past.append(present)
|
338 |
+
logits = logits[:, -1, :] / temperature
|
339 |
+
if top_k is not None:
|
340 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
341 |
+
|
342 |
+
probs = F.softmax(logits, dim=-1)
|
343 |
+
if not sample_logits:
|
344 |
+
_, x = torch.topk(probs, k=1, dim=-1)
|
345 |
+
else:
|
346 |
+
x = torch.multinomial(probs, num_samples=1)
|
347 |
+
# append to the sequence and continue
|
348 |
+
sample = torch.cat((sample, x), dim=1)
|
349 |
+
del past
|
350 |
+
sample = sample[:, cond_len:] # cut conditioning off
|
351 |
+
return sample
|
352 |
+
|
353 |
+
|
354 |
+
#### clustering utils
|
355 |
+
|
356 |
+
class KMeans(nn.Module):
|
357 |
+
def __init__(self, ncluster=512, nc=3, niter=10):
|
358 |
+
super().__init__()
|
359 |
+
self.ncluster = ncluster
|
360 |
+
self.nc = nc
|
361 |
+
self.niter = niter
|
362 |
+
self.shape = (3,32,32)
|
363 |
+
self.register_buffer("C", torch.zeros(self.ncluster,nc))
|
364 |
+
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
365 |
+
|
366 |
+
def is_initialized(self):
|
367 |
+
return self.initialized.item() == 1
|
368 |
+
|
369 |
+
@torch.no_grad()
|
370 |
+
def initialize(self, x):
|
371 |
+
N, D = x.shape
|
372 |
+
assert D == self.nc, D
|
373 |
+
c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
|
374 |
+
for i in range(self.niter):
|
375 |
+
# assign all pixels to the closest codebook element
|
376 |
+
a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
|
377 |
+
# move each codebook element to be the mean of the pixels that assigned to it
|
378 |
+
c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
|
379 |
+
# re-assign any poorly positioned codebook elements
|
380 |
+
nanix = torch.any(torch.isnan(c), dim=1)
|
381 |
+
ndead = nanix.sum().item()
|
382 |
+
print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
|
383 |
+
c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
|
384 |
+
|
385 |
+
self.C.copy_(c)
|
386 |
+
self.initialized.fill_(1)
|
387 |
+
|
388 |
+
|
389 |
+
def forward(self, x, reverse=False, shape=None):
|
390 |
+
if not reverse:
|
391 |
+
# flatten
|
392 |
+
bs,c,h,w = x.shape
|
393 |
+
assert c == self.nc
|
394 |
+
x = x.reshape(bs,c,h*w,1)
|
395 |
+
C = self.C.permute(1,0)
|
396 |
+
C = C.reshape(1,c,1,self.ncluster)
|
397 |
+
a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
|
398 |
+
return a
|
399 |
+
else:
|
400 |
+
# flatten
|
401 |
+
bs, HW = x.shape
|
402 |
+
"""
|
403 |
+
c = self.C.reshape( 1, self.nc, 1, self.ncluster)
|
404 |
+
c = c[bs*[0],:,:,:]
|
405 |
+
c = c[:,:,HW*[0],:]
|
406 |
+
x = x.reshape(bs, 1, HW, 1)
|
407 |
+
x = x[:,3*[0],:,:]
|
408 |
+
x = torch.gather(c, dim=3, index=x)
|
409 |
+
"""
|
410 |
+
x = self.C[x]
|
411 |
+
x = x.permute(0,2,1)
|
412 |
+
shape = shape if shape is not None else self.shape
|
413 |
+
x = x.reshape(bs, *shape)
|
414 |
+
|
415 |
+
return x
|
taming/modules/transformer/permuter.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class AbstractPermuter(nn.Module):
|
7 |
+
def __init__(self, *args, **kwargs):
|
8 |
+
super().__init__()
|
9 |
+
def forward(self, x, reverse=False):
|
10 |
+
raise NotImplementedError
|
11 |
+
|
12 |
+
|
13 |
+
class Identity(AbstractPermuter):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
def forward(self, x, reverse=False):
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
class Subsample(AbstractPermuter):
|
22 |
+
def __init__(self, H, W):
|
23 |
+
super().__init__()
|
24 |
+
C = 1
|
25 |
+
indices = np.arange(H*W).reshape(C,H,W)
|
26 |
+
while min(H, W) > 1:
|
27 |
+
indices = indices.reshape(C,H//2,2,W//2,2)
|
28 |
+
indices = indices.transpose(0,2,4,1,3)
|
29 |
+
indices = indices.reshape(C*4,H//2, W//2)
|
30 |
+
H = H//2
|
31 |
+
W = W//2
|
32 |
+
C = C*4
|
33 |
+
assert H == W == 1
|
34 |
+
idx = torch.tensor(indices.ravel())
|
35 |
+
self.register_buffer('forward_shuffle_idx',
|
36 |
+
nn.Parameter(idx, requires_grad=False))
|
37 |
+
self.register_buffer('backward_shuffle_idx',
|
38 |
+
nn.Parameter(torch.argsort(idx), requires_grad=False))
|
39 |
+
|
40 |
+
def forward(self, x, reverse=False):
|
41 |
+
if not reverse:
|
42 |
+
return x[:, self.forward_shuffle_idx]
|
43 |
+
else:
|
44 |
+
return x[:, self.backward_shuffle_idx]
|
45 |
+
|
46 |
+
|
47 |
+
def mortonify(i, j):
|
48 |
+
"""(i,j) index to linear morton code"""
|
49 |
+
i = np.uint64(i)
|
50 |
+
j = np.uint64(j)
|
51 |
+
|
52 |
+
z = np.uint(0)
|
53 |
+
|
54 |
+
for pos in range(32):
|
55 |
+
z = (z |
|
56 |
+
((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
|
57 |
+
((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
|
58 |
+
)
|
59 |
+
return z
|
60 |
+
|
61 |
+
|
62 |
+
class ZCurve(AbstractPermuter):
|
63 |
+
def __init__(self, H, W):
|
64 |
+
super().__init__()
|
65 |
+
reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
|
66 |
+
idx = np.argsort(reverseidx)
|
67 |
+
idx = torch.tensor(idx)
|
68 |
+
reverseidx = torch.tensor(reverseidx)
|
69 |
+
self.register_buffer('forward_shuffle_idx',
|
70 |
+
idx)
|
71 |
+
self.register_buffer('backward_shuffle_idx',
|
72 |
+
reverseidx)
|
73 |
+
|
74 |
+
def forward(self, x, reverse=False):
|
75 |
+
if not reverse:
|
76 |
+
return x[:, self.forward_shuffle_idx]
|
77 |
+
else:
|
78 |
+
return x[:, self.backward_shuffle_idx]
|
79 |
+
|
80 |
+
|
81 |
+
class SpiralOut(AbstractPermuter):
|
82 |
+
def __init__(self, H, W):
|
83 |
+
super().__init__()
|
84 |
+
assert H == W
|
85 |
+
size = W
|
86 |
+
indices = np.arange(size*size).reshape(size,size)
|
87 |
+
|
88 |
+
i0 = size//2
|
89 |
+
j0 = size//2-1
|
90 |
+
|
91 |
+
i = i0
|
92 |
+
j = j0
|
93 |
+
|
94 |
+
idx = [indices[i0, j0]]
|
95 |
+
step_mult = 0
|
96 |
+
for c in range(1, size//2+1):
|
97 |
+
step_mult += 1
|
98 |
+
# steps left
|
99 |
+
for k in range(step_mult):
|
100 |
+
i = i - 1
|
101 |
+
j = j
|
102 |
+
idx.append(indices[i, j])
|
103 |
+
|
104 |
+
# step down
|
105 |
+
for k in range(step_mult):
|
106 |
+
i = i
|
107 |
+
j = j + 1
|
108 |
+
idx.append(indices[i, j])
|
109 |
+
|
110 |
+
step_mult += 1
|
111 |
+
if c < size//2:
|
112 |
+
# step right
|
113 |
+
for k in range(step_mult):
|
114 |
+
i = i + 1
|
115 |
+
j = j
|
116 |
+
idx.append(indices[i, j])
|
117 |
+
|
118 |
+
# step up
|
119 |
+
for k in range(step_mult):
|
120 |
+
i = i
|
121 |
+
j = j - 1
|
122 |
+
idx.append(indices[i, j])
|
123 |
+
else:
|
124 |
+
# end reached
|
125 |
+
for k in range(step_mult-1):
|
126 |
+
i = i + 1
|
127 |
+
idx.append(indices[i, j])
|
128 |
+
|
129 |
+
assert len(idx) == size*size
|
130 |
+
idx = torch.tensor(idx)
|
131 |
+
self.register_buffer('forward_shuffle_idx', idx)
|
132 |
+
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
133 |
+
|
134 |
+
def forward(self, x, reverse=False):
|
135 |
+
if not reverse:
|
136 |
+
return x[:, self.forward_shuffle_idx]
|
137 |
+
else:
|
138 |
+
return x[:, self.backward_shuffle_idx]
|
139 |
+
|
140 |
+
|
141 |
+
class SpiralIn(AbstractPermuter):
|
142 |
+
def __init__(self, H, W):
|
143 |
+
super().__init__()
|
144 |
+
assert H == W
|
145 |
+
size = W
|
146 |
+
indices = np.arange(size*size).reshape(size,size)
|
147 |
+
|
148 |
+
i0 = size//2
|
149 |
+
j0 = size//2-1
|
150 |
+
|
151 |
+
i = i0
|
152 |
+
j = j0
|
153 |
+
|
154 |
+
idx = [indices[i0, j0]]
|
155 |
+
step_mult = 0
|
156 |
+
for c in range(1, size//2+1):
|
157 |
+
step_mult += 1
|
158 |
+
# steps left
|
159 |
+
for k in range(step_mult):
|
160 |
+
i = i - 1
|
161 |
+
j = j
|
162 |
+
idx.append(indices[i, j])
|
163 |
+
|
164 |
+
# step down
|
165 |
+
for k in range(step_mult):
|
166 |
+
i = i
|
167 |
+
j = j + 1
|
168 |
+
idx.append(indices[i, j])
|
169 |
+
|
170 |
+
step_mult += 1
|
171 |
+
if c < size//2:
|
172 |
+
# step right
|
173 |
+
for k in range(step_mult):
|
174 |
+
i = i + 1
|
175 |
+
j = j
|
176 |
+
idx.append(indices[i, j])
|
177 |
+
|
178 |
+
# step up
|
179 |
+
for k in range(step_mult):
|
180 |
+
i = i
|
181 |
+
j = j - 1
|
182 |
+
idx.append(indices[i, j])
|
183 |
+
else:
|
184 |
+
# end reached
|
185 |
+
for k in range(step_mult-1):
|
186 |
+
i = i + 1
|
187 |
+
idx.append(indices[i, j])
|
188 |
+
|
189 |
+
assert len(idx) == size*size
|
190 |
+
idx = idx[::-1]
|
191 |
+
idx = torch.tensor(idx)
|
192 |
+
self.register_buffer('forward_shuffle_idx', idx)
|
193 |
+
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
194 |
+
|
195 |
+
def forward(self, x, reverse=False):
|
196 |
+
if not reverse:
|
197 |
+
return x[:, self.forward_shuffle_idx]
|
198 |
+
else:
|
199 |
+
return x[:, self.backward_shuffle_idx]
|
200 |
+
|
201 |
+
|
202 |
+
class Random(nn.Module):
|
203 |
+
def __init__(self, H, W):
|
204 |
+
super().__init__()
|
205 |
+
indices = np.random.RandomState(1).permutation(H*W)
|
206 |
+
idx = torch.tensor(indices.ravel())
|
207 |
+
self.register_buffer('forward_shuffle_idx', idx)
|
208 |
+
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
209 |
+
|
210 |
+
def forward(self, x, reverse=False):
|
211 |
+
if not reverse:
|
212 |
+
return x[:, self.forward_shuffle_idx]
|
213 |
+
else:
|
214 |
+
return x[:, self.backward_shuffle_idx]
|
215 |
+
|
216 |
+
|
217 |
+
class AlternateParsing(AbstractPermuter):
|
218 |
+
def __init__(self, H, W):
|
219 |
+
super().__init__()
|
220 |
+
indices = np.arange(W*H).reshape(H,W)
|
221 |
+
for i in range(1, H, 2):
|
222 |
+
indices[i, :] = indices[i, ::-1]
|
223 |
+
idx = indices.flatten()
|
224 |
+
assert len(idx) == H*W
|
225 |
+
idx = torch.tensor(idx)
|
226 |
+
self.register_buffer('forward_shuffle_idx', idx)
|
227 |
+
self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
|
228 |
+
|
229 |
+
def forward(self, x, reverse=False):
|
230 |
+
if not reverse:
|
231 |
+
return x[:, self.forward_shuffle_idx]
|
232 |
+
else:
|
233 |
+
return x[:, self.backward_shuffle_idx]
|
234 |
+
|
235 |
+
|
236 |
+
if __name__ == "__main__":
|
237 |
+
p0 = AlternateParsing(16, 16)
|
238 |
+
print(p0.forward_shuffle_idx)
|
239 |
+
print(p0.backward_shuffle_idx)
|
240 |
+
|
241 |
+
x = torch.randint(0, 768, size=(11, 256))
|
242 |
+
y = p0(x)
|
243 |
+
xre = p0(y, reverse=True)
|
244 |
+
assert torch.equal(x, xre)
|
245 |
+
|
246 |
+
p1 = SpiralOut(2, 2)
|
247 |
+
print(p1.forward_shuffle_idx)
|
248 |
+
print(p1.backward_shuffle_idx)
|
taming/modules/util.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
def count_params(model):
|
6 |
+
total_params = sum(p.numel() for p in model.parameters())
|
7 |
+
return total_params
|
8 |
+
|
9 |
+
|
10 |
+
class ActNorm(nn.Module):
|
11 |
+
def __init__(self, num_features, logdet=False, affine=True,
|
12 |
+
allow_reverse_init=False):
|
13 |
+
assert affine
|
14 |
+
super().__init__()
|
15 |
+
self.logdet = logdet
|
16 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
17 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
18 |
+
self.allow_reverse_init = allow_reverse_init
|
19 |
+
|
20 |
+
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
21 |
+
|
22 |
+
def initialize(self, input):
|
23 |
+
with torch.no_grad():
|
24 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
25 |
+
mean = (
|
26 |
+
flatten.mean(1)
|
27 |
+
.unsqueeze(1)
|
28 |
+
.unsqueeze(2)
|
29 |
+
.unsqueeze(3)
|
30 |
+
.permute(1, 0, 2, 3)
|
31 |
+
)
|
32 |
+
std = (
|
33 |
+
flatten.std(1)
|
34 |
+
.unsqueeze(1)
|
35 |
+
.unsqueeze(2)
|
36 |
+
.unsqueeze(3)
|
37 |
+
.permute(1, 0, 2, 3)
|
38 |
+
)
|
39 |
+
|
40 |
+
self.loc.data.copy_(-mean)
|
41 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
42 |
+
|
43 |
+
def forward(self, input, reverse=False):
|
44 |
+
if reverse:
|
45 |
+
return self.reverse(input)
|
46 |
+
if len(input.shape) == 2:
|
47 |
+
input = input[:,:,None,None]
|
48 |
+
squeeze = True
|
49 |
+
else:
|
50 |
+
squeeze = False
|
51 |
+
|
52 |
+
_, _, height, width = input.shape
|
53 |
+
|
54 |
+
if self.training and self.initialized.item() == 0:
|
55 |
+
self.initialize(input)
|
56 |
+
self.initialized.fill_(1)
|
57 |
+
|
58 |
+
h = self.scale * (input + self.loc)
|
59 |
+
|
60 |
+
if squeeze:
|
61 |
+
h = h.squeeze(-1).squeeze(-1)
|
62 |
+
|
63 |
+
if self.logdet:
|
64 |
+
log_abs = torch.log(torch.abs(self.scale))
|
65 |
+
logdet = height*width*torch.sum(log_abs)
|
66 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
67 |
+
return h, logdet
|
68 |
+
|
69 |
+
return h
|
70 |
+
|
71 |
+
def reverse(self, output):
|
72 |
+
if self.training and self.initialized.item() == 0:
|
73 |
+
if not self.allow_reverse_init:
|
74 |
+
raise RuntimeError(
|
75 |
+
"Initializing ActNorm in reverse direction is "
|
76 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
self.initialize(output)
|
80 |
+
self.initialized.fill_(1)
|
81 |
+
|
82 |
+
if len(output.shape) == 2:
|
83 |
+
output = output[:,:,None,None]
|
84 |
+
squeeze = True
|
85 |
+
else:
|
86 |
+
squeeze = False
|
87 |
+
|
88 |
+
h = output / self.scale - self.loc
|
89 |
+
|
90 |
+
if squeeze:
|
91 |
+
h = h.squeeze(-1).squeeze(-1)
|
92 |
+
return h
|
93 |
+
|
94 |
+
|
95 |
+
class AbstractEncoder(nn.Module):
|
96 |
+
def __init__(self):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
def encode(self, *args, **kwargs):
|
100 |
+
raise NotImplementedError
|
101 |
+
|
102 |
+
|
103 |
+
class Labelator(AbstractEncoder):
|
104 |
+
"""Net2Net Interface for Class-Conditional Model"""
|
105 |
+
def __init__(self, n_classes, quantize_interface=True):
|
106 |
+
super().__init__()
|
107 |
+
self.n_classes = n_classes
|
108 |
+
self.quantize_interface = quantize_interface
|
109 |
+
|
110 |
+
def encode(self, c):
|
111 |
+
c = c[:,None]
|
112 |
+
if self.quantize_interface:
|
113 |
+
return c, None, [None, None, c.long()]
|
114 |
+
return c
|
115 |
+
|
116 |
+
|
117 |
+
class SOSProvider(AbstractEncoder):
|
118 |
+
# for unconditional training
|
119 |
+
def __init__(self, sos_token, quantize_interface=True):
|
120 |
+
super().__init__()
|
121 |
+
self.sos_token = sos_token
|
122 |
+
self.quantize_interface = quantize_interface
|
123 |
+
|
124 |
+
def encode(self, x):
|
125 |
+
# get batch size from data and replicate sos_token
|
126 |
+
c = torch.ones(x.shape[0], 1)*self.sos_token
|
127 |
+
c = c.long().to(x.device)
|
128 |
+
if self.quantize_interface:
|
129 |
+
return c, None, [None, None, c]
|
130 |
+
return c
|
taming/modules/vqvae/quantize.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from torch import einsum
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
|
9 |
+
class VectorQuantizer(nn.Module):
|
10 |
+
"""
|
11 |
+
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
12 |
+
____________________________________________
|
13 |
+
Discretization bottleneck part of the VQ-VAE.
|
14 |
+
Inputs:
|
15 |
+
- n_e : number of embeddings
|
16 |
+
- e_dim : dimension of embedding
|
17 |
+
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
18 |
+
_____________________________________________
|
19 |
+
"""
|
20 |
+
|
21 |
+
# NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
|
22 |
+
# a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
|
23 |
+
# used wherever VectorQuantizer has been used before and is additionally
|
24 |
+
# more efficient.
|
25 |
+
def __init__(self, n_e, e_dim, beta):
|
26 |
+
super(VectorQuantizer, self).__init__()
|
27 |
+
self.n_e = n_e
|
28 |
+
self.e_dim = e_dim
|
29 |
+
self.beta = beta
|
30 |
+
|
31 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
32 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
33 |
+
|
34 |
+
def forward(self, z):
|
35 |
+
"""
|
36 |
+
Inputs the output of the encoder network z and maps it to a discrete
|
37 |
+
one-hot vector that is the index of the closest embedding vector e_j
|
38 |
+
z (continuous) -> z_q (discrete)
|
39 |
+
z.shape = (batch, channel, height, width)
|
40 |
+
quantization pipeline:
|
41 |
+
1. get encoder input (B,C,H,W)
|
42 |
+
2. flatten input to (B*H*W,C)
|
43 |
+
"""
|
44 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
45 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
46 |
+
z_flattened = z.view(-1, self.e_dim)
|
47 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
48 |
+
|
49 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
50 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
51 |
+
torch.matmul(z_flattened, self.embedding.weight.t())
|
52 |
+
|
53 |
+
## could possible replace this here
|
54 |
+
# #\start...
|
55 |
+
# find closest encodings
|
56 |
+
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
57 |
+
|
58 |
+
min_encodings = torch.zeros(
|
59 |
+
min_encoding_indices.shape[0], self.n_e).to(z)
|
60 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
61 |
+
|
62 |
+
# dtype min encodings: torch.float32
|
63 |
+
# min_encodings shape: torch.Size([2048, 512])
|
64 |
+
# min_encoding_indices.shape: torch.Size([2048, 1])
|
65 |
+
|
66 |
+
# get quantized latent vectors
|
67 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
68 |
+
#.........\end
|
69 |
+
|
70 |
+
# with:
|
71 |
+
# .........\start
|
72 |
+
#min_encoding_indices = torch.argmin(d, dim=1)
|
73 |
+
#z_q = self.embedding(min_encoding_indices)
|
74 |
+
# ......\end......... (TODO)
|
75 |
+
|
76 |
+
# compute loss for embedding
|
77 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
78 |
+
torch.mean((z_q - z.detach()) ** 2)
|
79 |
+
|
80 |
+
# preserve gradients
|
81 |
+
z_q = z + (z_q - z).detach()
|
82 |
+
|
83 |
+
# perplexity
|
84 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
85 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
86 |
+
|
87 |
+
# reshape back to match original input shape
|
88 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
89 |
+
|
90 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
91 |
+
|
92 |
+
def get_codebook_entry(self, indices, shape):
|
93 |
+
# shape specifying (batch, height, width, channel)
|
94 |
+
# TODO: check for more easy handling with nn.Embedding
|
95 |
+
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
96 |
+
min_encodings.scatter_(1, indices[:,None], 1)
|
97 |
+
|
98 |
+
# get quantized latent vectors
|
99 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
100 |
+
|
101 |
+
if shape is not None:
|
102 |
+
z_q = z_q.view(shape)
|
103 |
+
|
104 |
+
# reshape back to match original input shape
|
105 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
106 |
+
|
107 |
+
return z_q
|
108 |
+
|
109 |
+
|
110 |
+
class GumbelQuantize(nn.Module):
|
111 |
+
"""
|
112 |
+
credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
|
113 |
+
Gumbel Softmax trick quantizer
|
114 |
+
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
|
115 |
+
https://arxiv.org/abs/1611.01144
|
116 |
+
"""
|
117 |
+
def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
|
118 |
+
kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
|
119 |
+
remap=None, unknown_index="random"):
|
120 |
+
super().__init__()
|
121 |
+
|
122 |
+
self.embedding_dim = embedding_dim
|
123 |
+
self.n_embed = n_embed
|
124 |
+
|
125 |
+
self.straight_through = straight_through
|
126 |
+
self.temperature = temp_init
|
127 |
+
self.kl_weight = kl_weight
|
128 |
+
|
129 |
+
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
|
130 |
+
self.embed = nn.Embedding(n_embed, embedding_dim)
|
131 |
+
|
132 |
+
self.use_vqinterface = use_vqinterface
|
133 |
+
|
134 |
+
self.remap = remap
|
135 |
+
if self.remap is not None:
|
136 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
137 |
+
self.re_embed = self.used.shape[0]
|
138 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
139 |
+
if self.unknown_index == "extra":
|
140 |
+
self.unknown_index = self.re_embed
|
141 |
+
self.re_embed = self.re_embed+1
|
142 |
+
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
143 |
+
f"Using {self.unknown_index} for unknown indices.")
|
144 |
+
else:
|
145 |
+
self.re_embed = n_embed
|
146 |
+
|
147 |
+
def remap_to_used(self, inds):
|
148 |
+
ishape = inds.shape
|
149 |
+
assert len(ishape)>1
|
150 |
+
inds = inds.reshape(ishape[0],-1)
|
151 |
+
used = self.used.to(inds)
|
152 |
+
match = (inds[:,:,None]==used[None,None,...]).long()
|
153 |
+
new = match.argmax(-1)
|
154 |
+
unknown = match.sum(2)<1
|
155 |
+
if self.unknown_index == "random":
|
156 |
+
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
157 |
+
else:
|
158 |
+
new[unknown] = self.unknown_index
|
159 |
+
return new.reshape(ishape)
|
160 |
+
|
161 |
+
def unmap_to_all(self, inds):
|
162 |
+
ishape = inds.shape
|
163 |
+
assert len(ishape)>1
|
164 |
+
inds = inds.reshape(ishape[0],-1)
|
165 |
+
used = self.used.to(inds)
|
166 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
167 |
+
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
168 |
+
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
169 |
+
return back.reshape(ishape)
|
170 |
+
|
171 |
+
def forward(self, z, temp=None, return_logits=False):
|
172 |
+
# force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
|
173 |
+
hard = self.straight_through if self.training else True
|
174 |
+
temp = self.temperature if temp is None else temp
|
175 |
+
|
176 |
+
logits = self.proj(z)
|
177 |
+
if self.remap is not None:
|
178 |
+
# continue only with used logits
|
179 |
+
full_zeros = torch.zeros_like(logits)
|
180 |
+
logits = logits[:,self.used,...]
|
181 |
+
|
182 |
+
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
|
183 |
+
if self.remap is not None:
|
184 |
+
# go back to all entries but unused set to zero
|
185 |
+
full_zeros[:,self.used,...] = soft_one_hot
|
186 |
+
soft_one_hot = full_zeros
|
187 |
+
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
|
188 |
+
|
189 |
+
# + kl divergence to the prior loss
|
190 |
+
qy = F.softmax(logits, dim=1)
|
191 |
+
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
|
192 |
+
|
193 |
+
ind = soft_one_hot.argmax(dim=1)
|
194 |
+
if self.remap is not None:
|
195 |
+
ind = self.remap_to_used(ind)
|
196 |
+
if self.use_vqinterface:
|
197 |
+
if return_logits:
|
198 |
+
return z_q, diff, (None, None, ind), logits
|
199 |
+
return z_q, diff, (None, None, ind)
|
200 |
+
return z_q, diff, ind
|
201 |
+
|
202 |
+
def get_codebook_entry(self, indices, shape):
|
203 |
+
b, h, w, c = shape
|
204 |
+
assert b*h*w == indices.shape[0]
|
205 |
+
indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
|
206 |
+
if self.remap is not None:
|
207 |
+
indices = self.unmap_to_all(indices)
|
208 |
+
one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
|
209 |
+
z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
|
210 |
+
return z_q
|
211 |
+
|
212 |
+
|
213 |
+
class VectorQuantizer2(nn.Module):
|
214 |
+
"""
|
215 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
216 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
217 |
+
"""
|
218 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
219 |
+
# backwards compatibility we use the buggy version by default, but you can
|
220 |
+
# specify legacy=False to fix it.
|
221 |
+
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
|
222 |
+
sane_index_shape=False, legacy=True):
|
223 |
+
super().__init__()
|
224 |
+
self.n_e = n_e
|
225 |
+
self.e_dim = e_dim
|
226 |
+
self.beta = beta
|
227 |
+
self.legacy = legacy
|
228 |
+
|
229 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
230 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
231 |
+
|
232 |
+
self.remap = remap
|
233 |
+
if self.remap is not None:
|
234 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
235 |
+
self.re_embed = self.used.shape[0]
|
236 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
237 |
+
if self.unknown_index == "extra":
|
238 |
+
self.unknown_index = self.re_embed
|
239 |
+
self.re_embed = self.re_embed+1
|
240 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
241 |
+
f"Using {self.unknown_index} for unknown indices.")
|
242 |
+
else:
|
243 |
+
self.re_embed = n_e
|
244 |
+
|
245 |
+
self.sane_index_shape = sane_index_shape
|
246 |
+
|
247 |
+
def remap_to_used(self, inds):
|
248 |
+
ishape = inds.shape
|
249 |
+
assert len(ishape)>1
|
250 |
+
inds = inds.reshape(ishape[0],-1)
|
251 |
+
used = self.used.to(inds)
|
252 |
+
match = (inds[:,:,None]==used[None,None,...]).long()
|
253 |
+
new = match.argmax(-1)
|
254 |
+
unknown = match.sum(2)<1
|
255 |
+
if self.unknown_index == "random":
|
256 |
+
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
257 |
+
else:
|
258 |
+
new[unknown] = self.unknown_index
|
259 |
+
return new.reshape(ishape)
|
260 |
+
|
261 |
+
def unmap_to_all(self, inds):
|
262 |
+
ishape = inds.shape
|
263 |
+
assert len(ishape)>1
|
264 |
+
inds = inds.reshape(ishape[0],-1)
|
265 |
+
used = self.used.to(inds)
|
266 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
267 |
+
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
268 |
+
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
269 |
+
return back.reshape(ishape)
|
270 |
+
|
271 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
272 |
+
assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
|
273 |
+
assert rescale_logits==False, "Only for interface compatible with Gumbel"
|
274 |
+
assert return_logits==False, "Only for interface compatible with Gumbel"
|
275 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
276 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
277 |
+
z_flattened = z.view(-1, self.e_dim)
|
278 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
279 |
+
|
280 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
281 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
282 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
283 |
+
|
284 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
285 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
286 |
+
perplexity = None
|
287 |
+
min_encodings = None
|
288 |
+
|
289 |
+
# compute loss for embedding
|
290 |
+
if not self.legacy:
|
291 |
+
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
292 |
+
torch.mean((z_q - z.detach()) ** 2)
|
293 |
+
else:
|
294 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
295 |
+
torch.mean((z_q - z.detach()) ** 2)
|
296 |
+
|
297 |
+
# preserve gradients
|
298 |
+
z_q = z + (z_q - z).detach()
|
299 |
+
|
300 |
+
# reshape back to match original input shape
|
301 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
302 |
+
|
303 |
+
if self.remap is not None:
|
304 |
+
min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
|
305 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
306 |
+
min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
|
307 |
+
|
308 |
+
if self.sane_index_shape:
|
309 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
310 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
311 |
+
|
312 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
313 |
+
|
314 |
+
def get_codebook_entry(self, indices, shape):
|
315 |
+
# shape specifying (batch, height, width, channel)
|
316 |
+
if self.remap is not None:
|
317 |
+
indices = indices.reshape(shape[0],-1) # add batch axis
|
318 |
+
indices = self.unmap_to_all(indices)
|
319 |
+
indices = indices.reshape(-1) # flatten again
|
320 |
+
|
321 |
+
# get quantized latent vectors
|
322 |
+
z_q = self.embedding(indices)
|
323 |
+
|
324 |
+
if shape is not None:
|
325 |
+
z_q = z_q.view(shape)
|
326 |
+
# reshape back to match original input shape
|
327 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
328 |
+
|
329 |
+
return z_q
|
330 |
+
|
331 |
+
class EmbeddingEMA(nn.Module):
|
332 |
+
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
|
333 |
+
super().__init__()
|
334 |
+
self.decay = decay
|
335 |
+
self.eps = eps
|
336 |
+
weight = torch.randn(num_tokens, codebook_dim)
|
337 |
+
self.weight = nn.Parameter(weight, requires_grad = False)
|
338 |
+
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
|
339 |
+
self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
|
340 |
+
self.update = True
|
341 |
+
|
342 |
+
def forward(self, embed_id):
|
343 |
+
return F.embedding(embed_id, self.weight)
|
344 |
+
|
345 |
+
def cluster_size_ema_update(self, new_cluster_size):
|
346 |
+
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
|
347 |
+
|
348 |
+
def embed_avg_ema_update(self, new_embed_avg):
|
349 |
+
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
350 |
+
|
351 |
+
def weight_update(self, num_tokens):
|
352 |
+
n = self.cluster_size.sum()
|
353 |
+
smoothed_cluster_size = (
|
354 |
+
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
355 |
+
)
|
356 |
+
#normalize embedding average with smoothed cluster size
|
357 |
+
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
358 |
+
self.weight.data.copy_(embed_normalized)
|
359 |
+
|
360 |
+
|
361 |
+
class EMAVectorQuantizer(nn.Module):
|
362 |
+
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
|
363 |
+
remap=None, unknown_index="random"):
|
364 |
+
super().__init__()
|
365 |
+
self.codebook_dim = codebook_dim
|
366 |
+
self.num_tokens = num_tokens
|
367 |
+
self.beta = beta
|
368 |
+
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
|
369 |
+
|
370 |
+
self.remap = remap
|
371 |
+
if self.remap is not None:
|
372 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
373 |
+
self.re_embed = self.used.shape[0]
|
374 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
375 |
+
if self.unknown_index == "extra":
|
376 |
+
self.unknown_index = self.re_embed
|
377 |
+
self.re_embed = self.re_embed+1
|
378 |
+
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
379 |
+
f"Using {self.unknown_index} for unknown indices.")
|
380 |
+
else:
|
381 |
+
self.re_embed = n_embed
|
382 |
+
|
383 |
+
def remap_to_used(self, inds):
|
384 |
+
ishape = inds.shape
|
385 |
+
assert len(ishape)>1
|
386 |
+
inds = inds.reshape(ishape[0],-1)
|
387 |
+
used = self.used.to(inds)
|
388 |
+
match = (inds[:,:,None]==used[None,None,...]).long()
|
389 |
+
new = match.argmax(-1)
|
390 |
+
unknown = match.sum(2)<1
|
391 |
+
if self.unknown_index == "random":
|
392 |
+
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
393 |
+
else:
|
394 |
+
new[unknown] = self.unknown_index
|
395 |
+
return new.reshape(ishape)
|
396 |
+
|
397 |
+
def unmap_to_all(self, inds):
|
398 |
+
ishape = inds.shape
|
399 |
+
assert len(ishape)>1
|
400 |
+
inds = inds.reshape(ishape[0],-1)
|
401 |
+
used = self.used.to(inds)
|
402 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
403 |
+
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
404 |
+
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
405 |
+
return back.reshape(ishape)
|
406 |
+
|
407 |
+
def forward(self, z):
|
408 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
409 |
+
#z, 'b c h w -> b h w c'
|
410 |
+
z = rearrange(z, 'b c h w -> b h w c')
|
411 |
+
z_flattened = z.reshape(-1, self.codebook_dim)
|
412 |
+
|
413 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
414 |
+
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
415 |
+
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
416 |
+
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
|
417 |
+
|
418 |
+
|
419 |
+
encoding_indices = torch.argmin(d, dim=1)
|
420 |
+
|
421 |
+
z_q = self.embedding(encoding_indices).view(z.shape)
|
422 |
+
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
423 |
+
avg_probs = torch.mean(encodings, dim=0)
|
424 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
425 |
+
|
426 |
+
if self.training and self.embedding.update:
|
427 |
+
#EMA cluster size
|
428 |
+
encodings_sum = encodings.sum(0)
|
429 |
+
self.embedding.cluster_size_ema_update(encodings_sum)
|
430 |
+
#EMA embedding average
|
431 |
+
embed_sum = encodings.transpose(0,1) @ z_flattened
|
432 |
+
self.embedding.embed_avg_ema_update(embed_sum)
|
433 |
+
#normalize embed_avg and update weight
|
434 |
+
self.embedding.weight_update(self.num_tokens)
|
435 |
+
|
436 |
+
# compute loss for embedding
|
437 |
+
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
438 |
+
|
439 |
+
# preserve gradients
|
440 |
+
z_q = z + (z_q - z).detach()
|
441 |
+
|
442 |
+
# reshape back to match original input shape
|
443 |
+
#z_q, 'b h w c -> b c h w'
|
444 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w')
|
445 |
+
return z_q, loss, (perplexity, encodings, encoding_indices)
|
taming/util.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, hashlib
|
2 |
+
import requests
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
URL_MAP = {
|
6 |
+
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
7 |
+
}
|
8 |
+
|
9 |
+
CKPT_MAP = {
|
10 |
+
"vgg_lpips": "vgg.pth"
|
11 |
+
}
|
12 |
+
|
13 |
+
MD5_MAP = {
|
14 |
+
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
def download(url, local_path, chunk_size=1024):
|
19 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
20 |
+
with requests.get(url, stream=True) as r:
|
21 |
+
total_size = int(r.headers.get("content-length", 0))
|
22 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
23 |
+
with open(local_path, "wb") as f:
|
24 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
25 |
+
if data:
|
26 |
+
f.write(data)
|
27 |
+
pbar.update(chunk_size)
|
28 |
+
|
29 |
+
|
30 |
+
def md5_hash(path):
|
31 |
+
with open(path, "rb") as f:
|
32 |
+
content = f.read()
|
33 |
+
return hashlib.md5(content).hexdigest()
|
34 |
+
|
35 |
+
|
36 |
+
def get_ckpt_path(name, root, check=False):
|
37 |
+
assert name in URL_MAP
|
38 |
+
path = os.path.join(root, CKPT_MAP[name])
|
39 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
40 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
41 |
+
download(URL_MAP[name], path)
|
42 |
+
md5 = md5_hash(path)
|
43 |
+
assert md5 == MD5_MAP[name], md5
|
44 |
+
return path
|
45 |
+
|
46 |
+
|
47 |
+
class KeyNotFoundError(Exception):
|
48 |
+
def __init__(self, cause, keys=None, visited=None):
|
49 |
+
self.cause = cause
|
50 |
+
self.keys = keys
|
51 |
+
self.visited = visited
|
52 |
+
messages = list()
|
53 |
+
if keys is not None:
|
54 |
+
messages.append("Key not found: {}".format(keys))
|
55 |
+
if visited is not None:
|
56 |
+
messages.append("Visited: {}".format(visited))
|
57 |
+
messages.append("Cause:\n{}".format(cause))
|
58 |
+
message = "\n".join(messages)
|
59 |
+
super().__init__(message)
|
60 |
+
|
61 |
+
|
62 |
+
def retrieve(
|
63 |
+
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
|
64 |
+
):
|
65 |
+
"""Given a nested list or dict return the desired value at key expanding
|
66 |
+
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
|
67 |
+
is done in-place.
|
68 |
+
|
69 |
+
Parameters
|
70 |
+
----------
|
71 |
+
list_or_dict : list or dict
|
72 |
+
Possibly nested list or dictionary.
|
73 |
+
key : str
|
74 |
+
key/to/value, path like string describing all keys necessary to
|
75 |
+
consider to get to the desired value. List indices can also be
|
76 |
+
passed here.
|
77 |
+
splitval : str
|
78 |
+
String that defines the delimiter between keys of the
|
79 |
+
different depth levels in `key`.
|
80 |
+
default : obj
|
81 |
+
Value returned if :attr:`key` is not found.
|
82 |
+
expand : bool
|
83 |
+
Whether to expand callable nodes on the path or not.
|
84 |
+
|
85 |
+
Returns
|
86 |
+
-------
|
87 |
+
The desired value or if :attr:`default` is not ``None`` and the
|
88 |
+
:attr:`key` is not found returns ``default``.
|
89 |
+
|
90 |
+
Raises
|
91 |
+
------
|
92 |
+
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
|
93 |
+
``None``.
|
94 |
+
"""
|
95 |
+
|
96 |
+
keys = key.split(splitval)
|
97 |
+
|
98 |
+
success = True
|
99 |
+
try:
|
100 |
+
visited = []
|
101 |
+
parent = None
|
102 |
+
last_key = None
|
103 |
+
for key in keys:
|
104 |
+
if callable(list_or_dict):
|
105 |
+
if not expand:
|
106 |
+
raise KeyNotFoundError(
|
107 |
+
ValueError(
|
108 |
+
"Trying to get past callable node with expand=False."
|
109 |
+
),
|
110 |
+
keys=keys,
|
111 |
+
visited=visited,
|
112 |
+
)
|
113 |
+
list_or_dict = list_or_dict()
|
114 |
+
parent[last_key] = list_or_dict
|
115 |
+
|
116 |
+
last_key = key
|
117 |
+
parent = list_or_dict
|
118 |
+
|
119 |
+
try:
|
120 |
+
if isinstance(list_or_dict, dict):
|
121 |
+
list_or_dict = list_or_dict[key]
|
122 |
+
else:
|
123 |
+
list_or_dict = list_or_dict[int(key)]
|
124 |
+
except (KeyError, IndexError, ValueError) as e:
|
125 |
+
raise KeyNotFoundError(e, keys=keys, visited=visited)
|
126 |
+
|
127 |
+
visited += [key]
|
128 |
+
# final expansion of retrieved value
|
129 |
+
if expand and callable(list_or_dict):
|
130 |
+
list_or_dict = list_or_dict()
|
131 |
+
parent[last_key] = list_or_dict
|
132 |
+
except KeyNotFoundError as e:
|
133 |
+
if default is None:
|
134 |
+
raise e
|
135 |
+
else:
|
136 |
+
list_or_dict = default
|
137 |
+
success = False
|
138 |
+
|
139 |
+
if not pass_success:
|
140 |
+
return list_or_dict
|
141 |
+
else:
|
142 |
+
return list_or_dict, success
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
config = {"keya": "a",
|
147 |
+
"keyb": "b",
|
148 |
+
"keyc":
|
149 |
+
{"cc1": 1,
|
150 |
+
"cc2": 2,
|
151 |
+
}
|
152 |
+
}
|
153 |
+
from omegaconf import OmegaConf
|
154 |
+
config = OmegaConf.create(config)
|
155 |
+
print(config)
|
156 |
+
retrieve(config, "keya")
|
157 |
+
|