helblazer811 commited on
Commit
29c7873
·
1 Parent(s): bde9560

Added capability for uploading existing images to the UI.

Browse files
app.py CHANGED
@@ -45,6 +45,12 @@ def encode_image(image, prompt, concepts, seed, layer_start_index, noise_timeste
45
  if len(concepts) > 9:
46
  raise gr.Error("Please enter at most 9 concepts", duration=10)
47
 
 
 
 
 
 
 
48
  pipeline_output = pipeline.encode_image(
49
  image=image,
50
  prompt=prompt,
@@ -318,7 +324,7 @@ with gr.Blocks(
318
  with gr.Column(scale=1, min_width=250):
319
  generated_image = gr.Image(
320
  elem_classes="generated-image",
321
- show_label=False
322
  )
323
 
324
  with gr.Column(scale=4):
@@ -419,7 +425,9 @@ with gr.Blocks(
419
  input_image = gr.Image(
420
  elem_classes="generated-image",
421
  show_label=False,
422
- interactive=True
 
 
423
  )
424
 
425
  with gr.Column(scale=4):
 
45
  if len(concepts) > 9:
46
  raise gr.Error("Please enter at most 9 concepts", duration=10)
47
 
48
+ print(f"Num samples: {num_samples}")
49
+ print(f"Layer start index: {layer_start_index}")
50
+ print(f"Noise timestep: {noise_timestep}")
51
+ print(image)
52
+ image = image.convert("RGB")
53
+
54
  pipeline_output = pipeline.encode_image(
55
  image=image,
56
  prompt=prompt,
 
324
  with gr.Column(scale=1, min_width=250):
325
  generated_image = gr.Image(
326
  elem_classes="generated-image",
327
+ show_label=False,
328
  )
329
 
330
  with gr.Column(scale=4):
 
425
  input_image = gr.Image(
426
  elem_classes="generated-image",
427
  show_label=False,
428
+ interactive=True,
429
+ type="pil",
430
+ image_mode="RGB",
431
  )
432
 
433
  with gr.Column(scale=4):
concept_attention/concept_attention_pipeline.py CHANGED
@@ -5,8 +5,12 @@ from dataclasses import dataclass
5
  import PIL
6
  import numpy as np
7
  import matplotlib.pyplot as plt
 
 
 
8
  import torch
9
  import einops
 
10
 
11
  from concept_attention.binary_segmentation_baselines.raw_cross_attention import RawCrossAttentionBaseline, RawCrossAttentionSegmentationModel
12
  from concept_attention.binary_segmentation_baselines.raw_output_space import RawOutputSpaceBaseline, RawOutputSpaceSegmentationModel
@@ -18,6 +22,65 @@ class ConceptAttentionPipelineOutput():
18
  concept_heatmaps: list[PIL.Image.Image]
19
  cross_attention_maps: list[PIL.Image.Image]
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class ConceptAttentionFluxPipeline():
22
  """
23
  This is an object that allows you to generate images with flux, and
@@ -66,7 +129,7 @@ class ConceptAttentionFluxPipeline():
66
  if timesteps is None:
67
  timesteps = list(range(num_inference_steps))
68
  # Run the raw output space object
69
- image, cross_attention_maps, concept_heatmaps = self.flux_generator.generate_image(
70
  width=width,
71
  height=height,
72
  prompt=prompt,
@@ -75,56 +138,43 @@ class ConceptAttentionFluxPipeline():
75
  seed=seed,
76
  guidance=guidance,
77
  )
78
- # Concept heamaps extraction
79
- if softmax:
80
- concept_heatmaps = torch.nn.functional.softmax(concept_heatmaps, dim=-2)
81
-
82
- concept_heatmaps = concept_heatmaps[:, layer_indices]
83
- concept_heatmaps = einops.reduce(
84
- concept_heatmaps,
85
- "time layers batch concepts patches -> batch concepts patches",
86
- reduction="mean"
87
- )
88
- concept_heatmaps = einops.rearrange(
89
- concept_heatmaps,
90
- "batch concepts (h w) -> batch concepts h w",
91
- h=64,
92
- w=64
93
- )
94
- # Cross attention maps
95
- if softmax:
96
- cross_attention_maps = torch.nn.functional.softmax(cross_attention_maps, dim=-2)
97
-
98
- cross_attention_maps = cross_attention_maps[:, layer_indices]
99
- cross_attention_maps = einops.reduce(
100
- cross_attention_maps,
101
- "time layers batch concepts patches -> batch concepts patches",
102
- reduction="mean"
103
  )
104
- cross_attention_maps = einops.rearrange(
105
- cross_attention_maps,
106
- "batch concepts (h w) -> batch concepts h w",
107
- h=64,
108
- w=64
 
109
  )
110
-
111
  concept_heatmaps = concept_heatmaps.to(torch.float32).detach().cpu().numpy()[0]
112
  cross_attention_maps = cross_attention_maps.to(torch.float32).detach().cpu().numpy()[0]
113
  # Convert the torch heatmaps to PIL images.
114
  if return_pil_heatmaps:
 
 
115
  # Convert to a matplotlib color scheme
116
  colored_heatmaps = []
117
  for concept_heatmap in concept_heatmaps:
118
- concept_heatmap = (concept_heatmap - concept_heatmap.min()) / (concept_heatmap.max() - concept_heatmap.min())
119
  colored_heatmap = plt.get_cmap(cmap)(concept_heatmap)
120
  rgb_image = (colored_heatmap[:, :, :3] * 255).astype(np.uint8)
121
  colored_heatmaps.append(rgb_image)
122
 
123
  concept_heatmaps = [PIL.Image.fromarray(concept_heatmap) for concept_heatmap in colored_heatmaps]
124
 
 
 
125
  colored_cross_attention_maps = []
126
  for cross_attention_map in cross_attention_maps:
127
- cross_attention_map = (cross_attention_map - cross_attention_map.min()) / (cross_attention_map.max() - cross_attention_map.min())
128
  colored_cross_attention_map = plt.get_cmap(cmap)(cross_attention_map)
129
  rgb_image = (colored_cross_attention_map[:, :, :3] * 255).astype(np.uint8)
130
  colored_cross_attention_maps.append(rgb_image)
@@ -137,58 +187,203 @@ class ConceptAttentionFluxPipeline():
137
  cross_attention_maps=cross_attention_maps
138
  )
139
 
140
- # def encode_image(
141
- # self,
142
- # image: PIL.Image.Image,
143
- # concepts: list[str],
144
- # prompt: str = "", # Optional
145
- # width: int = 1024,
146
- # height: int = 1024,
147
- # return_cross_attention = False,
148
- # layer_indices = list(range(15, 19)),
149
- # num_samples: int = 1,
150
- # device: str = "cuda:0",
151
- # return_pil_heatmaps: bool = True,
152
- # seed: int = 0,
153
- # cmap="plasma"
154
- # ) -> ConceptAttentionPipelineOutput:
155
- # """
156
- # Encode an image with flux, given a list of concepts.
157
- # """
158
- # assert return_cross_attention is False, "Not supported yet"
159
- # assert all([layer_index >= 0 and layer_index < 19 for layer_index in layer_indices]), "Invalid layer index"
160
- # assert height == width, "Height and width must be the same for now"
161
- # # Run the raw output space object
162
- # concept_heatmaps, _ = self.output_space_segmentation_model.segment_individual_image(
163
- # image=image,
164
- # concepts=concepts,
165
- # caption=prompt,
166
- # device=device,
167
- # softmax=True,
168
- # layers=layer_indices,
169
- # num_samples=num_samples,
170
- # height=height,
171
- # width=width
172
- # )
173
- # concept_heatmaps = concept_heatmaps.detach().cpu().numpy().squeeze()
174
-
175
- # # Convert the torch heatmaps to PIL images.
176
- # if return_pil_heatmaps:
177
- # min_val = concept_heatmaps.min()
178
- # max_val = concept_heatmaps.max()
179
- # # Convert to a matplotlib color scheme
180
- # colored_heatmaps = []
181
- # for concept_heatmap in concept_heatmaps:
182
- # # concept_heatmap = (concept_heatmap - concept_heatmap.min()) / (concept_heatmap.max() - concept_heatmap.min())
183
- # concept_heatmap = (concept_heatmap - min_val) / (max_val - min_val)
184
- # colored_heatmap = plt.get_cmap(cmap)(concept_heatmap)
185
- # rgb_image = (colored_heatmap[:, :, :3] * 255).astype(np.uint8)
186
- # colored_heatmaps.append(rgb_image)
187
-
188
- # concept_heatmaps = [PIL.Image.fromarray(concept_heatmap) for concept_heatmap in colored_heatmaps]
189
-
190
- # return ConceptAttentionPipelineOutput(
191
- # image=image,
192
- # concept_heatmaps=concept_heatmaps
193
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
 
5
  import PIL
6
  import numpy as np
7
  import matplotlib.pyplot as plt
8
+ from concept_attention.flux.src.flux.sampling import prepare
9
+ from concept_attention.segmentation import add_noise_to_image, encode_image
10
+ from concept_attention.utils import embed_concepts, linear_normalization
11
  import torch
12
  import einops
13
+ from tqdm import tqdm
14
 
15
  from concept_attention.binary_segmentation_baselines.raw_cross_attention import RawCrossAttentionBaseline, RawCrossAttentionSegmentationModel
16
  from concept_attention.binary_segmentation_baselines.raw_output_space import RawOutputSpaceBaseline, RawOutputSpaceSegmentationModel
 
22
  concept_heatmaps: list[PIL.Image.Image]
23
  cross_attention_maps: list[PIL.Image.Image]
24
 
25
+
26
+ def compute_heatmaps_from_vectors(
27
+ image_vectors,
28
+ concept_vectors,
29
+ layer_indices: list[int],
30
+ timesteps: list[int] = list(range(4)),
31
+ softmax: bool = True,
32
+ normalize_concepts: bool = True
33
+ ):
34
+ """
35
+ Accepts image vectors and concept vectors. These can be from cross attentions or attention outputs.
36
+ """
37
+ print(f"Image vectors shape: {image_vectors.shape}")
38
+ print(f"Concept vectors shape: {concept_vectors.shape}")
39
+ # Check if there are heads in the input
40
+ if len(image_vectors.shape) == 6:
41
+ # Collapse the had dimension
42
+ image_vectors = einops.rearrange(
43
+ image_vectors,
44
+ "time layers batch head patches dim -> time layers batch patches (head dim)"
45
+ )
46
+ concept_vectors = einops.rearrange(
47
+ concept_vectors,
48
+ "time layers batch head concepts dim -> time layers batch concepts (head dim)"
49
+ )
50
+
51
+
52
+ # Apply linear normalization to concepts
53
+ if normalize_concepts:
54
+ concept_vectors = linear_normalization(concept_vectors, dim=-2)
55
+
56
+ # Compute heatmaps
57
+ heatmaps = einops.einsum(
58
+ image_vectors,
59
+ concept_vectors,
60
+ "time layers batch patches dim, time layers batch concepts dim -> time layers batch concepts patches",
61
+ )
62
+
63
+ # Apply softmax
64
+ if softmax:
65
+ heatmaps = torch.nn.functional.softmax(heatmaps, dim=-2)
66
+ # Pull out the timesteps and layers
67
+ heatmaps = heatmaps[timesteps]
68
+ heatmaps = heatmaps[:, layer_indices]
69
+ # Average over the heatmaps
70
+ heatmaps = einops.reduce(
71
+ heatmaps,
72
+ "time layers batch concepts patches -> batch concepts patches",
73
+ reduction="mean"
74
+ )
75
+ heatmaps = einops.rearrange(
76
+ heatmaps,
77
+ "batch concepts (h w) -> batch concepts h w",
78
+ h=64,
79
+ w=64
80
+ )
81
+
82
+ return heatmaps
83
+
84
  class ConceptAttentionFluxPipeline():
85
  """
86
  This is an object that allows you to generate images with flux, and
 
129
  if timesteps is None:
130
  timesteps = list(range(num_inference_steps))
131
  # Run the raw output space object
132
+ image, concept_attention_dict = self.flux_generator.generate_image(
133
  width=width,
134
  height=height,
135
  prompt=prompt,
 
138
  seed=seed,
139
  guidance=guidance,
140
  )
141
+
142
+ cross_attention_maps = compute_heatmaps_from_vectors(
143
+ concept_attention_dict["cross_attention_image_vectors"],
144
+ concept_attention_dict["cross_attention_concept_vectors"],
145
+ layer_indices=layer_indices,
146
+ timesteps=timesteps,
147
+ softmax=softmax
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  )
149
+ concept_heatmaps = compute_heatmaps_from_vectors(
150
+ concept_attention_dict["output_space_image_vectors"],
151
+ concept_attention_dict["output_space_concept_vectors"],
152
+ layer_indices=layer_indices,
153
+ timesteps=timesteps,
154
+ softmax=softmax
155
  )
156
+
157
  concept_heatmaps = concept_heatmaps.to(torch.float32).detach().cpu().numpy()[0]
158
  cross_attention_maps = cross_attention_maps.to(torch.float32).detach().cpu().numpy()[0]
159
  # Convert the torch heatmaps to PIL images.
160
  if return_pil_heatmaps:
161
+ concept_heatmaps_min = concept_heatmaps.min()
162
+ concept_heatmaps_max = concept_heatmaps.max()
163
  # Convert to a matplotlib color scheme
164
  colored_heatmaps = []
165
  for concept_heatmap in concept_heatmaps:
166
+ concept_heatmap = (concept_heatmap - concept_heatmaps_min) / (concept_heatmaps_max - concept_heatmaps_min)
167
  colored_heatmap = plt.get_cmap(cmap)(concept_heatmap)
168
  rgb_image = (colored_heatmap[:, :, :3] * 255).astype(np.uint8)
169
  colored_heatmaps.append(rgb_image)
170
 
171
  concept_heatmaps = [PIL.Image.fromarray(concept_heatmap) for concept_heatmap in colored_heatmaps]
172
 
173
+ cross_attention_min = cross_attention_maps.min()
174
+ cross_attention_max = cross_attention_maps.max()
175
  colored_cross_attention_maps = []
176
  for cross_attention_map in cross_attention_maps:
177
+ cross_attention_map = (cross_attention_map - cross_attention_min) / (cross_attention_max - cross_attention_min)
178
  colored_cross_attention_map = plt.get_cmap(cmap)(cross_attention_map)
179
  rgb_image = (colored_cross_attention_map[:, :, :3] * 255).astype(np.uint8)
180
  colored_cross_attention_maps.append(rgb_image)
 
187
  cross_attention_maps=cross_attention_maps
188
  )
189
 
190
+ def encode_image(
191
+ self,
192
+ image: PIL.Image.Image,
193
+ concepts: list[str],
194
+ prompt: str = "", # Optional
195
+ width: int = 1024,
196
+ height: int = 1024,
197
+ layer_indices = list(range(15, 19)),
198
+ num_samples: int = 1,
199
+ num_steps: int = 4,
200
+ noise_timestep: int = 2,
201
+ device: str = "cuda:0",
202
+ return_pil_heatmaps: bool = True,
203
+ seed: int = 0,
204
+ cmap="plasma",
205
+ stop_after_multi_modal_attentions=True,
206
+ softmax=True
207
+ ) -> ConceptAttentionPipelineOutput:
208
+ """
209
+ Encode an image with flux, given a list of concepts.
210
+ """
211
+ assert all([layer_index >= 0 and layer_index < 19 for layer_index in layer_indices]), "Invalid layer index"
212
+ assert height == width, "Height and width must be the same for now"
213
+ print("Encoding image")
214
+
215
+ # Encode the image into the VAE latent space
216
+ encoded_image_without_noise = encode_image(
217
+ image,
218
+ self.flux_generator.ae,
219
+ offload=self.flux_generator.offload,
220
+ device=device,
221
+ )
222
+ # Do N trials
223
+ combined_concept_attention_dict = {
224
+ "cross_attention_image_vectors": [],
225
+ "cross_attention_concept_vectors": [],
226
+ "output_space_image_vectors": [],
227
+ "output_space_concept_vectors": []
228
+ }
229
+ print("Sampling")
230
+ for i in tqdm(range(num_samples)):
231
+ # Add noise to image
232
+ encoded_image, timesteps = add_noise_to_image(
233
+ encoded_image_without_noise,
234
+ num_steps=num_steps,
235
+ noise_timestep=noise_timestep,
236
+ seed=seed + i,
237
+ width=width,
238
+ height=height,
239
+ device=device,
240
+ is_schnell=self.flux_generator.is_schnell,
241
+ )
242
+ # Now run the diffusion model once on the noisy image
243
+ # Encode the concept vectors
244
+
245
+ if self.flux_generator.offload:
246
+ self.flux_generator.t5, self.flux_generator.clip = self.flux_generator.t5.to(device), self.flux_generator.clip.to(device)
247
+ inp = prepare(t5=self.flux_generator.t5, clip=self.flux_generator.clip, img=encoded_image, prompt=prompt)
248
+
249
+ concept_embeddings, concept_ids, concept_vec = embed_concepts(
250
+ self.flux_generator.clip,
251
+ self.flux_generator.t5,
252
+ concepts,
253
+ )
254
+
255
+ inp["concepts"] = concept_embeddings.to(encoded_image.device)
256
+ inp["concept_ids"] = concept_ids.to(encoded_image.device)
257
+ inp["concept_vec"] = concept_vec.to(encoded_image.device)
258
+ # offload TEs to CPU, load model to gpu
259
+ if self.flux_generator.offload:
260
+ self.flux_generator.t5, self.flux_generator.clip = self.flux_generator.t5.cpu(), self.flux_generator.clip.cpu()
261
+ torch.cuda.empty_cache()
262
+ self.flux_generator.model = self.flux_generator.model.to(device)
263
+ # Denoise the intermediate images
264
+ guidance_vec = torch.full((encoded_image.shape[0],), 0.0, device=encoded_image.device, dtype=encoded_image.dtype)
265
+ t_curr = timesteps[0]
266
+ t_prev = timesteps[1]
267
+ t_vec = torch.full((encoded_image.shape[0],), t_curr, dtype=encoded_image.dtype, device=encoded_image.device)
268
+ _, concept_attention_dict = self.flux_generator.model(
269
+ img=inp["img"],
270
+ img_ids=inp["img_ids"],
271
+ txt=inp["txt"],
272
+ txt_ids=inp["txt_ids"],
273
+ concepts=inp["concepts"],
274
+ concept_ids=inp["concept_ids"],
275
+ concept_vec=inp["concept_vec"],
276
+ y=inp["concept_vec"],
277
+ timesteps=t_vec,
278
+ guidance=guidance_vec,
279
+ stop_after_multimodal_attentions=stop_after_multi_modal_attentions, # Always true for the demo
280
+ joint_attention_kwargs=None,
281
+ )
282
+
283
+ for key in combined_concept_attention_dict.keys():
284
+ combined_concept_attention_dict[key].append(concept_attention_dict[key])
285
+
286
+ # all_concept_heatmaps.append(concept_heatmaps)
287
+ # all_cross_attention_maps.append(cross_attention_maps)
288
+
289
+ # Pull out the concept and image vectors from each block
290
+ for key in combined_concept_attention_dict.keys():
291
+ combined_concept_attention_dict[key] = torch.stack(combined_concept_attention_dict[key]).squeeze(1)
292
+
293
+ # Compute the heatmaps
294
+ concept_heatmaps = compute_heatmaps_from_vectors(
295
+ combined_concept_attention_dict["output_space_image_vectors"],
296
+ combined_concept_attention_dict["output_space_concept_vectors"],
297
+ layer_indices=layer_indices,
298
+ timesteps=timesteps,
299
+ softmax=softmax
300
+ )
301
+
302
+ cross_attention_maps = compute_heatmaps_from_vectors(
303
+ combined_concept_attention_dict["cross_attention_image_vectors"],
304
+ combined_concept_attention_dict["cross_attention_concept_vectors"],
305
+ layer_indices=layer_indices,
306
+ timesteps=timesteps,
307
+ softmax=softmax
308
+ )
309
+
310
+ # # Pull out the concept and image vectors from each block
311
+ # image_vectors = torch.stack(self.flux_generator.model.image_vectors).squeeze(1)
312
+ # concept_vectors = torch.stack(self.flux_generator.model.concept_vectors).squeeze(1)
313
+ # # Apply linear normalization ot the concept vectors
314
+ # if True:
315
+ # concept_vectors = linear_normalization(concept_vectors, dim=-2)
316
+ # # Compute the heatmaps
317
+ # concept_heatmaps = einops.einsum(
318
+ # image_vectors,
319
+ # concept_vectors,
320
+ # "time layers batch patches dim, time layers batch concepts dim -> time layers batch concepts patches",
321
+ # )
322
+ # concept_heatmaps = torch.stack(all_concept_heatmaps, dim=0)
323
+ # cross_attention_maps = torch.stack(all_cross_attention_maps, dim=0)
324
+ # Concept heamaps extraction
325
+ # if softmax:
326
+ # concept_heatmaps = torch.nn.functional.softmax(concept_heatmaps, dim=-2)
327
+
328
+ # concept_heatmaps = concept_heatmaps[:, layer_indices]
329
+ # concept_heatmaps = einops.reduce(
330
+ # concept_heatmaps,
331
+ # "time layers batch concepts patches -> batch concepts patches",
332
+ # reduction="mean"
333
+ # )
334
+ # concept_heatmaps = einops.rearrange(
335
+ # concept_heatmaps,
336
+ # "batch concepts (h w) -> batch concepts h w",
337
+ # h=64,
338
+ # w=64
339
+ # )
340
+ # Cross attention maps
341
+ # if softmax:
342
+ # cross_attention_maps = torch.nn.functional.softmax(cross_attention_maps, dim=-2)
343
+
344
+ # cross_attention_maps = cross_attention_maps[:, layer_indices]
345
+ # cross_attention_maps = einops.reduce(
346
+ # cross_attention_maps,
347
+ # "time layers batch concepts patches -> batch concepts patches",
348
+ # reduction="mean"
349
+ # )
350
+ # cross_attention_maps = einops.rearrange(
351
+ # cross_attention_maps,
352
+ # "batch concepts (h w) -> batch concepts h w",
353
+ # h=64,
354
+ # w=64
355
+ # )
356
+ concept_heatmaps = concept_heatmaps.to(torch.float32).detach().cpu().numpy()[0]
357
+ # cross_attention_maps = cross_attention_maps.to(torch.float32).detach().cpu().numpy()[0]
358
+ cross_attention_maps = concept_heatmaps
359
+ # Convert the torch heatmaps to PIL images.
360
+ if return_pil_heatmaps:
361
+ concept_heatmaps_min = concept_heatmaps.min()
362
+ concept_heatmaps_max = concept_heatmaps.max()
363
+ # Convert to a matplotlib color scheme
364
+ colored_heatmaps = []
365
+ for concept_heatmap in concept_heatmaps:
366
+ concept_heatmap = (concept_heatmap - concept_heatmaps_min) / (concept_heatmaps_max - concept_heatmaps_min)
367
+ colored_heatmap = plt.get_cmap(cmap)(concept_heatmap)
368
+ rgb_image = (colored_heatmap[:, :, :3] * 255).astype(np.uint8)
369
+ colored_heatmaps.append(rgb_image)
370
+
371
+ concept_heatmaps = [PIL.Image.fromarray(concept_heatmap) for concept_heatmap in colored_heatmaps]
372
+
373
+ cross_attention_min = cross_attention_maps.min()
374
+ cross_attention_max = cross_attention_maps.max()
375
+ colored_cross_attention_maps = []
376
+ for cross_attention_map in cross_attention_maps:
377
+ cross_attention_map = (cross_attention_map - cross_attention_min) / (cross_attention_max - cross_attention_min)
378
+ colored_cross_attention_map = plt.get_cmap(cmap)(cross_attention_map)
379
+ rgb_image = (colored_cross_attention_map[:, :, :3] * 255).astype(np.uint8)
380
+ colored_cross_attention_maps.append(rgb_image)
381
+
382
+ cross_attention_maps = [PIL.Image.fromarray(cross_attention_map) for cross_attention_map in colored_cross_attention_maps]
383
+
384
+ return ConceptAttentionPipelineOutput(
385
+ image=image,
386
+ concept_heatmaps=concept_heatmaps,
387
+ cross_attention_maps=cross_attention_maps
388
+ )
389
 
concept_attention/flux/src/flux/sampling.py CHANGED
@@ -111,14 +111,18 @@ def denoise(
111
  joint_attention_kwargs=None,
112
  ):
113
  intermediate_images = [img]
114
- all_cross_attention_maps = []
115
- all_concept_attention_maps = []
 
 
 
 
116
  # this is ignored for schnell
117
  guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
118
  iteration = 0
119
  for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])):
120
  t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
121
- pred, cross_attention_maps, concept_attention_maps = model(
122
  img=img,
123
  img_ids=img_ids,
124
  txt=txt,
@@ -138,13 +142,13 @@ def denoise(
138
  # increment iteration
139
  iteration += 1
140
 
141
- all_cross_attention_maps.append(cross_attention_maps)
142
- all_concept_attention_maps.append(concept_attention_maps)
143
 
144
- all_cross_attention_maps = torch.stack(all_cross_attention_maps, dim=0)
145
- all_concept_attention_maps = torch.stack(all_concept_attention_maps, dim=0)
146
 
147
- return img, intermediate_images, all_cross_attention_maps, all_concept_attention_maps
148
 
149
  def unpack(x: Tensor, height: int, width: int) -> Tensor:
150
  return rearrange(
 
111
  joint_attention_kwargs=None,
112
  ):
113
  intermediate_images = [img]
114
+ combined_concept_attention_dict = {
115
+ "output_space_concept_vectors": [],
116
+ "output_space_image_vectors": [],
117
+ "cross_attention_concept_vectors": [],
118
+ "cross_attention_image_vectors": [],
119
+ }
120
  # this is ignored for schnell
121
  guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
122
  iteration = 0
123
  for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])):
124
  t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
125
+ pred, concept_attention_dict = model(
126
  img=img,
127
  img_ids=img_ids,
128
  txt=txt,
 
142
  # increment iteration
143
  iteration += 1
144
 
145
+ for key in combined_concept_attention_dict.keys():
146
+ combined_concept_attention_dict[key].append(concept_attention_dict[key])
147
 
148
+ for key in combined_concept_attention_dict.keys():
149
+ combined_concept_attention_dict[key] = torch.stack(combined_concept_attention_dict[key], dim=0)
150
 
151
+ return img, intermediate_images, combined_concept_attention_dict
152
 
153
  def unpack(x: Tensor, height: int, width: int) -> Tensor:
154
  return rearrange(
concept_attention/image_generator.py CHANGED
@@ -171,7 +171,7 @@ class FluxGenerator():
171
  torch.cuda.empty_cache()
172
  self.model = self.model.to(self.device)
173
  # denoise initial noise
174
- x, intermediate_images, cross_attention_maps, concept_attention_maps = denoise(
175
  self.model,
176
  **inp,
177
  timesteps=timesteps,
@@ -203,4 +203,4 @@ class FluxGenerator():
203
 
204
  img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
205
 
206
- return img, cross_attention_maps, concept_attention_maps
 
171
  torch.cuda.empty_cache()
172
  self.model = self.model.to(self.device)
173
  # denoise initial noise
174
+ x, _, concept_attention_dict = denoise(
175
  self.model,
176
  **inp,
177
  timesteps=timesteps,
 
203
 
204
  img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
205
 
206
+ return img, concept_attention_dict
concept_attention/modified_double_stream_block.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  from torch import nn, Tensor
3
  import einops
@@ -77,6 +78,7 @@ class ModifiedDoubleStreamBlock(nn.Module):
77
  concept_vec: Tensor,
78
  concept_pe: Tensor,
79
  joint_attention_kwargs=None,
 
80
  **kwargs
81
  ) -> tuple[Tensor, Tensor]:
82
  assert concept_vec is not None, "Concept vectors must be provided for this implementation."
@@ -175,19 +177,26 @@ class ModifiedDoubleStreamBlock(nn.Module):
175
  concept_attn = einops.rearrange(concept_attn, "B H L D -> B L (H D)")
176
  img_attn = einops.rearrange(img_attn, "B H L D -> B L (H D)")
177
 
178
- # Compute the cross attentions
179
- cross_attention_maps = einops.einsum(
180
- concept_q,
181
- img_q,
182
- "batch head concepts dim, batch had patches dim -> batch head concepts patches"
183
- )
184
- cross_attention_maps = einops.reduce(cross_attention_maps, "batch head concepts patches -> batch concepts patches", reduction="mean")
185
- # Compute the concept attentions
186
- concept_attention_maps = einops.einsum(
187
- concept_attn,
188
- img_attn,
189
- "batch concepts dim, batch patches dim -> batch concepts patches"
190
- )
 
 
 
 
 
 
 
191
  # Do the block updates
192
  # Calculate the img blocks
193
  img = img + img_mod1.gate * self.img_attn.proj(img_attn)
@@ -200,4 +209,4 @@ class ModifiedDoubleStreamBlock(nn.Module):
200
  concepts = concepts + concept_mod1.gate * self.txt_attn.proj(concept_attn)
201
  concepts = concepts + concept_mod2.gate * self.txt_mlp((1 + concept_mod2.scale) * self.txt_norm2(concepts) + concept_mod2.shift)
202
 
203
- return img, txt, concepts, cross_attention_maps, concept_attention_maps
 
1
+ from concept_attention.utils import linear_normalization
2
  import torch
3
  from torch import nn, Tensor
4
  import einops
 
78
  concept_vec: Tensor,
79
  concept_pe: Tensor,
80
  joint_attention_kwargs=None,
81
+ normalize_concepts=True,
82
  **kwargs
83
  ) -> tuple[Tensor, Tensor]:
84
  assert concept_vec is not None, "Concept vectors must be provided for this implementation."
 
177
  concept_attn = einops.rearrange(concept_attn, "B H L D -> B L (H D)")
178
  img_attn = einops.rearrange(img_attn, "B H L D -> B L (H D)")
179
 
180
+ concept_attention_dict = {
181
+ "output_space_concept_vectors": concept_attn,
182
+ "output_space_image_vectors": img_attn,
183
+ "cross_attention_concept_vectors": concept_q,
184
+ "cross_attention_image_vectors": img_q
185
+ }
186
+
187
+ # # Compute the cross attentions
188
+ # cross_attention_maps = einops.einsum(
189
+ # concept_q,
190
+ # img_q,
191
+ # "batch head concepts dim, batch had patches dim -> batch head concepts patches"
192
+ # )
193
+ # cross_attention_maps = einops.reduce(cross_attention_maps, "batch head concepts patches -> batch concepts patches", reduction="mean")
194
+ # # Compute the concept attentions
195
+ # concept_attention_maps = einops.einsum(
196
+ # concept_attn,
197
+ # img_attn,
198
+ # "batch concepts dim, batch patches dim -> batch concepts patches"
199
+ # )
200
  # Do the block updates
201
  # Calculate the img blocks
202
  img = img + img_mod1.gate * self.img_attn.proj(img_attn)
 
209
  concepts = concepts + concept_mod1.gate * self.txt_attn.proj(concept_attn)
210
  concepts = concepts + concept_mod2.gate * self.txt_mlp((1 + concept_mod2.scale) * self.txt_norm2(concepts) + concept_mod2.shift)
211
 
212
+ return img, txt, concepts, concept_attention_dict
concept_attention/modified_flux_dit.py CHANGED
@@ -119,10 +119,14 @@ class ModifiedFluxDiT(nn.Module):
119
  concept_vec = concept_vec + self.vector_in(original_concept_vec)
120
  concepts = self.txt_in(concepts)
121
  ############## Modify the double blocks to also return concept vectors ##############
122
- all_cross_attention_maps = []
123
- all_concept_attention_maps = []
 
 
 
 
124
  for block in self.double_blocks:
125
- img, txt, concepts, cross_attention_maps, concept_attention_maps = block(
126
  img=img,
127
  txt=txt,
128
  vec=vec,
@@ -134,18 +138,18 @@ class ModifiedFluxDiT(nn.Module):
134
  iteration=iteration,
135
  joint_attention_kwargs=joint_attention_kwargs
136
  )
137
- all_cross_attention_maps.append(cross_attention_maps)
138
- all_concept_attention_maps.append(concept_attention_maps)
139
 
140
- all_concept_attention_maps = torch.stack(all_concept_attention_maps, dim=0)
141
- all_cross_attention_maps = torch.stack(all_cross_attention_maps, dim=0)
142
  #####################################################################################
143
 
144
  img = torch.cat((txt, img), 1)
145
 
146
  # Speed up segmentation by not generating the full image
147
  if stop_after_multimodal_attentions:
148
- return None, all_cross_attention_maps, all_concept_attention_maps
149
 
150
  # Do the single blocks now
151
  for block in self.single_blocks:
@@ -154,4 +158,5 @@ class ModifiedFluxDiT(nn.Module):
154
  img = img[:, txt.shape[1] :, ...]
155
 
156
  img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
157
- return img, all_cross_attention_maps, all_concept_attention_maps
 
 
119
  concept_vec = concept_vec + self.vector_in(original_concept_vec)
120
  concepts = self.txt_in(concepts)
121
  ############## Modify the double blocks to also return concept vectors ##############
122
+ combined_concept_attention_dict = {
123
+ "output_space_concept_vectors": [],
124
+ "output_space_image_vectors": [],
125
+ "cross_attention_concept_vectors": [],
126
+ "cross_attention_image_vectors": []
127
+ }
128
  for block in self.double_blocks:
129
+ img, txt, concepts, concept_attention_dict = block(
130
  img=img,
131
  txt=txt,
132
  vec=vec,
 
138
  iteration=iteration,
139
  joint_attention_kwargs=joint_attention_kwargs
140
  )
141
+ for key in combined_concept_attention_dict.keys():
142
+ combined_concept_attention_dict[key].append(concept_attention_dict[key])
143
 
144
+ for key in combined_concept_attention_dict.keys():
145
+ combined_concept_attention_dict[key] = torch.stack(combined_concept_attention_dict[key], dim=0)
146
  #####################################################################################
147
 
148
  img = torch.cat((txt, img), 1)
149
 
150
  # Speed up segmentation by not generating the full image
151
  if stop_after_multimodal_attentions:
152
+ return None, combined_concept_attention_dict
153
 
154
  # Do the single blocks now
155
  for block in self.single_blocks:
 
158
  img = img[:, txt.shape[1] :, ...]
159
 
160
  img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
161
+
162
+ return img, combined_concept_attention_dict
concept_attention/segmentation.py CHANGED
@@ -125,6 +125,7 @@ def encode_image(
125
  Encodes a PIL image to the VAE latent space and adds noise to it
126
  """
127
  if isinstance(image, PIL.Image.Image):
 
128
  transform = transforms.Compose([
129
  transforms.ToTensor(),
130
  transforms.Lambda(lambda x: 2.0 * x - 1.0),
 
125
  Encodes a PIL image to the VAE latent space and adds noise to it
126
  """
127
  if isinstance(image, PIL.Image.Image):
128
+ image = image.convert("RGB")
129
  transform = transforms.Compose([
130
  transforms.ToTensor(),
131
  transforms.Lambda(lambda x: 2.0 * x - 1.0),
experiments/test_image_encoder/test_encode_image.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concept_attention.concept_attention_pipeline import ConceptAttentionFluxPipeline
2
+ from PIL import Image
3
+
4
+ if __name__ == "__main__":
5
+ pipeline = ConceptAttentionFluxPipeline(
6
+ model_name="flux-schnell",
7
+ offload_model=True
8
+ ) # , device="cuda:0") # , offload_model=True)
9
+
10
+ image = Image.open("image.png").convert("RGB")
11
+
12
+ outputs = pipeline.encode_image(
13
+ image,
14
+ concepts=["animal", "background"]
15
+ )
16
+ concept_attention_maps = outputs.concept_heatmaps
17
+
18
+ concepts = ["animal", "background"]
19
+ for concept, attention_map in zip(concepts, concept_attention_maps):
20
+ attention_map.save(f"{concept}_attention_map.png")