helblazer811 commited on
Commit
227c367
·
1 Parent(s): ed0bb32

Fixed UI for mobile and the logic/UI for the second page.

Browse files
CrossAttentionCallout.svg ADDED
app.py CHANGED
@@ -69,14 +69,27 @@ def encode_image(image, prompt, concepts, seed, layer_start_index, noise_timeste
69
 
70
  cross_attention_heatmaps = pipeline_output.cross_attention_maps
71
  cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps]
72
- cross_attention_maps_and_labels = [(cross_attention_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  return output_image, \
75
  gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \
76
  gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels))
77
 
78
  except gr.Error as e:
79
- return None, gr.update(value=[], columns=1), gr.update(value=[], columns=1)
80
 
81
 
82
  @spaces.GPU(duration=60)
@@ -116,7 +129,20 @@ def generate_image(prompt, concepts, seed, layer_start_index, timestep_start_ind
116
 
117
  cross_attention_heatmaps = pipeline_output.cross_attention_maps
118
  cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps]
119
- cross_attention_maps_and_labels = [(cross_attention_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  return output_image, \
122
  gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \
@@ -145,11 +171,7 @@ with gr.Blocks(
145
  .input {
146
  height: 47px;
147
  }
148
- .input-column {
149
- flex-direction: column;
150
- gap: 0px;
151
- height: 100%;
152
- }
153
  .input-column-label {}
154
  .gallery {
155
  height: 220px;
@@ -162,52 +184,49 @@ with gr.Blocks(
162
  scrollbar-width: thin;
163
  scrollbar-color: grey black;
164
  }
165
-
166
- /* Show only on screens wider than 768px (adjust as needed)
167
- @media (min-width: 1024px) {
168
- .svg-container {
169
- min-width: 150px;
170
- width: 200px;
171
- padding-top: 540px;
172
- }
173
- }
174
 
175
  @media (min-width: 1280px) {
176
- .svg-container {
177
- min-width: 200px;
178
- width: 300px;
179
- padding-top: 420px;
180
- }
181
- }
182
- @media (min-width: 1530px) {
183
- .svg-container {
184
- min-width: 200px;
185
- width: 300px;
186
- padding-top: 400px;
187
- }
188
- }
189
-
190
- */
191
-
192
- @media (min-width: 1024px) {
193
  .svg-container {
194
  min-width: 250px;
 
 
 
195
  }
196
- #concept-attention-callout-svg {
197
  width: 250px;
198
  }
 
 
 
 
 
 
 
 
199
  }
200
 
201
-
202
- @media (max-width: 1024px) {
203
  .svg-container {
204
  display: none !important;
205
  }
206
- #concept-attention-callout-svg {
207
  display: none;
208
  }
209
  }
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  .header {
212
  display: flex;
213
  flex-direction: column;
@@ -241,11 +260,6 @@ with gr.Blocks(
241
  text-decoration: none;
242
  }
243
 
244
- .svg-container {
245
- display: flex;
246
- justify-content: center;
247
- align-items: center;
248
- }
249
 
250
  .caption-label {
251
  font-size: 1.15em;
@@ -415,8 +429,7 @@ with gr.Blocks(
415
  elem_classes="input"
416
  )
417
 
418
- with gr.Row(elem_classes="gallery-container", scale=8):
419
-
420
  with gr.Column(scale=1, min_width=250):
421
  input_image = gr.Image(
422
  elem_classes="generated-image",
@@ -424,9 +437,10 @@ with gr.Blocks(
424
  interactive=True,
425
  type="pil",
426
  image_mode="RGB",
 
427
  )
428
 
429
- with gr.Column(scale=4):
430
  concept_attention_gallery = gr.Gallery(
431
  label="Concept Attention (Ours)",
432
  show_label=True,
@@ -438,7 +452,6 @@ with gr.Blocks(
438
  elem_id="concept-attention-gallery",
439
  # scale=4
440
  )
441
-
442
  cross_attention_gallery = gr.Gallery(
443
  label="Cross Attention",
444
  show_label=True,
@@ -476,7 +489,11 @@ with gr.Blocks(
476
 
477
  with gr.Row(scale=4, elem_classes="svg-container"):
478
  concept_attention_callout_svg = gr.HTML(
479
- "<img src='/gradio_api/file=ConceptAttentionCallout.svg' id='concept-attention-callout-svg'/>",
 
 
 
 
480
  # container=False,
481
  )
482
 
 
69
 
70
  cross_attention_heatmaps = pipeline_output.cross_attention_maps
71
  cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps]
72
+ cross_attention_maps_and_labels = []
73
+ prompt_tokens = prompt.split()
74
+ for concept_index in range(len(concepts)):
75
+ concept = concepts[concept_index]
76
+ if concept in prompt_tokens:
77
+ cross_attention_maps_and_labels.append(
78
+ (cross_attention_heatmaps[concept_index], concept)
79
+ )
80
+ else:
81
+ # Exclude this concept because it is only generated due to ConceptAttention's causal attention mechanism
82
+ empty_image = Image.new("RGB", (IMG_SIZE, IMG_SIZE), (39, 39, 42))
83
+ cross_attention_maps_and_labels.append(
84
+ (empty_image, concept)
85
+ )
86
 
87
  return output_image, \
88
  gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \
89
  gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels))
90
 
91
  except gr.Error as e:
92
+ return None, gr.update(value=[], columns=1) # , gr.update(value=[], columns=1)
93
 
94
 
95
  @spaces.GPU(duration=60)
 
129
 
130
  cross_attention_heatmaps = pipeline_output.cross_attention_maps
131
  cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps]
132
+ cross_attention_maps_and_labels = []
133
+ prompt_tokens = prompt.split()
134
+ for concept_index in range(len(concepts)):
135
+ concept = concepts[concept_index]
136
+ if concept in prompt_tokens:
137
+ cross_attention_maps_and_labels.append(
138
+ (cross_attention_heatmaps[concept_index], concept)
139
+ )
140
+ else:
141
+ # Exclude this concept because it is only generated due to ConceptAttention's causal attention mechanism
142
+ empty_image = Image.new("RGB", (IMG_SIZE, IMG_SIZE), (39, 39, 42))
143
+ cross_attention_maps_and_labels.append(
144
+ (empty_image, concept)
145
+ )
146
 
147
  return output_image, \
148
  gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \
 
171
  .input {
172
  height: 47px;
173
  }
174
+
 
 
 
 
175
  .input-column-label {}
176
  .gallery {
177
  height: 220px;
 
184
  scrollbar-width: thin;
185
  scrollbar-color: grey black;
186
  }
 
 
 
 
 
 
 
 
 
187
 
188
  @media (min-width: 1280px) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  .svg-container {
190
  min-width: 250px;
191
+ display: flex;
192
+ flex-direction: column;
193
+ padding-top: 340px;
194
  }
195
+ .callout {
196
  width: 250px;
197
  }
198
+ .input-row {
199
+ height: 100px;
200
+ }
201
+ .input-column {
202
+ flex-direction: column;
203
+ gap: 0px;
204
+ height: 100%;
205
+ }
206
  }
207
 
208
+ @media (max-width: 1280px) {
 
209
  .svg-container {
210
  display: none !important;
211
  }
212
+ .callout {
213
  display: none;
214
  }
215
  }
216
 
217
+ /*
218
+ @media (max-width: 1024px) {
219
+ .svg-container {
220
+ display: none !important;
221
+ display: flex;
222
+ flex-direction: column;
223
+ }
224
+ .callout {
225
+ display: none;
226
+ }
227
+ }
228
+ */
229
+
230
  .header {
231
  display: flex;
232
  flex-direction: column;
 
260
  text-decoration: none;
261
  }
262
 
 
 
 
 
 
263
 
264
  .caption-label {
265
  font-size: 1.15em;
 
429
  elem_classes="input"
430
  )
431
 
432
+ with gr.Row(elem_classes="gallery-container", scale=8, equal_height=True):
 
433
  with gr.Column(scale=1, min_width=250):
434
  input_image = gr.Image(
435
  elem_classes="generated-image",
 
437
  interactive=True,
438
  type="pil",
439
  image_mode="RGB",
440
+ scale=1
441
  )
442
 
443
+ with gr.Column(scale=2):
444
  concept_attention_gallery = gr.Gallery(
445
  label="Concept Attention (Ours)",
446
  show_label=True,
 
452
  elem_id="concept-attention-gallery",
453
  # scale=4
454
  )
 
455
  cross_attention_gallery = gr.Gallery(
456
  label="Cross Attention",
457
  show_label=True,
 
489
 
490
  with gr.Row(scale=4, elem_classes="svg-container"):
491
  concept_attention_callout_svg = gr.HTML(
492
+ "<img src='/gradio_api/file=ConceptAttentionCallout.svg' class='callout'/>",
493
+ # container=False,
494
+ )
495
+ cross_attention_callout_svg = gr.HTML(
496
+ "<img src='/gradio_api/file=CrossAttentionCallout.svg' class='callout'/>",
497
  # container=False,
498
  )
499
 
concept_attention/concept_attention_pipeline.py CHANGED
@@ -29,13 +29,11 @@ def compute_heatmaps_from_vectors(
29
  layer_indices: list[int],
30
  timesteps: list[int] = list(range(4)),
31
  softmax: bool = True,
32
- normalize_concepts: bool = True
33
  ):
34
  """
35
  Accepts image vectors and concept vectors. These can be from cross attentions or attention outputs.
36
  """
37
- print(f"Image vectors shape: {image_vectors.shape}")
38
- print(f"Concept vectors shape: {concept_vectors.shape}")
39
  # Check if there are heads in the input
40
  if len(image_vectors.shape) == 6:
41
  # Collapse the had dimension
@@ -139,6 +137,25 @@ class ConceptAttentionFluxPipeline():
139
  guidance=guidance,
140
  )
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  cross_attention_maps = compute_heatmaps_from_vectors(
143
  concept_attention_dict["cross_attention_image_vectors"],
144
  concept_attention_dict["cross_attention_concept_vectors"],
@@ -146,6 +163,7 @@ class ConceptAttentionFluxPipeline():
146
  timesteps=timesteps,
147
  softmax=softmax
148
  )
 
149
  concept_heatmaps = compute_heatmaps_from_vectors(
150
  concept_attention_dict["output_space_image_vectors"],
151
  concept_attention_dict["output_space_concept_vectors"],
@@ -223,8 +241,9 @@ class ConceptAttentionFluxPipeline():
223
  combined_concept_attention_dict = {
224
  "cross_attention_image_vectors": [],
225
  "cross_attention_concept_vectors": [],
 
226
  "output_space_image_vectors": [],
227
- "output_space_concept_vectors": []
228
  }
229
  print("Sampling")
230
  for i in tqdm(range(num_samples)):
@@ -307,6 +326,26 @@ class ConceptAttentionFluxPipeline():
307
  softmax=softmax
308
  )
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  # # Pull out the concept and image vectors from each block
311
  # image_vectors = torch.stack(self.flux_generator.model.image_vectors).squeeze(1)
312
  # concept_vectors = torch.stack(self.flux_generator.model.concept_vectors).squeeze(1)
 
29
  layer_indices: list[int],
30
  timesteps: list[int] = list(range(4)),
31
  softmax: bool = True,
32
+ normalize_concepts: bool = False
33
  ):
34
  """
35
  Accepts image vectors and concept vectors. These can be from cross attentions or attention outputs.
36
  """
 
 
37
  # Check if there are heads in the input
38
  if len(image_vectors.shape) == 6:
39
  # Collapse the had dimension
 
137
  guidance=guidance,
138
  )
139
 
140
+ # # cross_attention_maps = concept_attention_dict["cross_attention_maps"]
141
+ # # Apply softmax
142
+ # if softmax:
143
+ # cross_attention_maps = torch.nn.functional.softmax(cross_attention_maps, dim=-2)
144
+ # # Pull out the timesteps and layers
145
+ # cross_attention_maps = cross_attention_maps[timesteps]
146
+ # cross_attention_maps = cross_attention_maps[:, layer_indices]
147
+ # # Average over time, had, and layers
148
+ # cross_attention_maps = einops.reduce(
149
+ # cross_attention_maps,
150
+ # "time layers batch head concepts patches -> batch concepts patches",
151
+ # reduction="mean"
152
+ # )
153
+ # cross_attention_maps = einops.rearrange(
154
+ # cross_attention_maps,
155
+ # "batch concepts (h w) -> batch concepts h w",
156
+ # h=64,
157
+ # w=64
158
+ # )
159
  cross_attention_maps = compute_heatmaps_from_vectors(
160
  concept_attention_dict["cross_attention_image_vectors"],
161
  concept_attention_dict["cross_attention_concept_vectors"],
 
163
  timesteps=timesteps,
164
  softmax=softmax
165
  )
166
+ # Compute concept the heatmaps
167
  concept_heatmaps = compute_heatmaps_from_vectors(
168
  concept_attention_dict["output_space_image_vectors"],
169
  concept_attention_dict["output_space_concept_vectors"],
 
241
  combined_concept_attention_dict = {
242
  "cross_attention_image_vectors": [],
243
  "cross_attention_concept_vectors": [],
244
+ # "cross_attention_maps": [],
245
  "output_space_image_vectors": [],
246
+ "output_space_concept_vectors": [],
247
  }
248
  print("Sampling")
249
  for i in tqdm(range(num_samples)):
 
326
  softmax=softmax
327
  )
328
 
329
+ # cross_attention_maps = concept_attention_dict["cross_attention_maps"]
330
+ # # Apply softmax
331
+ # if softmax:
332
+ # cross_attention_maps = torch.nn.functional.softmax(cross_attention_maps, dim=-2)
333
+ # # Pull out the timesteps and layers
334
+ # cross_attention_maps = cross_attention_maps[timesteps]
335
+ # cross_attention_maps = cross_attention_maps[:, layer_indices]
336
+ # # Average over time, had, and layers
337
+ # cross_attention_maps = einops.reduce(
338
+ # cross_attention_maps,
339
+ # "time layers batch head concepts patches -> batch concepts patches",
340
+ # reduction="mean"
341
+ # )
342
+ # cross_attention_maps = einops.rearrange(
343
+ # cross_attention_maps,
344
+ # "batch concepts (h w) -> batch concepts h w",
345
+ # h=64,
346
+ # w=64
347
+ # )
348
+
349
  # # Pull out the concept and image vectors from each block
350
  # image_vectors = torch.stack(self.flux_generator.model.image_vectors).squeeze(1)
351
  # concept_vectors = torch.stack(self.flux_generator.model.concept_vectors).squeeze(1)
concept_attention/flux/src/flux/sampling.py CHANGED
@@ -114,6 +114,7 @@ def denoise(
114
  combined_concept_attention_dict = {
115
  "output_space_concept_vectors": [],
116
  "output_space_image_vectors": [],
 
117
  "cross_attention_concept_vectors": [],
118
  "cross_attention_image_vectors": [],
119
  }
 
114
  combined_concept_attention_dict = {
115
  "output_space_concept_vectors": [],
116
  "output_space_image_vectors": [],
117
+ # "cross_attention_maps": [],
118
  "cross_attention_concept_vectors": [],
119
  "cross_attention_image_vectors": [],
120
  }
concept_attention/modified_double_stream_block.py CHANGED
@@ -4,7 +4,6 @@ from torch import nn, Tensor
4
  import einops
5
  import math
6
  import torch.nn.functional as F
7
- import matplotlib.pyplot as plt
8
 
9
  from concept_attention.flux.src.flux.modules.layers import Modulation, SelfAttention
10
  from concept_attention.flux.src.flux.math import apply_rope
@@ -167,7 +166,6 @@ class ModifiedDoubleStreamBlock(nn.Module):
167
  )
168
  # Separate the concept and image attentions
169
  concept_attn = concept_image_attn[:, :, :concepts.shape[1]]
170
-
171
  # Rearrange the attention tensors
172
  txt_attn = einops.rearrange(txt_attn, "B H L D -> B L (H D)")
173
  if joint_attention_kwargs is not None and joint_attention_kwargs.get("keep_head_dim", False):
@@ -177,26 +175,20 @@ class ModifiedDoubleStreamBlock(nn.Module):
177
  concept_attn = einops.rearrange(concept_attn, "B H L D -> B L (H D)")
178
  img_attn = einops.rearrange(img_attn, "B H L D -> B L (H D)")
179
 
180
- concept_attention_dict = {
181
- "output_space_concept_vectors": concept_attn,
182
- "output_space_image_vectors": img_attn,
183
- "cross_attention_concept_vectors": concept_q,
184
- "cross_attention_image_vectors": img_q
185
- }
186
-
187
  # # Compute the cross attentions
188
  # cross_attention_maps = einops.einsum(
189
  # concept_q,
190
  # img_q,
191
  # "batch head concepts dim, batch had patches dim -> batch head concepts patches"
192
  # )
193
- # cross_attention_maps = einops.reduce(cross_attention_maps, "batch head concepts patches -> batch concepts patches", reduction="mean")
194
- # # Compute the concept attentions
195
- # concept_attention_maps = einops.einsum(
196
- # concept_attn,
197
- # img_attn,
198
- # "batch concepts dim, batch patches dim -> batch concepts patches"
199
- # )
 
200
  # Do the block updates
201
  # Calculate the img blocks
202
  img = img + img_mod1.gate * self.img_attn.proj(img_attn)
 
4
  import einops
5
  import math
6
  import torch.nn.functional as F
 
7
 
8
  from concept_attention.flux.src.flux.modules.layers import Modulation, SelfAttention
9
  from concept_attention.flux.src.flux.math import apply_rope
 
166
  )
