UPstud commited on
Commit
519be3e
Β·
verified Β·
1 Parent(s): e9c9507

Upload 5 files

Browse files
Files changed (5) hide show
  1. CtrlColor_environ.yaml +40 -0
  2. app.py +524 -0
  3. config.py +1 -0
  4. requirements.txt +29 -0
  5. share.py +8 -0
CtrlColor_environ.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CtrlColor
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.5
7
+ - pip=20.3
8
+ - cudatoolkit=11.3
9
+ - pytorch=1.12.1
10
+ - torchvision=0.13.1
11
+ - numpy=1.23.1
12
+ - pip:
13
+ - gradio==3.31.0
14
+ - gradio-client==0.2.5
15
+ - albumentations==1.3.0
16
+ - opencv-python==4.9.0.80
17
+ - opencv-python-headless==4.5.5.64
18
+ - imageio==2.9.0
19
+ - imageio-ffmpeg==0.4.2
20
+ - pytorch-lightning==1.5.0
21
+ - omegaconf==2.1.1
22
+ - test-tube>=0.7.5
23
+ - streamlit==1.12.1
24
+ - webdataset==0.2.5
25
+ - kornia==0.6
26
+ - open_clip_torch==2.0.2
27
+ - invisible-watermark>=0.1.5
28
+ - streamlit-drawable-canvas==0.8.0
29
+ - torchmetrics==0.6.0
30
+ - addict==2.4.0
31
+ - yapf==0.32.0
32
+ - prettytable==3.6.0
33
+ - basicsr==1.4.2
34
+ - salesforce-lavis==1.0.2
35
+ - grpcio==1.60
36
+ - pydantic==1.10.5
37
+ - spacy==3.5.1
38
+ - typer==0.7.0
39
+ - typing-extensions==4.4.0
40
+ - fastapi==0.92.0
app.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from share import *
3
+ import config
4
+
5
+ import cv2
6
+ import einops
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ import random
11
+
12
+ from pytorch_lightning import seed_everything
13
+ from annotator.util import resize_image
14
+ from cldm.model import create_model, load_state_dict
15
+ from cldm.ddim_haced_sag_step import DDIMSampler
16
+ from lavis.models import load_model_and_preprocess
17
+ from PIL import Image
18
+ import tqdm
19
+
20
+ from ldm.models.autoencoder_train import AutoencoderKL
21
+
22
+ ckpt_path="./pretrained_models/main_model.ckpt"
23
+
24
+ model = create_model('./models/cldm_v15_inpainting_infer1.yaml').cpu()
25
+ model.load_state_dict(load_state_dict(ckpt_path, location='cuda'),strict=False)
26
+ model = model.cuda()
27
+
28
+ ddim_sampler = DDIMSampler(model)
29
+
30
+
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ BLIP_model, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=device)
33
+
34
+ vae_model_ckpt_path="./pretrained_models/content-guided_deformable_vae.ckpt"
35
+
36
+ def load_vae():
37
+ init_config = {
38
+ "embed_dim": 4,
39
+ "monitor": "val/rec_loss",
40
+ "ddconfig":{
41
+ "double_z": True,
42
+ "z_channels": 4,
43
+ "resolution": 256,
44
+ "in_channels": 3,
45
+ "out_ch": 3,
46
+ "ch": 128,
47
+ "ch_mult":[1,2,4,4],
48
+ "num_res_blocks": 2,
49
+ "attn_resolutions": [],
50
+ "dropout": 0.0,
51
+ },
52
+ "lossconfig":{
53
+ "target": "ldm.modules.losses.LPIPSWithDiscriminator",
54
+ "params":{
55
+ "disc_start": 501,
56
+ "kl_weight": 0,
57
+ "disc_weight": 0.025,
58
+ "disc_factor": 1.0
59
+ }
60
+ }
61
+ }
62
+ vae = AutoencoderKL(**init_config)
63
+ vae.load_state_dict(load_state_dict(vae_model_ckpt_path, location='cuda'))
64
+ vae = vae.cuda()
65
+ return vae
66
+
67
+ vae_model=load_vae()
68
+
69
+ def encode_mask(mask,masked_image):
70
+ mask = torch.nn.functional.interpolate(mask, size=(mask.shape[2] // 8, mask.shape[3] // 8))
71
+ # mask=torch.cat([mask] * 2) #if do_classifier_free_guidance else mask
72
+ mask = mask.to(device="cuda")
73
+ # do_classifier_free_guidance=False
74
+ masked_image_latents = model.get_first_stage_encoding(model.encode_first_stage(masked_image.cuda())).detach()
75
+ return mask,masked_image_latents
76
+
77
+ def get_mask(input_image,hint_image):
78
+ mask=input_image.copy()
79
+ H,W,C=input_image.shape
80
+ for i in range(H):
81
+ for j in range(W):
82
+ if input_image[i,j,0]==hint_image[i,j,0]:
83
+ # print(input_image[i,j,0])
84
+ mask[i,j,:]=255.
85
+ else:
86
+ mask[i,j,:]=0. #input_image[i,j,:]
87
+ kernel=cv2.getStructuringElement(cv2.MORPH_RECT,(3,3))
88
+ mask=cv2.morphologyEx(np.array(mask),cv2.MORPH_OPEN,kernel,iterations=1)
89
+ return mask
90
+
91
+ def prepare_mask_and_masked_image(image, mask):
92
+ """
93
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
94
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
95
+ ``image`` and ``1`` for the ``mask``.
96
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
97
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
98
+ Args:
99
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
100
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
101
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
102
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
103
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
104
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
105
+ Raises:
106
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
107
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
108
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
109
+ (ot the other way around).
110
+ Returns:
111
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
112
+ dimensions: ``batch x channels x height x width``.
113
+ """
114
+ if isinstance(image, torch.Tensor):
115
+ if not isinstance(mask, torch.Tensor):
116
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
117
+
118
+ # Batch single image
119
+ if image.ndim == 3:
120
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
121
+ image = image.unsqueeze(0)
122
+
123
+ # Batch and add channel dim for single mask
124
+ if mask.ndim == 2:
125
+ mask = mask.unsqueeze(0).unsqueeze(0)
126
+
127
+ # Batch single mask or add channel dim
128
+ if mask.ndim == 3:
129
+ # Single batched mask, no channel dim or single mask not batched but channel dim
130
+ if mask.shape[0] == 1:
131
+ mask = mask.unsqueeze(0)
132
+
133
+ # Batched masks no channel dim
134
+ else:
135
+ mask = mask.unsqueeze(1)
136
+
137
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
138
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
139
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
140
+
141
+ # Check image is in [-1, 1]
142
+ if image.min() < -1 or image.max() > 1:
143
+ raise ValueError("Image should be in [-1, 1] range")
144
+
145
+ # Check mask is in [0, 1]
146
+ if mask.min() < 0 or mask.max() > 1:
147
+ raise ValueError("Mask should be in [0, 1] range")
148
+
149
+ # Binarize mask
150
+ mask[mask < 0.5] = 0
151
+ mask[mask >= 0.5] = 1
152
+
153
+ # Image as float32
154
+ image = image.to(dtype=torch.float32)
155
+ elif isinstance(mask, torch.Tensor):
156
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
157
+ else:
158
+ # preprocess image
159
+ if isinstance(image, (Image.Image, np.ndarray)):
160
+ image = [image]
161
+
162
+ if isinstance(image, list) and isinstance(image[0], Image.Image):
163
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
164
+ image = np.concatenate(image, axis=0)
165
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
166
+ image = np.concatenate([i[None, :] for i in image], axis=0)
167
+
168
+ image = image.transpose(0, 3, 1, 2)
169
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
170
+
171
+ # preprocess mask
172
+ if isinstance(mask, (Image.Image, np.ndarray)):
173
+ mask = [mask]
174
+
175
+ if isinstance(mask, list) and isinstance(mask[0], Image.Image):
176
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
177
+ mask = mask.astype(np.float32) / 255.0
178
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
179
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
180
+
181
+ mask[mask < 0.5] = 0
182
+ mask[mask >= 0.5] = 1
183
+ mask = torch.from_numpy(mask)
184
+
185
+ masked_image = image * (mask < 0.5)
186
+
187
+ return mask, masked_image
188
+
189
+ # generate image
190
+ generator = torch.manual_seed(859311133)#0
191
+ def path2L(img_path):
192
+ raw_image = cv2.imread(img_path)
193
+ raw_image = cv2.cvtColor(raw_image,cv2.COLOR_BGR2LAB)
194
+ raw_image_input = cv2.merge([raw_image[:,:,0],raw_image[:,:,0],raw_image[:,:,0]])
195
+ return raw_image_input
196
+
197
+ def is_gray_scale(img, threshold=10):
198
+ img = Image.fromarray(img)
199
+ if len(img.getbands()) == 1:
200
+ return True
201
+ img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16)
202
+ img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16)
203
+ img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16)
204
+ diff1 = (img1 - img2).var()
205
+ diff2 = (img2 - img3).var()
206
+ diff3 = (img3 - img1).var()
207
+ diff_sum = (diff1 + diff2 + diff3) / 3.0
208
+ if diff_sum <= threshold:
209
+ return True
210
+ else:
211
+ return False
212
+
213
+ def randn_tensor(
214
+ shape,
215
+ generator= None,
216
+ device= None,
217
+ dtype=None,
218
+ layout= None,
219
+ ):
220
+ """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
221
+ passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
222
+ is always created on the CPU.
223
+ """
224
+ # device on which tensor is created defaults to device
225
+ rand_device = device
226
+ batch_size = shape[0]
227
+
228
+ layout = layout or torch.strided
229
+ device = device or torch.device("cpu")
230
+
231
+ if generator is not None:
232
+ gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
233
+ if gen_device_type != device.type and gen_device_type == "cpu":
234
+ rand_device = "cpu"
235
+ if device != "mps":
236
+ print("The passed generator was created on 'cpu' even though a tensor on {device} was expected.")
237
+ # logger.info(
238
+ # f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
239
+ # f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
240
+ # f" slighly speed up this function by passing a generator that was created on the {device} device."
241
+ # )
242
+ elif gen_device_type != device.type and gen_device_type == "cuda":
243
+ raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
244
+
245
+ # make sure generator list of length 1 is treated like a non-list
246
+ if isinstance(generator, list) and len(generator) == 1:
247
+ generator = generator[0]
248
+
249
+ if isinstance(generator, list):
250
+ shape = (1,) + shape[1:]
251
+ latents = [
252
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
253
+ for i in range(batch_size)
254
+ ]
255
+ latents = torch.cat(latents, dim=0).to(device)
256
+ else:
257
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
258
+
259
+ return latents
260
+
261
+ def add_noise(
262
+ original_samples: torch.FloatTensor,
263
+ noise: torch.FloatTensor,
264
+ timesteps: torch.IntTensor,
265
+ ) -> torch.FloatTensor:
266
+ betas = torch.linspace(0.00085, 0.0120, 1000, dtype=torch.float32)
267
+ alphas = 1.0 - betas
268
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
269
+ alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
270
+ timesteps = timesteps.to(original_samples.device)
271
+
272
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
273
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
274
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
275
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
276
+
277
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
278
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
279
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
280
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
281
+
282
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
283
+
284
+ return noisy_samples
285
+
286
+ def set_timesteps(num_inference_steps: int, timestep_spacing="leading",device=None):
287
+ """
288
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
289
+
290
+ Args:
291
+ num_inference_steps (`int`):
292
+ the number of diffusion steps used when generating samples with a pre-trained model.
293
+ """
294
+ num_train_timesteps=1000
295
+ if num_inference_steps > num_train_timesteps:
296
+ raise ValueError(
297
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
298
+ f" {num_train_timesteps} as the unet model trained with this scheduler can only handle"
299
+ f" maximal {num_train_timesteps} timesteps."
300
+ )
301
+
302
+ num_inference_steps = num_inference_steps
303
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
304
+ if timestep_spacing == "linspace":
305
+ timesteps = (
306
+ np.linspace(0, num_train_timesteps - 1, num_inference_steps)
307
+ .round()[::-1]
308
+ .copy()
309
+ .astype(np.int64)
310
+ )
311
+ elif timestep_spacing == "leading":
312
+ step_ratio = num_train_timesteps // num_inference_steps
313
+ # creates integer timesteps by multiplying by ratio
314
+ # casting to int to avoid issues when num_inference_step is power of 3
315
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
316
+ # timesteps += steps_offset
317
+ elif timestep_spacing == "trailing":
318
+ step_ratio = num_train_timesteps / num_inference_steps
319
+ # creates integer timesteps by multiplying by ratio
320
+ # casting to int to avoid issues when num_inference_step is power of 3
321
+ timesteps = np.round(np.arange(num_train_timesteps, 0, -step_ratio)).astype(np.int64)
322
+ timesteps -= 1
323
+ else:
324
+ raise ValueError(
325
+ f"{timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
326
+ )
327
+
328
+ timesteps = torch.from_numpy(timesteps).to(device)
329
+ return timesteps
330
+
331
+ def get_timesteps(num_inference_steps, timesteps_set, strength, device):
332
+ # get the original timestep using init_timestep
333
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
334
+
335
+ t_start = max(num_inference_steps - init_timestep, 0)
336
+ timesteps = timesteps_set[t_start * 1 :]
337
+
338
+ return timesteps, num_inference_steps - t_start
339
+
340
+
341
+ def get_noised_image_latents(img,W,H,ddim_steps,strength,seed,device):
342
+ img1 = [cv2.resize(img,(W,H))]
343
+ img1 = np.concatenate([i[None, :] for i in img1], axis=0)
344
+ img1 = img1.transpose(0, 3, 1, 2)
345
+ img1 = torch.from_numpy(img1).to(dtype=torch.float32) /127.5 - 1.0
346
+
347
+ image_latents=model.get_first_stage_encoding(model.encode_first_stage(img1.cuda())).detach()
348
+ shape=image_latents.shape
349
+ generator = torch.manual_seed(seed)
350
+
351
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
352
+
353
+ timesteps_set=set_timesteps(ddim_steps,timestep_spacing="linspace", device=device)
354
+ timesteps, num_inference_steps = get_timesteps(ddim_steps, timesteps_set, strength, device)
355
+ latent_timestep = timesteps[1].repeat(1 * 1)
356
+
357
+ init_latents = add_noise(image_latents, noise, torch.tensor(latent_timestep))
358
+ for j in range(0, 1000, 100):
359
+
360
+ x_samples=model.decode_first_stage(add_noise(image_latents, noise, torch.tensor(j)))
361
+ init_image=(einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
362
+
363
+ cv2.imwrite("./initlatents1/"+str(j)+"init_image.png",cv2.cvtColor(init_image[0],cv2.COLOR_RGB2BGR))
364
+ return init_latents
365
+
366
+ def process(using_deformable_vae,change_according_to_strokes,iterative_editing,input_image,hint_image,prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, sag_scale,SAG_influence_step, seed, eta):
367
+ torch.cuda.empty_cache()
368
+ with torch.no_grad():
369
+ ref_flag=True
370
+ input_image_ori=input_image
371
+ if is_gray_scale(input_image):
372
+ print("It is a greyscale image.")
373
+ # mask=get_mask(input_image,hint_image)
374
+ else:
375
+ print("It is a color image.")
376
+ input_image_ori=input_image
377
+ input_image=cv2.cvtColor(input_image,cv2.COLOR_RGB2LAB)[:,:,0]
378
+ input_image=cv2.merge([input_image,input_image,input_image])
379
+ mask=get_mask(input_image_ori,hint_image)
380
+ cv2.imwrite("gradio_mask1.png",mask)
381
+
382
+ if iterative_editing:
383
+ mask=255-mask
384
+ if change_according_to_strokes:
385
+ hint_image=mask/255.*hint_image+(1-mask/255.)*input_image_ori
386
+ else:
387
+ hint_image=mask/255.*input_image+(1-mask/255.)*input_image_ori
388
+ else:
389
+ hint_image=mask/255.*input_image+(1-mask/255.)*hint_image
390
+ hint_image=hint_image.astype(np.uint8)
391
+ if len(prompt)==0:
392
+ image = Image.fromarray(input_image)
393
+ image = vis_processors["eval"](image).unsqueeze(0).to(device)
394
+ prompt = BLIP_model.generate({"image": image})[0]
395
+ if "a black and white photo of" in prompt or "black and white photograph of" in prompt:
396
+ prompt=prompt.replace(prompt[:prompt.find("of")+3],"")
397
+ print(prompt)
398
+ H_ori,W_ori,C_ori=input_image.shape
399
+ img = resize_image(input_image, image_resolution)
400
+ mask = resize_image(mask, image_resolution)
401
+ hint_image =resize_image(hint_image,image_resolution)
402
+ mask,masked_image=prepare_mask_and_masked_image(Image.fromarray(hint_image),Image.fromarray(mask))
403
+ mask,masked_image_latents=encode_mask(mask,masked_image)
404
+ H, W, C = img.shape
405
+
406
+ # if ref_image is None:
407
+ ref_image=np.array([[[0]*C]*W]*H).astype(np.float32)
408
+ # print(ref_image.shape)
409
+ # ref_flag=False
410
+ ref_image=resize_image(ref_image,image_resolution)
411
+
412
+ # cv2.imwrite("exemplar_image.png",cv2.cvtColor(ref_image,cv2.COLOR_RGB2BGR))
413
+
414
+ # ddim_steps=1
415
+ control = torch.from_numpy(img.copy()).float().cuda() / 255.0
416
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
417
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
418
+
419
+ if seed == -1:
420
+ seed = random.randint(0, 65535)
421
+ seed_everything(seed)
422
+
423
+ ref_image=cv2.resize(ref_image,(W,H))
424
+
425
+ ref_image=torch.from_numpy(ref_image).cuda().unsqueeze(0)
426
+
427
+ init_latents=None
428
+
429
+ if config.save_memory:
430
+ model.low_vram_shift(is_diffusing=False)
431
+
432
+ print("no reference images, using Frozen encoder")
433
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
434
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
435
+ shape = (4, H // 8, W // 8)
436
+
437
+ if config.save_memory:
438
+ model.low_vram_shift(is_diffusing=True)
439
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
440
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
441
+ samples, intermediates = ddim_sampler.sample(model,ddim_steps, num_samples,
442
+ shape, cond, mask=mask, masked_image_latents=masked_image_latents,verbose=False, eta=eta,
443
+ # x_T=image_latents,
444
+ x_T=init_latents,
445
+ unconditional_guidance_scale=scale,
446
+ sag_scale = sag_scale,
447
+ SAG_influence_step=SAG_influence_step,
448
+ noise = noise,
449
+ unconditional_conditioning=un_cond)
450
+
451
+
452
+ if config.save_memory:
453
+ model.low_vram_shift(is_diffusing=False)
454
+
455
+ if not using_deformable_vae:
456
+ x_samples = model.decode_first_stage(samples)
457
+ else:
458
+ samples = model.decode_first_stage_before_vae(samples)
459
+ gray_content_z=vae_model.get_gray_content_z(torch.from_numpy(img.copy()).float().cuda() / 255.0)
460
+ # print(gray_content_z.shape)
461
+ x_samples = vae_model.decode(samples,gray_content_z)
462
+
463
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
464
+
465
+ #single image replace L channel
466
+ results_ori = [x_samples[i] for i in range(num_samples)]
467
+ results_ori=[cv2.resize(i,(W_ori,H_ori),interpolation=cv2.INTER_LANCZOS4) for i in results_ori]
468
+
469
+ cv2.imwrite("result_ori.png",cv2.cvtColor(results_ori[0],cv2.COLOR_RGB2BGR))
470
+
471
+ results_tmp=[cv2.cvtColor(np.array(i),cv2.COLOR_RGB2LAB) for i in results_ori]
472
+ results=[cv2.merge([input_image[:,:,0],tmp[:,:,1],tmp[:,:,2]]) for tmp in results_tmp]
473
+ results_mergeL=[cv2.cvtColor(np.asarray(i),cv2.COLOR_LAB2RGB) for i in results]#cv2.COLOR_LAB2BGR)
474
+ cv2.imwrite("output.png",cv2.cvtColor(results_mergeL[0],cv2.COLOR_RGB2BGR))
475
+ return results_mergeL
476
+
477
+ def get_grayscale_img(img, progress=gr.Progress(track_tqdm=True)):
478
+ torch.cuda.empty_cache()
479
+ for j in tqdm.tqdm(range(1),desc="Uploading input..."):
480
+ return img,"Uploading input image done."
481
+
482
+ block = gr.Blocks().queue()
483
+ with block:
484
+ with gr.Row():
485
+ gr.Markdown("## Control-Color")#("## Color-Anything")#Control Stable Diffusion with L channel
486
+ with gr.Row():
487
+ with gr.Column():
488
+ # input_image = gr.Image(source='upload', type="numpy")
489
+ grayscale_img = gr.Image(visible=False, type="numpy")
490
+ input_image = gr.Image(source='upload',tool='color-sketch',interactive=True)
491
+ Grayscale_button = gr.Button(value="Upload input image")
492
+ text_out = gr.Textbox(value="Please upload input image first, then draw the strokes or input text prompts or give reference images as you wish.")
493
+ prompt = gr.Textbox(label="Prompt")
494
+ change_according_to_strokes = gr.Checkbox(label='Change according to strokes\' color', value=True)
495
+ iterative_editing = gr.Checkbox(label='Only change the strokes\' area', value=False)
496
+ using_deformable_vae = gr.Checkbox(label='Using deformable vae. (Less color overflow)', value=False)
497
+ # with gr.Accordion("Input Reference", open=False):
498
+ # ref_image = gr.Image(source='upload', type="numpy")
499
+ run_button = gr.Button(label="Upload prompts/strokes (optional) and Run",value="Upload prompts/strokes (optional) and Run")
500
+ with gr.Accordion("Advanced options", open=False):
501
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
502
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
503
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
504
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
505
+ #detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=384, step=1)
506
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
507
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.0, step=0.1)#value=9.0
508
+ sag_scale = gr.Slider(label="SAG Scale", minimum=0.0, maximum=1.0, value=0.05, step=0.01)#0.08
509
+ SAG_influence_step = gr.Slider(label="1000-SAG influence step", minimum=0, maximum=900, value=600, step=50)
510
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)#94433242802
511
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
512
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, detailed, real')#extremely detailed
513
+ n_prompt = gr.Textbox(label="Negative Prompt",
514
+ value='a black and white photo, longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
515
+ with gr.Column():
516
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
517
+ # grayscale_img = gr.Image(interactive=False,visible=False)
518
+
519
+ Grayscale_button.click(fn=get_grayscale_img,inputs=input_image,outputs=[grayscale_img,text_out])
520
+ ips = [using_deformable_vae,change_according_to_strokes,iterative_editing,grayscale_img,input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale,sag_scale,SAG_influence_step, seed, eta]
521
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
522
+
523
+
524
+ block.launch(server_name='0.0.0.0',share=True)
config.py ADDED
@@ -0,0 +1 @@
 
 
1
+ save_memory = False
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ gradio-client
3
+ albumentations==1.3.0
4
+ opencv-python==4.9.0.80
5
+ opencv-python-headless==4.5.5.64
6
+ imageio==2.9.0
7
+ imageio-ffmpeg==0.4.2
8
+ pytorch-lightning==1.5.0
9
+ omegaconf==2.1.1
10
+ test-tube>=0.7.5
11
+ streamlit==1.12.1
12
+ webdataset==0.2.5
13
+ kornia==0.6
14
+ open_clip_torch==2.0.2
15
+ invisible-watermark>=0.1.5
16
+ streamlit-drawable-canvas==0.8.0
17
+ torchmetrics==0.6.0
18
+ addict==2.4.0
19
+ yapf==0.32.0
20
+ prettytable==3.6.0
21
+ basicsr==1.4.2
22
+ salesforce-lavis==1.0.2
23
+ grpcio==1.60
24
+ pydantic==1.10.5
25
+ wandb==0.15.12
26
+ spacy==3.5.1
27
+ typer==0.7.0
28
+ typing-extensions==4.4.0
29
+ fastapi==0.92.0
share.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ from cldm.hack import disable_verbosity, enable_sliced_attention
3
+
4
+
5
+ disable_verbosity()
6
+
7
+ if config.save_memory:
8
+ enable_sliced_attention()