UPstud commited on
Commit
6e4fd82
Β·
verified Β·
1 Parent(s): 906e982

Upload 10 files

Browse files
cldm/__pycache__/cldm.cpython-38.pyc ADDED
Binary file (12 kB). View file
 
cldm/__pycache__/ddim_haced_sag_step.cpython-38.pyc ADDED
Binary file (13.3 kB). View file
 
cldm/__pycache__/hack.cpython-310.pyc ADDED
Binary file (3.88 kB). View file
 
cldm/__pycache__/hack.cpython-38.pyc ADDED
Binary file (3.89 kB). View file
 
cldm/__pycache__/model.cpython-38.pyc ADDED
Binary file (1.09 kB). View file
 
cldm/cldm.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import torch
3
+ import torch as th
4
+ import torch.nn as nn
5
+
6
+ from ldm.modules.diffusionmodules.util import (
7
+ conv_nd,
8
+ linear,
9
+ zero_module,
10
+ timestep_embedding,
11
+ )
12
+
13
+ from einops import rearrange, repeat
14
+ from torchvision.utils import make_grid
15
+ from ldm.modules.attention import SpatialTransformer
16
+ from ldm.modules.attention_dcn_control import SpatialTransformer_dcn
17
+ from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
18
+ from ldm.models.diffusion.ddpm import LatentDiffusion
19
+ from ldm.util import log_txt_as_img, exists, instantiate_from_config
20
+ from ldm.models.diffusion.ddim import DDIMSampler
21
+
22
+
23
+ class ControlledUnetModel(UNetModel):
24
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
25
+ hs = []
26
+ # print("timestep",timesteps)
27
+ with torch.no_grad():
28
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
29
+ # print("t_emb",t_emb)
30
+ emb = self.time_embed(t_emb)
31
+ h = x.type(self.dtype)
32
+ for module in self.input_blocks:
33
+ h = module(h, emb, context)#,timestep=timesteps)
34
+ hs.append(h)
35
+ h = self.middle_block(h, emb, context)#,timestep=timesteps)
36
+
37
+ if control is not None:
38
+ h += control.pop()
39
+
40
+ for i, module in enumerate(self.output_blocks):
41
+ # print("output_blocks0",h.shape)
42
+ if only_mid_control or control is None:
43
+ h = torch.cat([h, hs.pop()], dim=1)
44
+ else:
45
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
46
+ h = module(h, emb, context)#,timestep=timesteps)
47
+
48
+ # print("output_blocks",h.shape)
49
+
50
+ h = h.type(x.dtype)
51
+ h=self.out(h)
52
+ # print("self.ot",h.shape)
53
+ return h
54
+
55
+
56
+ class ControlNet(nn.Module):
57
+ def __init__(
58
+ self,
59
+ image_size,
60
+ in_channels,
61
+ model_channels,
62
+ hint_channels,
63
+ num_res_blocks,
64
+ attention_resolutions,
65
+ dropout=0,
66
+ channel_mult=(1, 2, 4, 8),
67
+ conv_resample=True,
68
+ dims=2,
69
+ use_checkpoint=False,
70
+ use_fp16=False,
71
+ num_heads=-1,
72
+ num_head_channels=-1,
73
+ num_heads_upsample=-1,
74
+ use_scale_shift_norm=False,
75
+ resblock_updown=False,
76
+ use_new_attention_order=False,
77
+ use_spatial_transformer=False, # custom transformer support
78
+ transformer_depth=1, # custom transformer support
79
+ context_dim=None, # custom transformer support
80
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
81
+ legacy=True,
82
+ disable_self_attentions=None,
83
+ num_attention_blocks=None,
84
+ disable_middle_self_attn=False,
85
+ use_linear_in_transformer=False,
86
+ ):
87
+ super().__init__()
88
+ if use_spatial_transformer:
89
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
90
+
91
+ if context_dim is not None:
92
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
93
+ from omegaconf.listconfig import ListConfig
94
+ if type(context_dim) == ListConfig:
95
+ context_dim = list(context_dim)
96
+
97
+ if num_heads_upsample == -1:
98
+ num_heads_upsample = num_heads
99
+
100
+ if num_heads == -1:
101
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
102
+
103
+ if num_head_channels == -1:
104
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
105
+
106
+ self.dims = dims
107
+ self.image_size = image_size
108
+ self.in_channels = in_channels
109
+ self.model_channels = model_channels
110
+ if isinstance(num_res_blocks, int):
111
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
112
+ else:
113
+ if len(num_res_blocks) != len(channel_mult):
114
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
115
+ "as a list/tuple (per-level) with the same length as channel_mult")
116
+ self.num_res_blocks = num_res_blocks
117
+ if disable_self_attentions is not None:
118
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
119
+ assert len(disable_self_attentions) == len(channel_mult)
120
+ if num_attention_blocks is not None:
121
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
122
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
123
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
124
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
125
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
126
+ f"attention will still not be set.")
127
+
128
+ self.attention_resolutions = attention_resolutions
129
+ self.dropout = dropout
130
+ self.channel_mult = channel_mult
131
+ self.conv_resample = conv_resample
132
+ self.use_checkpoint = use_checkpoint
133
+ self.dtype = th.float16 if use_fp16 else th.float32
134
+ self.num_heads = num_heads
135
+ self.num_head_channels = num_head_channels
136
+ self.num_heads_upsample = num_heads_upsample
137
+ self.predict_codebook_ids = n_embed is not None
138
+
139
+ time_embed_dim = model_channels * 4
140
+ self.time_embed = nn.Sequential(
141
+ linear(model_channels, time_embed_dim),
142
+ nn.SiLU(),
143
+ linear(time_embed_dim, time_embed_dim),
144
+ )
145
+
146
+ self.input_blocks = nn.ModuleList(
147
+ [
148
+ TimestepEmbedSequential(
149
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
150
+ )
151
+ ]
152
+ )
153
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
154
+
155
+ self.input_hint_block = TimestepEmbedSequential(
156
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
157
+ nn.SiLU(),
158
+ conv_nd(dims, 16, 16, 3, padding=1),
159
+ nn.SiLU(),
160
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
161
+ nn.SiLU(),
162
+ conv_nd(dims, 32, 32, 3, padding=1),
163
+ nn.SiLU(),
164
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
165
+ nn.SiLU(),
166
+ conv_nd(dims, 96, 96, 3, padding=1),
167
+ nn.SiLU(),
168
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
169
+ nn.SiLU(),
170
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
171
+ )
172
+
173
+ self._feature_size = model_channels
174
+ input_block_chans = [model_channels]
175
+ ch = model_channels
176
+ ds = 1
177
+ for level, mult in enumerate(channel_mult):
178
+ for nr in range(self.num_res_blocks[level]):
179
+ layers = [
180
+ ResBlock(
181
+ ch,
182
+ time_embed_dim,
183
+ dropout,
184
+ out_channels=mult * model_channels,
185
+ dims=dims,
186
+ use_checkpoint=use_checkpoint,
187
+ use_scale_shift_norm=use_scale_shift_norm,
188
+ )
189
+ ]
190
+ ch = mult * model_channels
191
+ if ds in attention_resolutions:
192
+ if num_head_channels == -1:
193
+ dim_head = ch // num_heads
194
+ else:
195
+ num_heads = ch // num_head_channels
196
+ dim_head = num_head_channels
197
+ if legacy:
198
+ # num_heads = 1
199
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
200
+ if exists(disable_self_attentions):
201
+ disabled_sa = disable_self_attentions[level]
202
+ else:
203
+ disabled_sa = False
204
+
205
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
206
+ layers.append(
207
+ AttentionBlock(
208
+ ch,
209
+ use_checkpoint=use_checkpoint,
210
+ num_heads=num_heads,
211
+ num_head_channels=dim_head,
212
+ use_new_attention_order=use_new_attention_order,
213
+ ) if not use_spatial_transformer else SpatialTransformer(
214
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
215
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
216
+ use_checkpoint=use_checkpoint
217
+ )
218
+ )
219
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
220
+ self.zero_convs.append(self.make_zero_conv(ch))
221
+ self._feature_size += ch
222
+ input_block_chans.append(ch)
223
+ if level != len(channel_mult) - 1:
224
+ out_ch = ch
225
+ self.input_blocks.append(
226
+ TimestepEmbedSequential(
227
+ ResBlock(
228
+ ch,
229
+ time_embed_dim,
230
+ dropout,
231
+ out_channels=out_ch,
232
+ dims=dims,
233
+ use_checkpoint=use_checkpoint,
234
+ use_scale_shift_norm=use_scale_shift_norm,
235
+ down=True,
236
+ )
237
+ if resblock_updown
238
+ else Downsample(
239
+ ch, conv_resample, dims=dims, out_channels=out_ch
240
+ )
241
+ )
242
+ )
243
+ ch = out_ch
244
+ input_block_chans.append(ch)
245
+ self.zero_convs.append(self.make_zero_conv(ch))
246
+ ds *= 2
247
+ self._feature_size += ch
248
+
249
+ if num_head_channels == -1:
250
+ dim_head = ch // num_heads
251
+ else:
252
+ num_heads = ch // num_head_channels
253
+ dim_head = num_head_channels
254
+ if legacy:
255
+ # num_heads = 1
256
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
257
+ self.middle_block = TimestepEmbedSequential(
258
+ ResBlock(
259
+ ch,
260
+ time_embed_dim,
261
+ dropout,
262
+ dims=dims,
263
+ use_checkpoint=use_checkpoint,
264
+ use_scale_shift_norm=use_scale_shift_norm,
265
+ ),
266
+ AttentionBlock(
267
+ ch,
268
+ use_checkpoint=use_checkpoint,
269
+ num_heads=num_heads,
270
+ num_head_channels=dim_head,
271
+ use_new_attention_order=use_new_attention_order,
272
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
273
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
274
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
275
+ use_checkpoint=use_checkpoint
276
+ ),
277
+ ResBlock(
278
+ ch,
279
+ time_embed_dim,
280
+ dropout,
281
+ dims=dims,
282
+ use_checkpoint=use_checkpoint,
283
+ use_scale_shift_norm=use_scale_shift_norm,
284
+ ),
285
+ )
286
+ self.middle_block_out = self.make_zero_conv(ch)
287
+ self._feature_size += ch
288
+
289
+ def make_zero_conv(self, channels):
290
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
291
+
292
+ def forward(self, x, hint, timesteps, context, **kwargs):
293
+ # print("cldm",hint.shape,x.shape)
294
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
295
+ emb = self.time_embed(t_emb)
296
+
297
+ guided_hint = self.input_hint_block(hint, emb, context)
298
+
299
+ outs = []
300
+
301
+ h = x.type(self.dtype)
302
+ # h_in=h
303
+
304
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
305
+ if guided_hint is not None:
306
+ h = module(h, emb, context)#,dcn_guide=h_in)
307
+ h += guided_hint
308
+ guided_hint = None
309
+ else:
310
+ # print("dcn_guide")
311
+ h = module(h, emb, context)#,dcn_guide=h_in)
312
+ outs.append(zero_conv(h, emb, context))
313
+
314
+ h = self.middle_block(h, emb, context)#,dcn_guide=h_in)
315
+ outs.append(self.middle_block_out(h, emb, context))
316
+
317
+ return outs
318
+
319
+
320
+ class ControlLDM(LatentDiffusion):
321
+
322
+ def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs): #freeze
323
+ # print(control_stage_config)
324
+ super().__init__(*args, **kwargs)
325
+ self.control_model = instantiate_from_config(control_stage_config)
326
+ self.control_key = control_key
327
+ self.only_mid_control = only_mid_control
328
+ self.control_scales = [1.0] * 13
329
+ # if freeze==True:
330
+ # self.freeze()
331
+
332
+ # def freeze(self):
333
+ # #self.train = disabled_train
334
+ # for param in self.parameters():
335
+ # param.requires_grad = False
336
+
337
+
338
+
339
+ @torch.no_grad()
340
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
341
+ x,mask,masked_image_latents, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
342
+ control = batch[self.control_key]
343
+ if bs is not None:
344
+ control = control[:bs]
345
+ control = control.to(self.device)
346
+ control = einops.rearrange(control, 'b h w c -> b c h w')
347
+ control = control.to(memory_format=torch.contiguous_format).float()
348
+ return x,mask,masked_image_latents, dict(c_crossattn=[c], c_concat=[control])
349
+
350
+ def apply_model(self, x_noisy,mask,masked_image_latents, t, cond, *args, **kwargs):
351
+ assert isinstance(cond, dict)
352
+ diffusion_model = self.model.diffusion_model
353
+
354
+ cond_txt = torch.cat(cond['c_crossattn'], 1)
355
+ # print(cond_txt.shape,cond['c_crossattn'].shape)
356
+ if cond['c_concat'] is None:
357
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
358
+ else:
359
+ control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
360
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
361
+ mask=torch.cat([mask] * x_noisy.shape[0])
362
+ masked_image_latents=torch.cat([masked_image_latents] * x_noisy.shape[0])
363
+ x_noisy = torch.cat([x_noisy,mask,masked_image_latents], dim=1)
364
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
365
+
366
+ return eps
367
+
368
+ def apply_model_addhint(self, x_noisy,mask,masked_image_latents, t, cond, *args, **kwargs):
369
+ assert isinstance(cond, dict)
370
+ diffusion_model = self.model.diffusion_model
371
+
372
+ cond_txt = torch.cat(cond['c_crossattn'], 1)
373
+ # print(cond_txt.shape,cond['c_crossattn'].shape)
374
+ if cond['c_concat'] is None:
375
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
376
+ else:
377
+ control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
378
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
379
+ # print(x_noisy.shape,mask.shape,masked_image_latents.shape)
380
+ x_noisy = torch.cat([x_noisy,mask,masked_image_latents], dim=1)
381
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
382
+
383
+ return eps
384
+
385
+ @torch.no_grad()
386
+ def get_unconditional_conditioning(self, N):
387
+ return self.get_learned_conditioning([""] * N)
388
+ # def get_unconditional_conditioning(self, N,hint_image):
389
+ # hint_image[:,:,:,:]=0
390
+ # return self.get_learned_conditioning(([""] * N,hint_image))
391
+
392
+ # @torch.no_grad()
393
+ # def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
394
+ # quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
395
+ # plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
396
+ # use_ema_scope=True,
397
+ # **kwargs):
398
+ # use_ddim = ddim_steps is not None
399
+
400
+ # log = dict()
401
+ # z,mask,masked_image_latents, c = self.get_input(batch, self.first_stage_key, bs=N)
402
+ # c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
403
+ # N = min(z.shape[0], N)
404
+ # n_row = min(z.shape[0], n_row)
405
+ # log["reconstruction"] = self.decode_first_stage(z)
406
+ # log["control"] = c_cat * 2.0 - 1.0
407
+ # log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
408
+ # txt,hint_image=batch[self.cond_stage_key]
409
+ # if plot_diffusion_rows:
410
+ # # get diffusion row
411
+ # diffusion_row = list()
412
+ # z_start = z[:n_row]
413
+ # for t in range(self.num_timesteps):
414
+ # if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
415
+ # t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
416
+ # t = t.to(self.device).long()
417
+ # noise = torch.randn_like(z_start)
418
+ # z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
419
+ # diffusion_row.append(self.decode_first_stage(z_noisy))
420
+
421
+ # diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
422
+ # diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
423
+ # diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
424
+ # diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
425
+ # log["diffusion_row"] = diffusion_grid
426
+
427
+ # if sample:
428
+ # # get denoise row
429
+ # samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
430
+ # batch_size=N, ddim=use_ddim,
431
+ # ddim_steps=ddim_steps, eta=ddim_eta)
432
+ # x_samples = self.decode_first_stage(samples)
433
+ # log["samples"] = x_samples
434
+ # if plot_denoise_rows:
435
+ # denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
436
+ # log["denoise_row"] = denoise_grid
437
+
438
+ # if unconditional_guidance_scale > 1.0:
439
+ # uc_cross = self.get_unconditional_conditioning(N,hint_image)
440
+ # uc_cat = c_cat # torch.zeros_like(c_cat)
441
+ # uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
442
+ # samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
443
+ # batch_size=N, ddim=use_ddim,
444
+ # ddim_steps=ddim_steps, eta=ddim_eta,
445
+ # unconditional_guidance_scale=unconditional_guidance_scale,
446
+ # unconditional_conditioning=uc_full,
447
+ # )
448
+ # x_samples_cfg = self.decode_first_stage(samples_cfg)
449
+ # log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
450
+
451
+ # return log
452
+
453
+ @torch.no_grad()
454
+ def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
455
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
456
+ plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
457
+ use_ema_scope=True,
458
+ **kwargs):
459
+ use_ddim = ddim_steps is not None
460
+
461
+ log = dict()
462
+ z,mask,masked_image_latents, c = self.get_input(batch, self.first_stage_key, bs=N, )
463
+ c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
464
+ N = min(z.shape[0], N)
465
+ n_row = min(z.shape[0], n_row)
466
+ log["reconstruction"] = self.decode_first_stage(z)
467
+ log["control"] = c_cat * 2.0 - 1.0
468
+ log["conditioning"] = log_txt_as_img((512, 512),batch[self.masked_image], batch[self.cond_stage_key], size=16)
469
+
470
+ if plot_diffusion_rows:
471
+ # get diffusion row
472
+ diffusion_row = list()
473
+ z_start = z[:n_row]
474
+ for t in range(self.num_timesteps):
475
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
476
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
477
+ t = t.to(self.device).long()
478
+ noise = torch.randn_like(z_start)
479
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
480
+ diffusion_row.append(self.decode_first_stage(z_noisy))
481
+
482
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
483
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
484
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
485
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
486
+ log["diffusion_row"] = diffusion_grid
487
+
488
+ if sample:
489
+ # get denoise row
490
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},mask=mask,masked_image_latents=masked_image_latents,
491
+ batch_size=N, ddim=use_ddim,
492
+ ddim_steps=ddim_steps, eta=ddim_eta)
493
+ x_samples = self.decode_first_stage(samples)
494
+ log["samples"] = x_samples
495
+ if plot_denoise_rows:
496
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
497
+ log["denoise_row"] = denoise_grid
498
+
499
+ if unconditional_guidance_scale > 1.0:
500
+ uc_cross = self.get_unconditional_conditioning(N)
501
+ uc_cat = c_cat # torch.zeros_like(c_cat)
502
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
503
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},mask=mask,masked_image_latents=masked_image_latents,
504
+ batch_size=N, ddim=use_ddim,
505
+ ddim_steps=ddim_steps, eta=ddim_eta,
506
+ unconditional_guidance_scale=unconditional_guidance_scale,
507
+ unconditional_conditioning=uc_full,
508
+ )
509
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
510
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
511
+
512
+ return log
513
+ @torch.no_grad()
514
+ def sample_log(self, cond,mask,masked_image_latents, batch_size, ddim, ddim_steps, **kwargs):
515
+ ddim_sampler = DDIMSampler(self)
516
+ b, c, h, w = cond["c_concat"][0].shape
517
+ shape = (self.channels, h // 8, w // 8)
518
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond,mask=mask,masked_image_latents=masked_image_latents, verbose=False, **kwargs)
519
+ return samples, intermediates
520
+
521
+ def configure_optimizers(self):
522
+ lr = self.learning_rate
523
+ params = list(self.control_model.parameters())
524
+ # head_params=list()
525
+ # for name,param in self.control_model.named_parameters(): #self.model.named_parameters():
526
+ # if "dcn" in name:
527
+ # # print(name)
528
+ # head_params.append(param)
529
+ # # params = list(self.control_model.parameters())+head_params
530
+ # params = head_params
531
+ if not self.sd_locked:
532
+ params += list(self.model.diffusion_model.output_blocks.parameters())
533
+ params += list(self.model.diffusion_model.out.parameters())
534
+ opt = torch.optim.AdamW(params, lr=lr)
535
+ return opt
536
+
537
+ def low_vram_shift(self, is_diffusing):
538
+ if is_diffusing:
539
+ self.model = self.model.cuda()
540
+ self.control_model = self.control_model.cuda()
541
+ self.first_stage_model = self.first_stage_model.cpu()
542
+ self.cond_stage_model = self.cond_stage_model.cpu()
543
+ else:
544
+ self.model = self.model.cpu()
545
+ self.control_model = self.control_model.cpu()
546
+ self.first_stage_model = self.first_stage_model.cuda()
547
+ self.cond_stage_model = self.cond_stage_model.cuda()
cldm/ddim_haced_sag_step.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
+ import torch.nn.functional as F
9
+
10
+ import cv2
11
+
12
+ import einops
13
+ # Gaussian blur
14
+ def gaussian_blur_2d(img, kernel_size, sigma):
15
+ ksize_half = (kernel_size - 1) * 0.5
16
+
17
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
18
+
19
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
20
+
21
+ x_kernel = pdf / pdf.sum()
22
+ x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
23
+
24
+ kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
25
+ kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
26
+
27
+ padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
28
+
29
+ img = F.pad(img, padding, mode="reflect")
30
+ img = F.conv2d(img, kernel2d, groups=img.shape[-3])
31
+
32
+ return img
33
+
34
+ # processes and stores attention probabilities
35
+ class CrossAttnStoreProcessor:
36
+ def __init__(self):
37
+ self.attention_probs = None
38
+
39
+ def __call__(
40
+ self,
41
+ attn,
42
+ hidden_states,
43
+ encoder_hidden_states=None,
44
+ attention_mask=None,
45
+ ):
46
+ batch_size, sequence_length, _ = hidden_states.shape
47
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
48
+ query = attn.to_q(hidden_states)
49
+
50
+ if encoder_hidden_states is None:
51
+ encoder_hidden_states = hidden_states
52
+ elif attn.norm_cross:
53
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
54
+
55
+ key = attn.to_k(encoder_hidden_states)
56
+ value = attn.to_v(encoder_hidden_states)
57
+
58
+ query = attn.head_to_batch_dim(query)
59
+ key = attn.head_to_batch_dim(key)
60
+ value = attn.head_to_batch_dim(value)
61
+
62
+ self.attention_probs = attn.get_attention_scores(query, key, attention_mask)
63
+ hidden_states = torch.bmm(self.attention_probs, value)
64
+ hidden_states = attn.batch_to_head_dim(hidden_states)
65
+
66
+ # linear proj
67
+ hidden_states = attn.to_out[0](hidden_states)
68
+ # dropout
69
+ hidden_states = attn.to_out[1](hidden_states)
70
+
71
+ return hidden_states
72
+
73
+ class DDIMSampler(object):
74
+ def __init__(self, model, schedule="linear", **kwargs):
75
+ super().__init__()
76
+ self.model = model
77
+ self.ddpm_num_timesteps = model.num_timesteps
78
+ self.schedule = schedule
79
+
80
+ def register_buffer(self, name, attr):
81
+ if type(attr) == torch.Tensor:
82
+ if attr.device != torch.device("cuda"):
83
+ attr = attr.to(torch.device("cuda"))
84
+ setattr(self, name, attr)
85
+
86
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
87
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
88
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
89
+ alphas_cumprod = self.model.alphas_cumprod
90
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
91
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
92
+
93
+ self.register_buffer('betas', to_torch(self.model.betas))
94
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
95
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
96
+
97
+ # calculations for diffusion q(x_t | x_{t-1}) and others
98
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
99
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
100
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
101
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
102
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
103
+
104
+ # ddim sampling parameters
105
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
106
+ ddim_timesteps=self.ddim_timesteps,
107
+ eta=ddim_eta,verbose=verbose)
108
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
109
+ self.register_buffer('ddim_alphas', ddim_alphas)
110
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
111
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
112
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
113
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
114
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
115
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
116
+
117
+ @torch.no_grad()
118
+ def sample(self,
119
+ model,
120
+ S,
121
+ batch_size,
122
+ shape,
123
+ conditioning=None,
124
+ callback=None,
125
+ normals_sequence=None,
126
+ img_callback=None,
127
+ quantize_x0=False,
128
+ eta=0.,
129
+ mask=None,
130
+ masked_image_latents=None,
131
+ x0=None,
132
+ temperature=1.,
133
+ noise_dropout=0.,
134
+ score_corrector=None,
135
+ corrector_kwargs=None,
136
+ verbose=True,
137
+ x_T=None,
138
+ log_every_t=100,
139
+ unconditional_guidance_scale=1.,
140
+ sag_scale=0.75,
141
+ SAG_influence_step=600,
142
+ noise = None,
143
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
144
+ dynamic_threshold=None,
145
+ ucg_schedule=None,
146
+ **kwargs
147
+ ):
148
+ if conditioning is not None:
149
+ if isinstance(conditioning, dict):
150
+ ctmp = conditioning[list(conditioning.keys())[0]]
151
+ while isinstance(ctmp, list): ctmp = ctmp[0]
152
+ cbs = ctmp.shape[0]
153
+ if cbs != batch_size:
154
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
155
+
156
+ elif isinstance(conditioning, list):
157
+ for ctmp in conditioning:
158
+ if ctmp.shape[0] != batch_size:
159
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
160
+
161
+ else:
162
+ if conditioning.shape[0] != batch_size:
163
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
164
+
165
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
166
+ # sampling
167
+ # print(shape)
168
+ C, H, W = shape
169
+ size = (batch_size, C, H, W)
170
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
171
+
172
+ samples, intermediates = self.ddim_sampling(model,conditioning, size,
173
+ callback=callback,
174
+ img_callback=img_callback,
175
+ quantize_denoised=quantize_x0,
176
+ mask=mask,masked_image_latents=masked_image_latents, x0=x0,
177
+ ddim_use_original_steps=False,
178
+ noise_dropout=noise_dropout,
179
+ temperature=temperature,
180
+ score_corrector=score_corrector,
181
+ corrector_kwargs=corrector_kwargs,
182
+ x_T=x_T,
183
+ log_every_t=log_every_t,
184
+ unconditional_guidance_scale=unconditional_guidance_scale,
185
+ sag_scale = sag_scale,
186
+ SAG_influence_step = SAG_influence_step,
187
+ noise = noise,
188
+ unconditional_conditioning=unconditional_conditioning,
189
+ dynamic_threshold=dynamic_threshold,
190
+ ucg_schedule=ucg_schedule
191
+ )
192
+ return samples, intermediates
193
+
194
+ def add_noise(self,
195
+ original_samples: torch.FloatTensor,
196
+ noise: torch.FloatTensor,
197
+ timesteps: torch.IntTensor,
198
+ ) -> torch.FloatTensor:
199
+ betas = torch.linspace(0.00085, 0.0120, 1000, dtype=torch.float32)
200
+ alphas = 1.0 - betas
201
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
202
+ alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
203
+ timesteps = timesteps.to(original_samples.device)
204
+
205
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
206
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
207
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
208
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
209
+
210
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
211
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
212
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
213
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
214
+
215
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
216
+
217
+ return noisy_samples
218
+
219
+
220
+ def sag_masking(self, original_latents,model_output,x, attn_map, map_size, t, eps):
221
+ # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
222
+ bh, hw1, hw2 = attn_map.shape
223
+ b, latent_channel, latent_h, latent_w = original_latents.shape
224
+ h = 4 #self.unet.config.attention_head_dim
225
+ if isinstance(h, list):
226
+ h = h[-1]
227
+ attn_map = attn_map.reshape(b, h, hw1, hw2)
228
+ attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0
229
+ attn_mask = (
230
+ attn_mask.reshape(b, map_size[0], map_size[1])
231
+ .unsqueeze(1)
232
+ .repeat(1, latent_channel, 1, 1)
233
+ .type(attn_map.dtype)
234
+ )
235
+ attn_mask = F.interpolate(attn_mask, (latent_h, latent_w))
236
+ degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0)
237
+ degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) #x#original_latents
238
+
239
+ return degraded_latents
240
+
241
+ def pred_epsilon(self, sample, model_output, timestep):
242
+ alpha_prod_t = timestep
243
+
244
+ beta_prod_t = 1 - alpha_prod_t
245
+ # print(self.model.parameterization)#eps
246
+ if self.model.parameterization == "eps":
247
+ pred_eps = model_output
248
+ elif self.model.parameterization == "sample":
249
+ pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5)
250
+ elif self.model.parameterization == "v":
251
+ pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output
252
+ else:
253
+ raise ValueError(
254
+ f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `eps`, `sample`,"
255
+ " or `v`"
256
+ )
257
+
258
+ return pred_eps
259
+
260
+ @torch.no_grad()
261
+ def ddim_sampling(self,model, cond, shape,
262
+ x_T=None, ddim_use_original_steps=False,
263
+ callback=None, timesteps=None, quantize_denoised=False,
264
+ mask=None,masked_image_latents=None, x0=None, img_callback=None, log_every_t=100,
265
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
266
+ unconditional_guidance_scale=1.,sag_scale = 0.75, SAG_influence_step=600, sag_enable = True, noise = None, unconditional_conditioning=None, dynamic_threshold=None,
267
+ ucg_schedule=None):
268
+ device = self.model.betas.device
269
+ b = shape[0]
270
+ if x_T is None:
271
+ img = torch.randn(shape, device=device)
272
+ else:
273
+ img = x_T
274
+ # timesteps =100
275
+ if timesteps is None:
276
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
277
+ elif timesteps is not None and not ddim_use_original_steps:
278
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
279
+ timesteps = self.ddim_timesteps[:subset_end]
280
+ # timesteps=timesteps[:-3]
281
+ # print("timesteps",timesteps)
282
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
283
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
284
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
285
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
286
+
287
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
288
+
289
+ for i, step in enumerate(iterator):
290
+ # print(step)
291
+ if step > SAG_influence_step:
292
+ sag_enable_t=True
293
+ else:
294
+ sag_enable_t=False
295
+ index = total_steps - i - 1
296
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
297
+
298
+ if ucg_schedule is not None:
299
+ assert len(ucg_schedule) == len(time_range)
300
+ unconditional_guidance_scale = ucg_schedule[i]
301
+
302
+ outs = self.p_sample_ddim(img,mask,masked_image_latents, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
303
+ quantize_denoised=quantize_denoised, temperature=temperature,
304
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
305
+ corrector_kwargs=corrector_kwargs,
306
+ unconditional_guidance_scale=unconditional_guidance_scale,
307
+ sag_scale = sag_scale,
308
+ sag_enable=sag_enable_t,
309
+ noise =noise,
310
+ unconditional_conditioning=unconditional_conditioning,
311
+ dynamic_threshold=dynamic_threshold)
312
+ img, pred_x0 = outs
313
+ if callback: callback(i)
314
+ if img_callback: img_callback(pred_x0, i)
315
+
316
+ if index % log_every_t == 0 or index == total_steps - 1:
317
+ intermediates['x_inter'].append(img)
318
+ intermediates['pred_x0'].append(pred_x0)
319
+ x_samples = model.decode_first_stage(img)
320
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
321
+
322
+ #single image replace L channel
323
+ results_ori = [x_samples[i] for i in range(1)]
324
+ # results_ori=[i for i in results_ori]
325
+
326
+ # cv2.imwrite("result_ori"+str(step)+".png",cv2.cvtColor(results_ori[0],cv2.COLOR_RGB2BGR))
327
+ return img, intermediates
328
+
329
+ @torch.no_grad()
330
+ def p_sample_ddim(self, x,mask,masked_image_latents, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
331
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
332
+ unconditional_guidance_scale=1.,sag_scale = 0.75, sag_enable=True, noise=None, unconditional_conditioning=None,
333
+ dynamic_threshold=None):
334
+ b, *_, device = *x.shape, x.device
335
+
336
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
337
+ model_output = self.model.apply_model(x,mask,masked_image_latents, t, c)
338
+ else:
339
+ model_t = self.model.apply_model(x,mask,masked_image_latents, t, c)
340
+ model_uncond = self.model.apply_model(x,mask,masked_image_latents, t, unconditional_conditioning)
341
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
342
+
343
+ if self.model.parameterization == "v":
344
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
345
+ else:
346
+ e_t = model_output
347
+
348
+ if score_corrector is not None:
349
+ assert self.model.parameterization == "eps", 'not implemented'
350
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
351
+
352
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
353
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
354
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
355
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
356
+ # select parameters corresponding to the currently considered timestep
357
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
358
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
359
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
360
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
361
+
362
+ # current prediction for x_0
363
+ if self.model.parameterization != "v":
364
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
365
+ else:
366
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
367
+
368
+ if quantize_denoised:
369
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
370
+
371
+ if dynamic_threshold is not None:
372
+ raise NotImplementedError()
373
+ if sag_enable == True:
374
+ uncond_attn, cond_attn = self.model.model.diffusion_model.middle_block[1].transformer_blocks[0].attn1.attention_probs.chunk(2)
375
+ # self-attention-based degrading of latents
376
+ map_size = self.model.model.diffusion_model.middle_block[1].map_size
377
+ degraded_latents = self.sag_masking(
378
+ pred_x0,model_output,x,uncond_attn, map_size, t, eps = noise, #self.pred_epsilon(x, model_uncond, self.model.alphas_cumprod[t]),#noise
379
+ )
380
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
381
+ degraded_model_output = self.model.apply_model(degraded_latents,mask,masked_image_latents, t, c)
382
+ else:
383
+ degraded_model_t = self.model.apply_model(degraded_latents,mask,masked_image_latents, t, c)
384
+ degraded_model_uncond = self.model.apply_model(degraded_latents,mask,masked_image_latents, t, unconditional_conditioning)
385
+ degraded_model_output = degraded_model_uncond + unconditional_guidance_scale * (degraded_model_t - degraded_model_uncond)
386
+ # print("sag_scale",sag_scale)
387
+ model_output += sag_scale * (model_output - degraded_model_output)
388
+ # model_output = (1-sag_scale) * model_output + sag_scale * degraded_model_output
389
+
390
+ # current prediction for x_0
391
+ if self.model.parameterization != "v":
392
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
393
+ else:
394
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
395
+
396
+ if quantize_denoised:
397
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
398
+
399
+ if dynamic_threshold is not None:
400
+ raise NotImplementedError()
401
+
402
+ # direction pointing to x_t
403
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
404
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
405
+ if noise_dropout > 0.:
406
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
407
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
408
+ return x_prev, pred_x0
409
+
410
+ @torch.no_grad()
411
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
412
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
413
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
414
+ num_reference_steps = timesteps.shape[0]
415
+
416
+ assert t_enc <= num_reference_steps
417
+ num_steps = t_enc
418
+
419
+ if use_original_steps:
420
+ alphas_next = self.alphas_cumprod[:num_steps]
421
+ alphas = self.alphas_cumprod_prev[:num_steps]
422
+ else:
423
+ alphas_next = self.ddim_alphas[:num_steps]
424
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
425
+
426
+ x_next = x0
427
+ intermediates = []
428
+ inter_steps = []
429
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
430
+ t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
431
+ if unconditional_guidance_scale == 1.:
432
+ noise_pred = self.model.apply_model(x_next, t, c)
433
+ else:
434
+ assert unconditional_conditioning is not None
435
+ e_t_uncond, noise_pred = torch.chunk(
436
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
437
+ torch.cat((unconditional_conditioning, c))), 2)
438
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
439
+
440
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
441
+ weighted_noise_pred = alphas_next[i].sqrt() * (
442
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
443
+ x_next = xt_weighted + weighted_noise_pred
444
+ if return_intermediates and i % (
445
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
446
+ intermediates.append(x_next)
447
+ inter_steps.append(i)
448
+ elif return_intermediates and i >= num_steps - 2:
449
+ intermediates.append(x_next)
450
+ inter_steps.append(i)
451
+ if callback: callback(i)
452
+
453
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
454
+ if return_intermediates:
455
+ out.update({'intermediates': intermediates})
456
+ return x_next, out
457
+
458
+ @torch.no_grad()
459
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
460
+ # fast, but does not allow for exact reconstruction
461
+ # t serves as an index to gather the correct alphas
462
+ if use_original_steps:
463
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
464
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
465
+ else:
466
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
467
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
468
+
469
+ if noise is None:
470
+ noise = torch.randn_like(x0)
471
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
472
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
473
+
474
+ @torch.no_grad()
475
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
476
+ use_original_steps=False, callback=None):
477
+
478
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
479
+ timesteps = timesteps[:t_start]
480
+
481
+ time_range = np.flip(timesteps)
482
+ total_steps = timesteps.shape[0]
483
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
484
+
485
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
486
+ x_dec = x_latent
487
+ for i, step in enumerate(iterator):
488
+ index = total_steps - i - 1
489
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
490
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
491
+ unconditional_guidance_scale=unconditional_guidance_scale,
492
+ unconditional_conditioning=unconditional_conditioning)
493
+ if callback: callback(i)
494
+ return x_dec
cldm/ddim_hacked_sag.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
+ import torch.nn.functional as F
9
+
10
+ import cv2
11
+ # Gaussian blur
12
+ def gaussian_blur_2d(img, kernel_size, sigma):
13
+ ksize_half = (kernel_size - 1) * 0.5
14
+
15
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
16
+
17
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
18
+
19
+ x_kernel = pdf / pdf.sum()
20
+ x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
21
+
22
+ kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
23
+ kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
24
+
25
+ padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
26
+
27
+ img = F.pad(img, padding, mode="reflect")
28
+ img = F.conv2d(img, kernel2d, groups=img.shape[-3])
29
+
30
+ return img
31
+
32
+ # processes and stores attention probabilities
33
+ class CrossAttnStoreProcessor:
34
+ def __init__(self):
35
+ self.attention_probs = None
36
+
37
+ def __call__(
38
+ self,
39
+ attn,
40
+ hidden_states,
41
+ encoder_hidden_states=None,
42
+ attention_mask=None,
43
+ ):
44
+ batch_size, sequence_length, _ = hidden_states.shape
45
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
46
+ query = attn.to_q(hidden_states)
47
+
48
+ if encoder_hidden_states is None:
49
+ encoder_hidden_states = hidden_states
50
+ elif attn.norm_cross:
51
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
52
+
53
+ key = attn.to_k(encoder_hidden_states)
54
+ value = attn.to_v(encoder_hidden_states)
55
+
56
+ query = attn.head_to_batch_dim(query)
57
+ key = attn.head_to_batch_dim(key)
58
+ value = attn.head_to_batch_dim(value)
59
+
60
+ self.attention_probs = attn.get_attention_scores(query, key, attention_mask)
61
+ hidden_states = torch.bmm(self.attention_probs, value)
62
+ hidden_states = attn.batch_to_head_dim(hidden_states)
63
+
64
+ # linear proj
65
+ hidden_states = attn.to_out[0](hidden_states)
66
+ # dropout
67
+ hidden_states = attn.to_out[1](hidden_states)
68
+
69
+ return hidden_states
70
+
71
+ class DDIMSampler(object):
72
+ def __init__(self, model, schedule="linear", **kwargs):
73
+ super().__init__()
74
+ self.model = model
75
+ self.ddpm_num_timesteps = model.num_timesteps
76
+ self.schedule = schedule
77
+
78
+ def register_buffer(self, name, attr):
79
+ if type(attr) == torch.Tensor:
80
+ if attr.device != torch.device("cuda"):
81
+ attr = attr.to(torch.device("cuda"))
82
+ setattr(self, name, attr)
83
+
84
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
85
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
86
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
87
+ alphas_cumprod = self.model.alphas_cumprod
88
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
89
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
90
+
91
+ self.register_buffer('betas', to_torch(self.model.betas))
92
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
93
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
94
+
95
+ # calculations for diffusion q(x_t | x_{t-1}) and others
96
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
97
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
98
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
99
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
100
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
101
+
102
+ # ddim sampling parameters
103
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
104
+ ddim_timesteps=self.ddim_timesteps,
105
+ eta=ddim_eta,verbose=verbose)
106
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
107
+ self.register_buffer('ddim_alphas', ddim_alphas)
108
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
109
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
110
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
111
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
112
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
113
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
114
+
115
+ @torch.no_grad()
116
+ def sample(self,
117
+ S,
118
+ batch_size,
119
+ shape,
120
+ conditioning=None,
121
+ callback=None,
122
+ normals_sequence=None,
123
+ img_callback=None,
124
+ quantize_x0=False,
125
+ eta=0.,
126
+ mask=None,
127
+ masked_image_latents=None,
128
+ x0=None,
129
+ temperature=1.,
130
+ noise_dropout=0.,
131
+ score_corrector=None,
132
+ corrector_kwargs=None,
133
+ verbose=True,
134
+ x_T=None,
135
+ log_every_t=100,
136
+ unconditional_guidance_scale=1.,
137
+ sag_scale=0.75,
138
+ SAG_influence_step=600,
139
+ noise = None,
140
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
141
+ dynamic_threshold=None,
142
+ ucg_schedule=None,
143
+ **kwargs
144
+ ):
145
+ if conditioning is not None:
146
+ if isinstance(conditioning, dict):
147
+ ctmp = conditioning[list(conditioning.keys())[0]]
148
+ while isinstance(ctmp, list): ctmp = ctmp[0]
149
+ cbs = ctmp.shape[0]
150
+ if cbs != batch_size:
151
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
152
+
153
+ elif isinstance(conditioning, list):
154
+ for ctmp in conditioning:
155
+ if ctmp.shape[0] != batch_size:
156
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
157
+
158
+ else:
159
+ if conditioning.shape[0] != batch_size:
160
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
161
+
162
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
163
+ # sampling
164
+ C, H, W = shape
165
+ size = (batch_size, C, H, W)
166
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
167
+
168
+ samples, intermediates = self.ddim_sampling(conditioning, size,
169
+ callback=callback,
170
+ img_callback=img_callback,
171
+ quantize_denoised=quantize_x0,
172
+ mask=mask,masked_image_latents=masked_image_latents, x0=x0,
173
+ ddim_use_original_steps=False,
174
+ noise_dropout=noise_dropout,
175
+ temperature=temperature,
176
+ score_corrector=score_corrector,
177
+ corrector_kwargs=corrector_kwargs,
178
+ x_T=x_T,
179
+ log_every_t=log_every_t,
180
+ unconditional_guidance_scale=unconditional_guidance_scale,
181
+ sag_scale = sag_scale,
182
+ SAG_influence_step = SAG_influence_step,
183
+ noise = noise,
184
+ unconditional_conditioning=unconditional_conditioning,
185
+ dynamic_threshold=dynamic_threshold,
186
+ ucg_schedule=ucg_schedule
187
+ )
188
+ return samples, intermediates
189
+
190
+ def add_noise(self,
191
+ original_samples: torch.FloatTensor,
192
+ noise: torch.FloatTensor,
193
+ timesteps: torch.IntTensor,
194
+ ) -> torch.FloatTensor:
195
+ betas = torch.linspace(0.00085, 0.0120, 1000, dtype=torch.float32)
196
+ alphas = 1.0 - betas
197
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
198
+ alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
199
+ timesteps = timesteps.to(original_samples.device)
200
+
201
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
202
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
203
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
204
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
205
+
206
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
207
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
208
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
209
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
210
+
211
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
212
+
213
+ return noisy_samples
214
+ # def add_noise(
215
+ # self,
216
+ # original_samples: torch.FloatTensor,
217
+ # noise: torch.FloatTensor,
218
+ # timesteps: torch.FloatTensor,
219
+ # sigma_t,
220
+ # ) -> torch.FloatTensor:
221
+
222
+ # # Make sure sigmas and timesteps have the same device and dtype as original_samples
223
+
224
+ # sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
225
+ # if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
226
+ # # mps does not support float64
227
+ # schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
228
+ # timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
229
+ # else:
230
+ # schedule_timesteps = self.timesteps.to(original_samples.device)
231
+ # timesteps = timesteps.to(original_samples.device)
232
+
233
+ # step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
234
+
235
+ # sigma = sigmas[step_indices].flatten()
236
+ # while len(sigma.shape) < len(original_samples.shape):
237
+ # sigma = sigma.unsqueeze(-1)
238
+ # # print(sigma_t)
239
+ # noisy_samples = original_samples + noise * sigma_t
240
+ # return noisy_samples
241
+
242
+
243
+ def sag_masking(self, original_latents,model_output,x, attn_map, map_size, t, eps):
244
+ # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
245
+ bh, hw1, hw2 = attn_map.shape
246
+ b, latent_channel, latent_h, latent_w = original_latents.shape
247
+ h = 4 #self.unet.config.attention_head_dim
248
+ if isinstance(h, list):
249
+ h = h[-1]
250
+ # print(attn_map.shape)
251
+ # print(original_latents.shape)
252
+ # print(map_size)
253
+ # Produce attention mask
254
+ attn_map = attn_map.reshape(b, h, hw1, hw2)
255
+ attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0
256
+ # print(attn_mask.shape)
257
+ attn_mask = (
258
+ attn_mask.reshape(b, map_size[0], map_size[1])
259
+ .unsqueeze(1)
260
+ .repeat(1, latent_channel, 1, 1)
261
+ .type(attn_map.dtype)
262
+ )
263
+ attn_mask = F.interpolate(attn_mask, (latent_h, latent_w))
264
+ # print(attn_mask.shape)
265
+ # cv2.imwrite("attn_mask.png",attn_mask)
266
+ # Blur according to the self-attention mask
267
+ degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0)
268
+ # degraded_latents = self.add_noise(degraded_latents, noise=eps, timesteps=t)#,sigma_t=sigma_t)
269
+ degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) #x#original_latents
270
+ # degraded_latents = self.model.get_x_t_from_start_and_t(degraded_latents,t,model_output)
271
+ # print(original_latents.shape)
272
+ # print(eps.shape)
273
+ # Noise it again to match the noise level
274
+ # print("t",t)
275
+ # degraded_latents = self.add_noise(degraded_latents, noise=eps, timesteps=t)#,sigma_t=sigma_t)
276
+
277
+ return degraded_latents
278
+
279
+ def pred_epsilon(self, sample, model_output, timestep):
280
+ alpha_prod_t = timestep
281
+
282
+ beta_prod_t = 1 - alpha_prod_t
283
+ # print(self.model.parameterization)#eps
284
+ if self.model.parameterization == "eps":
285
+ pred_eps = model_output
286
+ elif self.model.parameterization == "sample":
287
+ pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5)
288
+ elif self.model.parameterization == "v":
289
+ pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output
290
+ else:
291
+ raise ValueError(
292
+ f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `eps`, `sample`,"
293
+ " or `v`"
294
+ )
295
+
296
+ return pred_eps
297
+
298
+ @torch.no_grad()
299
+ def ddim_sampling(self, cond, shape,
300
+ x_T=None, ddim_use_original_steps=False,
301
+ callback=None, timesteps=None, quantize_denoised=False,
302
+ mask=None,masked_image_latents=None, x0=None, img_callback=None, log_every_t=100,
303
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
304
+ unconditional_guidance_scale=1.,sag_scale = 0.75, SAG_influence_step=600, sag_enable = True, noise = None, unconditional_conditioning=None, dynamic_threshold=None,
305
+ ucg_schedule=None):
306
+ device = self.model.betas.device
307
+ b = shape[0]
308
+ if x_T is None:
309
+ img = torch.randn(shape, device=device)
310
+ else:
311
+ img = x_T
312
+ # timesteps =100
313
+ if timesteps is None:
314
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
315
+ elif timesteps is not None and not ddim_use_original_steps:
316
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
317
+ timesteps = self.ddim_timesteps[:subset_end]
318
+ # timesteps=timesteps[:-3]
319
+ # print("timesteps",timesteps)
320
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
321
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
322
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
323
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
324
+
325
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
326
+
327
+ for i, step in enumerate(iterator):
328
+ print(step)
329
+ if step > SAG_influence_step:
330
+ sag_enable_t=True
331
+ else:
332
+ sag_enable_t=False
333
+ index = total_steps - i - 1
334
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
335
+
336
+ # if mask is not None:
337
+ # assert x0 is not None
338
+ # img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
339
+ # img = img_orig * mask + (1. - mask) * img
340
+
341
+ if ucg_schedule is not None:
342
+ assert len(ucg_schedule) == len(time_range)
343
+ unconditional_guidance_scale = ucg_schedule[i]
344
+
345
+ outs = self.p_sample_ddim(img,mask,masked_image_latents, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
346
+ quantize_denoised=quantize_denoised, temperature=temperature,
347
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
348
+ corrector_kwargs=corrector_kwargs,
349
+ unconditional_guidance_scale=unconditional_guidance_scale,
350
+ sag_scale = sag_scale,
351
+ sag_enable=sag_enable_t,
352
+ noise =noise,
353
+ unconditional_conditioning=unconditional_conditioning,
354
+ dynamic_threshold=dynamic_threshold)
355
+ img, pred_x0 = outs
356
+ if callback: callback(i)
357
+ if img_callback: img_callback(pred_x0, i)
358
+
359
+ if index % log_every_t == 0 or index == total_steps - 1:
360
+ intermediates['x_inter'].append(img)
361
+ intermediates['pred_x0'].append(pred_x0)
362
+
363
+ return img, intermediates
364
+
365
+ @torch.no_grad()
366
+ def p_sample_ddim(self, x,mask,masked_image_latents, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
367
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
368
+ unconditional_guidance_scale=1.,sag_scale = 0.75, sag_enable=True, noise=None, unconditional_conditioning=None,
369
+ dynamic_threshold=None):
370
+ b, *_, device = *x.shape, x.device
371
+
372
+ # map_size = None
373
+ # def get_map_size(module, input, output):
374
+ # nonlocal map_size
375
+ # map_size = output.shape[-2:]
376
+
377
+ # store_processor = CrossAttnStoreProcessor()
378
+ # for name, param in self.model.model.diffusion_model.named_parameters():
379
+ # print(name)
380
+ # self.model.control_model.middle_block[1].transformer_blocks[0].attn1.processor = store_processor
381
+ # print(self.model.model.diffusion_model.middle_block[1].transformer_blocks[0].attn1)
382
+ # self.model.model.diffusion_model.middle_block[1].transformer_blocks[0].attn1 = store_processor
383
+
384
+ # with self.model.model.diffusion_model.middle_block[1].register_forward_hook(get_map_size):
385
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
386
+ model_output = self.model.apply_model(x,mask,masked_image_latents, t, c)
387
+ else:
388
+ model_t = self.model.apply_model(x,mask,masked_image_latents, t, c)
389
+ model_uncond = self.model.apply_model(x,mask,masked_image_latents, t, unconditional_conditioning)
390
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
391
+
392
+ if self.model.parameterization == "v":
393
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
394
+ else:
395
+ e_t = model_output
396
+
397
+ if score_corrector is not None:
398
+ assert self.model.parameterization == "eps", 'not implemented'
399
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
400
+
401
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
402
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
403
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
404
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
405
+ # select parameters corresponding to the currently considered timestep
406
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
407
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
408
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
409
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
410
+
411
+ # current prediction for x_0
412
+ if self.model.parameterization != "v":
413
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
414
+ else:
415
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
416
+
417
+ if quantize_denoised:
418
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
419
+
420
+ if dynamic_threshold is not None:
421
+ raise NotImplementedError()
422
+ if sag_enable == True:
423
+ uncond_attn, cond_attn = self.model.model.diffusion_model.middle_block[1].transformer_blocks[0].attn1.attention_probs.chunk(2)
424
+ # self-attention-based degrading of latents
425
+ map_size = self.model.model.diffusion_model.middle_block[1].map_size
426
+ degraded_latents = self.sag_masking(
427
+ pred_x0,model_output,x,uncond_attn, map_size, t, eps = noise, #self.pred_epsilon(x, model_uncond, self.model.alphas_cumprod[t]),#noise
428
+ )
429
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
430
+ degraded_model_output = self.model.apply_model(degraded_latents,mask,masked_image_latents, t, c)
431
+ else:
432
+ degraded_model_t = self.model.apply_model(degraded_latents,mask,masked_image_latents, t, c)
433
+ degraded_model_uncond = self.model.apply_model(degraded_latents,mask,masked_image_latents, t, unconditional_conditioning)
434
+ degraded_model_output = degraded_model_uncond + unconditional_guidance_scale * (degraded_model_t - degraded_model_uncond)
435
+ # print("sag_scale",sag_scale)
436
+ model_output += sag_scale * (model_output - degraded_model_output)
437
+ # model_output = (1-sag_scale) * model_output + sag_scale * degraded_model_output
438
+
439
+ # current prediction for x_0
440
+ if self.model.parameterization != "v":
441
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
442
+ else:
443
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
444
+
445
+ if quantize_denoised:
446
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
447
+
448
+ if dynamic_threshold is not None:
449
+ raise NotImplementedError()
450
+
451
+ # direction pointing to x_t
452
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
453
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
454
+ if noise_dropout > 0.:
455
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
456
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
457
+ return x_prev, pred_x0
458
+
459
+ @torch.no_grad()
460
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
461
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
462
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
463
+ num_reference_steps = timesteps.shape[0]
464
+
465
+ assert t_enc <= num_reference_steps
466
+ num_steps = t_enc
467
+
468
+ if use_original_steps:
469
+ alphas_next = self.alphas_cumprod[:num_steps]
470
+ alphas = self.alphas_cumprod_prev[:num_steps]
471
+ else:
472
+ alphas_next = self.ddim_alphas[:num_steps]
473
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
474
+
475
+ x_next = x0
476
+ intermediates = []
477
+ inter_steps = []
478
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
479
+ t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
480
+ if unconditional_guidance_scale == 1.:
481
+ noise_pred = self.model.apply_model(x_next, t, c)
482
+ else:
483
+ assert unconditional_conditioning is not None
484
+ e_t_uncond, noise_pred = torch.chunk(
485
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
486
+ torch.cat((unconditional_conditioning, c))), 2)
487
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
488
+
489
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
490
+ weighted_noise_pred = alphas_next[i].sqrt() * (
491
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
492
+ x_next = xt_weighted + weighted_noise_pred
493
+ if return_intermediates and i % (
494
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
495
+ intermediates.append(x_next)
496
+ inter_steps.append(i)
497
+ elif return_intermediates and i >= num_steps - 2:
498
+ intermediates.append(x_next)
499
+ inter_steps.append(i)
500
+ if callback: callback(i)
501
+
502
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
503
+ if return_intermediates:
504
+ out.update({'intermediates': intermediates})
505
+ return x_next, out
506
+
507
+ @torch.no_grad()
508
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
509
+ # fast, but does not allow for exact reconstruction
510
+ # t serves as an index to gather the correct alphas
511
+ if use_original_steps:
512
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
513
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
514
+ else:
515
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
516
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
517
+
518
+ if noise is None:
519
+ noise = torch.randn_like(x0)
520
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
521
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
522
+
523
+ @torch.no_grad()
524
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
525
+ use_original_steps=False, callback=None):
526
+
527
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
528
+ timesteps = timesteps[:t_start]
529
+
530
+ time_range = np.flip(timesteps)
531
+ total_steps = timesteps.shape[0]
532
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
533
+
534
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
535
+ x_dec = x_latent
536
+ for i, step in enumerate(iterator):
537
+ index = total_steps - i - 1
538
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
539
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
540
+ unconditional_guidance_scale=unconditional_guidance_scale,
541
+ unconditional_conditioning=unconditional_conditioning)
542
+ if callback: callback(i)
543
+ return x_dec
cldm/hack.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+
4
+ import ldm.modules.encoders.modules
5
+ import ldm.modules.attention
6
+
7
+ from transformers import logging
8
+ from ldm.modules.attention import default
9
+
10
+
11
+ def disable_verbosity():
12
+ logging.set_verbosity_error()
13
+ print('logging improved.')
14
+ return
15
+
16
+
17
+ def enable_sliced_attention():
18
+ ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
19
+ print('Enabled sliced_attention.')
20
+ return
21
+
22
+
23
+ def hack_everything(clip_skip=0):
24
+ disable_verbosity()
25
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
26
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
27
+ print('Enabled clip hacks.')
28
+ return
29
+
30
+
31
+ # Written by Lvmin
32
+ def _hacked_clip_forward(self, text):
33
+ PAD = self.tokenizer.pad_token_id
34
+ EOS = self.tokenizer.eos_token_id
35
+ BOS = self.tokenizer.bos_token_id
36
+
37
+ def tokenize(t):
38
+ return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
39
+
40
+ def transformer_encode(t):
41
+ if self.clip_skip > 1:
42
+ rt = self.transformer(input_ids=t, output_hidden_states=True)
43
+ return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
44
+ else:
45
+ return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
46
+
47
+ def split(x):
48
+ return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
49
+
50
+ def pad(x, p, i):
51
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
52
+
53
+ raw_tokens_list = tokenize(text)
54
+ tokens_list = []
55
+
56
+ for raw_tokens in raw_tokens_list:
57
+ raw_tokens_123 = split(raw_tokens)
58
+ raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
59
+ raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
60
+ tokens_list.append(raw_tokens_123)
61
+
62
+ tokens_list = torch.IntTensor(tokens_list).to(self.device)
63
+
64
+ feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
65
+ y = transformer_encode(feed)
66
+ z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
67
+
68
+ return z
69
+
70
+
71
+ # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
72
+ def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
73
+ h = self.heads
74
+
75
+ q = self.to_q(x)
76
+ context = default(context, x)
77
+ k = self.to_k(context)
78
+ v = self.to_v(context)
79
+ del context, x
80
+
81
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
82
+
83
+ limit = k.shape[0]
84
+ att_step = 1
85
+ q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
86
+ k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
87
+ v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
88
+
89
+ q_chunks.reverse()
90
+ k_chunks.reverse()
91
+ v_chunks.reverse()
92
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
93
+ del k, q, v
94
+ for i in range(0, limit, att_step):
95
+ q_buffer = q_chunks.pop()
96
+ k_buffer = k_chunks.pop()
97
+ v_buffer = v_chunks.pop()
98
+ sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
99
+
100
+ del k_buffer, q_buffer
101
+ # attention, what we cannot get enough of, by chunks
102
+
103
+ sim_buffer = sim_buffer.softmax(dim=-1)
104
+
105
+ sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
106
+ del v_buffer
107
+ sim[i:i + att_step, :, :] = sim_buffer
108
+
109
+ del sim_buffer
110
+ sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
111
+ return self.to_out(sim)
cldm/model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from omegaconf import OmegaConf
5
+ from ldm.util import instantiate_from_config
6
+
7
+
8
+ def get_state_dict(d):
9
+ return d.get('state_dict', d)
10
+
11
+
12
+ def load_state_dict(ckpt_path, location='cpu'):
13
+ _, extension = os.path.splitext(ckpt_path)
14
+ if extension.lower() == ".safetensors":
15
+ import safetensors.torch
16
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
17
+ else:
18
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
19
+ state_dict = get_state_dict(state_dict)
20
+ print(f'Loaded state_dict from [{ckpt_path}]')
21
+ return state_dict
22
+
23
+
24
+ def create_model(config_path):
25
+ config = OmegaConf.load(config_path)
26
+ model = instantiate_from_config(config.model).cpu()
27
+ print(f'Loaded model config from [{config_path}]')
28
+ return model