167
  # Separate the concept and image attentions
168
  concept_attn = concept_image_attn[:, :, :concepts.shape[1]]
 
169
  # Rearrange the attention tensors
170
  txt_attn = einops.rearrange(txt_attn, "B H L D -> B L (H D)")
171
  if joint_attention_kwargs is not None and joint_attention_kwargs.get("keep_head_dim", False):
 
175
  concept_attn = einops.rearrange(concept_attn, "B H L D -> B L (H D)")
176
  img_attn = einops.rearrange(img_attn, "B H L D -> B L (H D)")
177
 
 
 
 
 
 
 
 
178
  # # Compute the cross attentions
179
  # cross_attention_maps = einops.einsum(
180
  # concept_q,
181
  # img_q,
182
  # "batch head concepts dim, batch had patches dim -> batch head concepts patches"
183
  # )
184
+ # Collect all of the concept attention information
185
+ concept_attention_dict = {
186
+ "output_space_concept_vectors": concept_attn.detach(),
187
+ "output_space_image_vectors": img_attn.detach(),
188
+ # "cross_attention_maps": cross_attention_maps.detach(),
189
+ "cross_attention_concept_vectors": concept_q.detach(),
190
+ "cross_attention_image_vectors": img_q.detach()
191
+ }
192
  # Do the block updates
193
  # Calculate the img blocks
194
  img = img + img_mod1.gate * self.img_attn.proj(img_attn)
concept_attention/modified_flux_dit.py CHANGED
@@ -122,8 +122,9 @@ class ModifiedFluxDiT(nn.Module):
122
  combined_concept_attention_dict = {
123
  "output_space_concept_vectors": [],
124
  "output_space_image_vectors": [],
 
125
  "cross_attention_concept_vectors": [],
126
- "cross_attention_image_vectors": []
127
  }
128
  for block in self.double_blocks:
129
  img, txt, concepts, concept_attention_dict = block(
 
122
  combined_concept_attention_dict = {
123
  "output_space_concept_vectors": [],
124
  "output_space_image_vectors": [],
125
+ # "cross_attention_maps": [],
126
  "cross_attention_concept_vectors": [],
127
+ "cross_attention_image_vectors": [],
128
  }
129
  for block in self.double_blocks:
130
  img, txt, concepts, concept_attention_dict = block(