Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
29c7873
1
Parent(s):
bde9560
Added capability for uploading existing images to the UI.
Browse files- app.py +10 -2
- concept_attention/concept_attention_pipeline.py +283 -88
- concept_attention/flux/src/flux/sampling.py +12 -8
- concept_attention/image_generator.py +2 -2
- concept_attention/modified_double_stream_block.py +23 -14
- concept_attention/modified_flux_dit.py +14 -9
- concept_attention/segmentation.py +1 -0
- experiments/test_image_encoder/test_encode_image.py +20 -0
app.py
CHANGED
@@ -45,6 +45,12 @@ def encode_image(image, prompt, concepts, seed, layer_start_index, noise_timeste
|
|
45 |
if len(concepts) > 9:
|
46 |
raise gr.Error("Please enter at most 9 concepts", duration=10)
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
pipeline_output = pipeline.encode_image(
|
49 |
image=image,
|
50 |
prompt=prompt,
|
@@ -318,7 +324,7 @@ with gr.Blocks(
|
|
318 |
with gr.Column(scale=1, min_width=250):
|
319 |
generated_image = gr.Image(
|
320 |
elem_classes="generated-image",
|
321 |
-
show_label=False
|
322 |
)
|
323 |
|
324 |
with gr.Column(scale=4):
|
@@ -419,7 +425,9 @@ with gr.Blocks(
|
|
419 |
input_image = gr.Image(
|
420 |
elem_classes="generated-image",
|
421 |
show_label=False,
|
422 |
-
interactive=True
|
|
|
|
|
423 |
)
|
424 |
|
425 |
with gr.Column(scale=4):
|
|
|
45 |
if len(concepts) > 9:
|
46 |
raise gr.Error("Please enter at most 9 concepts", duration=10)
|
47 |
|
48 |
+
print(f"Num samples: {num_samples}")
|
49 |
+
print(f"Layer start index: {layer_start_index}")
|
50 |
+
print(f"Noise timestep: {noise_timestep}")
|
51 |
+
print(image)
|
52 |
+
image = image.convert("RGB")
|
53 |
+
|
54 |
pipeline_output = pipeline.encode_image(
|
55 |
image=image,
|
56 |
prompt=prompt,
|
|
|
324 |
with gr.Column(scale=1, min_width=250):
|
325 |
generated_image = gr.Image(
|
326 |
elem_classes="generated-image",
|
327 |
+
show_label=False,
|
328 |
)
|
329 |
|
330 |
with gr.Column(scale=4):
|
|
|
425 |
input_image = gr.Image(
|
426 |
elem_classes="generated-image",
|
427 |
show_label=False,
|
428 |
+
interactive=True,
|
429 |
+
type="pil",
|
430 |
+
image_mode="RGB",
|
431 |
)
|
432 |
|
433 |
with gr.Column(scale=4):
|
concept_attention/concept_attention_pipeline.py
CHANGED
@@ -5,8 +5,12 @@ from dataclasses import dataclass
|
|
5 |
import PIL
|
6 |
import numpy as np
|
7 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
8 |
import torch
|
9 |
import einops
|
|
|
10 |
|
11 |
from concept_attention.binary_segmentation_baselines.raw_cross_attention import RawCrossAttentionBaseline, RawCrossAttentionSegmentationModel
|
12 |
from concept_attention.binary_segmentation_baselines.raw_output_space import RawOutputSpaceBaseline, RawOutputSpaceSegmentationModel
|
@@ -18,6 +22,65 @@ class ConceptAttentionPipelineOutput():
|
|
18 |
concept_heatmaps: list[PIL.Image.Image]
|
19 |
cross_attention_maps: list[PIL.Image.Image]
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
class ConceptAttentionFluxPipeline():
|
22 |
"""
|
23 |
This is an object that allows you to generate images with flux, and
|
@@ -66,7 +129,7 @@ class ConceptAttentionFluxPipeline():
|
|
66 |
if timesteps is None:
|
67 |
timesteps = list(range(num_inference_steps))
|
68 |
# Run the raw output space object
|
69 |
-
image,
|
70 |
width=width,
|
71 |
height=height,
|
72 |
prompt=prompt,
|
@@ -75,56 +138,43 @@ class ConceptAttentionFluxPipeline():
|
|
75 |
seed=seed,
|
76 |
guidance=guidance,
|
77 |
)
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
"time layers batch concepts patches -> batch concepts patches",
|
86 |
-
reduction="mean"
|
87 |
-
)
|
88 |
-
concept_heatmaps = einops.rearrange(
|
89 |
-
concept_heatmaps,
|
90 |
-
"batch concepts (h w) -> batch concepts h w",
|
91 |
-
h=64,
|
92 |
-
w=64
|
93 |
-
)
|
94 |
-
# Cross attention maps
|
95 |
-
if softmax:
|
96 |
-
cross_attention_maps = torch.nn.functional.softmax(cross_attention_maps, dim=-2)
|
97 |
-
|
98 |
-
cross_attention_maps = cross_attention_maps[:, layer_indices]
|
99 |
-
cross_attention_maps = einops.reduce(
|
100 |
-
cross_attention_maps,
|
101 |
-
"time layers batch concepts patches -> batch concepts patches",
|
102 |
-
reduction="mean"
|
103 |
)
|
104 |
-
|
105 |
-
|
106 |
-
"
|
107 |
-
|
108 |
-
|
|
|
109 |
)
|
110 |
-
|
111 |
concept_heatmaps = concept_heatmaps.to(torch.float32).detach().cpu().numpy()[0]
|
112 |
cross_attention_maps = cross_attention_maps.to(torch.float32).detach().cpu().numpy()[0]
|
113 |
# Convert the torch heatmaps to PIL images.
|
114 |
if return_pil_heatmaps:
|
|
|
|
|
115 |
# Convert to a matplotlib color scheme
|
116 |
colored_heatmaps = []
|
117 |
for concept_heatmap in concept_heatmaps:
|
118 |
-
concept_heatmap = (concept_heatmap -
|
119 |
colored_heatmap = plt.get_cmap(cmap)(concept_heatmap)
|
120 |
rgb_image = (colored_heatmap[:, :, :3] * 255).astype(np.uint8)
|
121 |
colored_heatmaps.append(rgb_image)
|
122 |
|
123 |
concept_heatmaps = [PIL.Image.fromarray(concept_heatmap) for concept_heatmap in colored_heatmaps]
|
124 |
|
|
|
|
|
125 |
colored_cross_attention_maps = []
|
126 |
for cross_attention_map in cross_attention_maps:
|
127 |
-
cross_attention_map = (cross_attention_map -
|
128 |
colored_cross_attention_map = plt.get_cmap(cmap)(cross_attention_map)
|
129 |
rgb_image = (colored_cross_attention_map[:, :, :3] * 255).astype(np.uint8)
|
130 |
colored_cross_attention_maps.append(rgb_image)
|
@@ -137,58 +187,203 @@ class ConceptAttentionFluxPipeline():
|
|
137 |
cross_attention_maps=cross_attention_maps
|
138 |
)
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
|
|
5 |
import PIL
|
6 |
import numpy as np
|
7 |
import matplotlib.pyplot as plt
|
8 |
+
from concept_attention.flux.src.flux.sampling import prepare
|
9 |
+
from concept_attention.segmentation import add_noise_to_image, encode_image
|
10 |
+
from concept_attention.utils import embed_concepts, linear_normalization
|
11 |
import torch
|
12 |
import einops
|
13 |
+
from tqdm import tqdm
|
14 |
|
15 |
from concept_attention.binary_segmentation_baselines.raw_cross_attention import RawCrossAttentionBaseline, RawCrossAttentionSegmentationModel
|
16 |
from concept_attention.binary_segmentation_baselines.raw_output_space import RawOutputSpaceBaseline, RawOutputSpaceSegmentationModel
|
|
|
22 |
concept_heatmaps: list[PIL.Image.Image]
|
23 |
cross_attention_maps: list[PIL.Image.Image]
|
24 |
|
25 |
+
|
26 |
+
def compute_heatmaps_from_vectors(
|
27 |
+
image_vectors,
|
28 |
+
concept_vectors,
|
29 |
+
layer_indices: list[int],
|
30 |
+
timesteps: list[int] = list(range(4)),
|
31 |
+
softmax: bool = True,
|
32 |
+
normalize_concepts: bool = True
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
Accepts image vectors and concept vectors. These can be from cross attentions or attention outputs.
|
36 |
+
"""
|
37 |
+
print(f"Image vectors shape: {image_vectors.shape}")
|
38 |
+
print(f"Concept vectors shape: {concept_vectors.shape}")
|
39 |
+
# Check if there are heads in the input
|
40 |
+
if len(image_vectors.shape) == 6:
|
41 |
+
# Collapse the had dimension
|
42 |
+
image_vectors = einops.rearrange(
|
43 |
+
image_vectors,
|
44 |
+
"time layers batch head patches dim -> time layers batch patches (head dim)"
|
45 |
+
)
|
46 |
+
concept_vectors = einops.rearrange(
|
47 |
+
concept_vectors,
|
48 |
+
"time layers batch head concepts dim -> time layers batch concepts (head dim)"
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
# Apply linear normalization to concepts
|
53 |
+
if normalize_concepts:
|
54 |
+
concept_vectors = linear_normalization(concept_vectors, dim=-2)
|
55 |
+
|
56 |
+
# Compute heatmaps
|
57 |
+
heatmaps = einops.einsum(
|
58 |
+
image_vectors,
|
59 |
+
concept_vectors,
|
60 |
+
"time layers batch patches dim, time layers batch concepts dim -> time layers batch concepts patches",
|
61 |
+
)
|
62 |
+
|
63 |
+
# Apply softmax
|
64 |
+
if softmax:
|
65 |
+
heatmaps = torch.nn.functional.softmax(heatmaps, dim=-2)
|
66 |
+
# Pull out the timesteps and layers
|
67 |
+
heatmaps = heatmaps[timesteps]
|
68 |
+
heatmaps = heatmaps[:, layer_indices]
|
69 |
+
# Average over the heatmaps
|
70 |
+
heatmaps = einops.reduce(
|
71 |
+
heatmaps,
|
72 |
+
"time layers batch concepts patches -> batch concepts patches",
|
73 |
+
reduction="mean"
|
74 |
+
)
|
75 |
+
heatmaps = einops.rearrange(
|
76 |
+
heatmaps,
|
77 |
+
"batch concepts (h w) -> batch concepts h w",
|
78 |
+
h=64,
|
79 |
+
w=64
|
80 |
+
)
|
81 |
+
|
82 |
+
return heatmaps
|
83 |
+
|
84 |
class ConceptAttentionFluxPipeline():
|
85 |
"""
|
86 |
This is an object that allows you to generate images with flux, and
|
|
|
129 |
if timesteps is None:
|
130 |
timesteps = list(range(num_inference_steps))
|
131 |
# Run the raw output space object
|
132 |
+
image, concept_attention_dict = self.flux_generator.generate_image(
|
133 |
width=width,
|
134 |
height=height,
|
135 |
prompt=prompt,
|
|
|
138 |
seed=seed,
|
139 |
guidance=guidance,
|
140 |
)
|
141 |
+
|
142 |
+
cross_attention_maps = compute_heatmaps_from_vectors(
|
143 |
+
concept_attention_dict["cross_attention_image_vectors"],
|
144 |
+
concept_attention_dict["cross_attention_concept_vectors"],
|
145 |
+
layer_indices=layer_indices,
|
146 |
+
timesteps=timesteps,
|
147 |
+
softmax=softmax
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
)
|
149 |
+
concept_heatmaps = compute_heatmaps_from_vectors(
|
150 |
+
concept_attention_dict["output_space_image_vectors"],
|
151 |
+
concept_attention_dict["output_space_concept_vectors"],
|
152 |
+
layer_indices=layer_indices,
|
153 |
+
timesteps=timesteps,
|
154 |
+
softmax=softmax
|
155 |
)
|
156 |
+
|
157 |
concept_heatmaps = concept_heatmaps.to(torch.float32).detach().cpu().numpy()[0]
|
158 |
cross_attention_maps = cross_attention_maps.to(torch.float32).detach().cpu().numpy()[0]
|
159 |
# Convert the torch heatmaps to PIL images.
|
160 |
if return_pil_heatmaps:
|
161 |
+
concept_heatmaps_min = concept_heatmaps.min()
|
162 |
+
concept_heatmaps_max = concept_heatmaps.max()
|
163 |
# Convert to a matplotlib color scheme
|
164 |
colored_heatmaps = []
|
165 |
for concept_heatmap in concept_heatmaps:
|
166 |
+
concept_heatmap = (concept_heatmap - concept_heatmaps_min) / (concept_heatmaps_max - concept_heatmaps_min)
|
167 |
colored_heatmap = plt.get_cmap(cmap)(concept_heatmap)
|
168 |
rgb_image = (colored_heatmap[:, :, :3] * 255).astype(np.uint8)
|
169 |
colored_heatmaps.append(rgb_image)
|
170 |
|
171 |
concept_heatmaps = [PIL.Image.fromarray(concept_heatmap) for concept_heatmap in colored_heatmaps]
|
172 |
|
173 |
+
cross_attention_min = cross_attention_maps.min()
|
174 |
+
cross_attention_max = cross_attention_maps.max()
|
175 |
colored_cross_attention_maps = []
|
176 |
for cross_attention_map in cross_attention_maps:
|
177 |
+
cross_attention_map = (cross_attention_map - cross_attention_min) / (cross_attention_max - cross_attention_min)
|
178 |
colored_cross_attention_map = plt.get_cmap(cmap)(cross_attention_map)
|
179 |
rgb_image = (colored_cross_attention_map[:, :, :3] * 255).astype(np.uint8)
|
180 |
colored_cross_attention_maps.append(rgb_image)
|
|
|
187 |
cross_attention_maps=cross_attention_maps
|
188 |
)
|
189 |
|
190 |
+
def encode_image(
|
191 |
+
self,
|
192 |
+
image: PIL.Image.Image,
|
193 |
+
concepts: list[str],
|
194 |
+
prompt: str = "", # Optional
|
195 |
+
width: int = 1024,
|
196 |
+
height: int = 1024,
|
197 |
+
layer_indices = list(range(15, 19)),
|
198 |
+
num_samples: int = 1,
|
199 |
+
num_steps: int = 4,
|
200 |
+
noise_timestep: int = 2,
|
201 |
+
device: str = "cuda:0",
|
202 |
+
return_pil_heatmaps: bool = True,
|
203 |
+
seed: int = 0,
|
204 |
+
cmap="plasma",
|
205 |
+
stop_after_multi_modal_attentions=True,
|
206 |
+
softmax=True
|
207 |
+
) -> ConceptAttentionPipelineOutput:
|
208 |
+
"""
|
209 |
+
Encode an image with flux, given a list of concepts.
|
210 |
+
"""
|
211 |
+
assert all([layer_index >= 0 and layer_index < 19 for layer_index in layer_indices]), "Invalid layer index"
|
212 |
+
assert height == width, "Height and width must be the same for now"
|
213 |
+
print("Encoding image")
|
214 |
+
|
215 |
+
# Encode the image into the VAE latent space
|
216 |
+
encoded_image_without_noise = encode_image(
|
217 |
+
image,
|
218 |
+
self.flux_generator.ae,
|
219 |
+
offload=self.flux_generator.offload,
|
220 |
+
device=device,
|
221 |
+
)
|
222 |
+
# Do N trials
|
223 |
+
combined_concept_attention_dict = {
|
224 |
+
"cross_attention_image_vectors": [],
|
225 |
+
"cross_attention_concept_vectors": [],
|
226 |
+
"output_space_image_vectors": [],
|
227 |
+
"output_space_concept_vectors": []
|
228 |
+
}
|
229 |
+
print("Sampling")
|
230 |
+
for i in tqdm(range(num_samples)):
|
231 |
+
# Add noise to image
|
232 |
+
encoded_image, timesteps = add_noise_to_image(
|
233 |
+
encoded_image_without_noise,
|
234 |
+
num_steps=num_steps,
|
235 |
+
noise_timestep=noise_timestep,
|
236 |
+
seed=seed + i,
|
237 |
+
width=width,
|
238 |
+
height=height,
|
239 |
+
device=device,
|
240 |
+
is_schnell=self.flux_generator.is_schnell,
|
241 |
+
)
|
242 |
+
# Now run the diffusion model once on the noisy image
|
243 |
+
# Encode the concept vectors
|
244 |
+
|
245 |
+
if self.flux_generator.offload:
|
246 |
+
self.flux_generator.t5, self.flux_generator.clip = self.flux_generator.t5.to(device), self.flux_generator.clip.to(device)
|
247 |
+
inp = prepare(t5=self.flux_generator.t5, clip=self.flux_generator.clip, img=encoded_image, prompt=prompt)
|
248 |
+
|
249 |
+
concept_embeddings, concept_ids, concept_vec = embed_concepts(
|
250 |
+
self.flux_generator.clip,
|
251 |
+
self.flux_generator.t5,
|
252 |
+
concepts,
|
253 |
+
)
|
254 |
+
|
255 |
+
inp["concepts"] = concept_embeddings.to(encoded_image.device)
|
256 |
+
inp["concept_ids"] = concept_ids.to(encoded_image.device)
|
257 |
+
inp["concept_vec"] = concept_vec.to(encoded_image.device)
|
258 |
+
# offload TEs to CPU, load model to gpu
|
259 |
+
if self.flux_generator.offload:
|
260 |
+
self.flux_generator.t5, self.flux_generator.clip = self.flux_generator.t5.cpu(), self.flux_generator.clip.cpu()
|
261 |
+
torch.cuda.empty_cache()
|
262 |
+
self.flux_generator.model = self.flux_generator.model.to(device)
|
263 |
+
# Denoise the intermediate images
|
264 |
+
guidance_vec = torch.full((encoded_image.shape[0],), 0.0, device=encoded_image.device, dtype=encoded_image.dtype)
|
265 |
+
t_curr = timesteps[0]
|
266 |
+
t_prev = timesteps[1]
|
267 |
+
t_vec = torch.full((encoded_image.shape[0],), t_curr, dtype=encoded_image.dtype, device=encoded_image.device)
|
268 |
+
_, concept_attention_dict = self.flux_generator.model(
|
269 |
+
img=inp["img"],
|
270 |
+
img_ids=inp["img_ids"],
|
271 |
+
txt=inp["txt"],
|
272 |
+
txt_ids=inp["txt_ids"],
|
273 |
+
concepts=inp["concepts"],
|
274 |
+
concept_ids=inp["concept_ids"],
|
275 |
+
concept_vec=inp["concept_vec"],
|
276 |
+
y=inp["concept_vec"],
|
277 |
+
timesteps=t_vec,
|
278 |
+
guidance=guidance_vec,
|
279 |
+
stop_after_multimodal_attentions=stop_after_multi_modal_attentions, # Always true for the demo
|
280 |
+
joint_attention_kwargs=None,
|
281 |
+
)
|
282 |
+
|
283 |
+
for key in combined_concept_attention_dict.keys():
|
284 |
+
combined_concept_attention_dict[key].append(concept_attention_dict[key])
|
285 |
+
|
286 |
+
# all_concept_heatmaps.append(concept_heatmaps)
|
287 |
+
# all_cross_attention_maps.append(cross_attention_maps)
|
288 |
+
|
289 |
+
# Pull out the concept and image vectors from each block
|
290 |
+
for key in combined_concept_attention_dict.keys():
|
291 |
+
combined_concept_attention_dict[key] = torch.stack(combined_concept_attention_dict[key]).squeeze(1)
|
292 |
+
|
293 |
+
# Compute the heatmaps
|
294 |
+
concept_heatmaps = compute_heatmaps_from_vectors(
|
295 |
+
combined_concept_attention_dict["output_space_image_vectors"],
|
296 |
+
combined_concept_attention_dict["output_space_concept_vectors"],
|
297 |
+
layer_indices=layer_indices,
|
298 |
+
timesteps=timesteps,
|
299 |
+
softmax=softmax
|
300 |
+
)
|
301 |
+
|
302 |
+
cross_attention_maps = compute_heatmaps_from_vectors(
|
303 |
+
combined_concept_attention_dict["cross_attention_image_vectors"],
|
304 |
+
combined_concept_attention_dict["cross_attention_concept_vectors"],
|
305 |
+
layer_indices=layer_indices,
|
306 |
+
timesteps=timesteps,
|
307 |
+
softmax=softmax
|
308 |
+
)
|
309 |
+
|
310 |
+
# # Pull out the concept and image vectors from each block
|
311 |
+
# image_vectors = torch.stack(self.flux_generator.model.image_vectors).squeeze(1)
|
312 |
+
# concept_vectors = torch.stack(self.flux_generator.model.concept_vectors).squeeze(1)
|
313 |
+
# # Apply linear normalization ot the concept vectors
|
314 |
+
# if True:
|
315 |
+
# concept_vectors = linear_normalization(concept_vectors, dim=-2)
|
316 |
+
# # Compute the heatmaps
|
317 |
+
# concept_heatmaps = einops.einsum(
|
318 |
+
# image_vectors,
|
319 |
+
# concept_vectors,
|
320 |
+
# "time layers batch patches dim, time layers batch concepts dim -> time layers batch concepts patches",
|
321 |
+
# )
|
322 |
+
# concept_heatmaps = torch.stack(all_concept_heatmaps, dim=0)
|
323 |
+
# cross_attention_maps = torch.stack(all_cross_attention_maps, dim=0)
|
324 |
+
# Concept heamaps extraction
|
325 |
+
# if softmax:
|
326 |
+
# concept_heatmaps = torch.nn.functional.softmax(concept_heatmaps, dim=-2)
|
327 |
+
|
328 |
+
# concept_heatmaps = concept_heatmaps[:, layer_indices]
|
329 |
+
# concept_heatmaps = einops.reduce(
|
330 |
+
# concept_heatmaps,
|
331 |
+
# "time layers batch concepts patches -> batch concepts patches",
|
332 |
+
# reduction="mean"
|
333 |
+
# )
|
334 |
+
# concept_heatmaps = einops.rearrange(
|
335 |
+
# concept_heatmaps,
|
336 |
+
# "batch concepts (h w) -> batch concepts h w",
|
337 |
+
# h=64,
|
338 |
+
# w=64
|
339 |
+
# )
|
340 |
+
# Cross attention maps
|
341 |
+
# if softmax:
|
342 |
+
# cross_attention_maps = torch.nn.functional.softmax(cross_attention_maps, dim=-2)
|
343 |
+
|
344 |
+
# cross_attention_maps = cross_attention_maps[:, layer_indices]
|
345 |
+
# cross_attention_maps = einops.reduce(
|
346 |
+
# cross_attention_maps,
|
347 |
+
# "time layers batch concepts patches -> batch concepts patches",
|
348 |
+
# reduction="mean"
|
349 |
+
# )
|
350 |
+
# cross_attention_maps = einops.rearrange(
|
351 |
+
# cross_attention_maps,
|
352 |
+
# "batch concepts (h w) -> batch concepts h w",
|
353 |
+
# h=64,
|
354 |
+
# w=64
|
355 |
+
# )
|
356 |
+
concept_heatmaps = concept_heatmaps.to(torch.float32).detach().cpu().numpy()[0]
|
357 |
+
# cross_attention_maps = cross_attention_maps.to(torch.float32).detach().cpu().numpy()[0]
|
358 |
+
cross_attention_maps = concept_heatmaps
|
359 |
+
# Convert the torch heatmaps to PIL images.
|
360 |
+
if return_pil_heatmaps:
|
361 |
+
concept_heatmaps_min = concept_heatmaps.min()
|
362 |
+
concept_heatmaps_max = concept_heatmaps.max()
|
363 |
+
# Convert to a matplotlib color scheme
|
364 |
+
colored_heatmaps = []
|
365 |
+
for concept_heatmap in concept_heatmaps:
|
366 |
+
concept_heatmap = (concept_heatmap - concept_heatmaps_min) / (concept_heatmaps_max - concept_heatmaps_min)
|
367 |
+
colored_heatmap = plt.get_cmap(cmap)(concept_heatmap)
|
368 |
+
rgb_image = (colored_heatmap[:, :, :3] * 255).astype(np.uint8)
|
369 |
+
colored_heatmaps.append(rgb_image)
|
370 |
+
|
371 |
+
concept_heatmaps = [PIL.Image.fromarray(concept_heatmap) for concept_heatmap in colored_heatmaps]
|
372 |
+
|
373 |
+
cross_attention_min = cross_attention_maps.min()
|
374 |
+
cross_attention_max = cross_attention_maps.max()
|
375 |
+
colored_cross_attention_maps = []
|
376 |
+
for cross_attention_map in cross_attention_maps:
|
377 |
+
cross_attention_map = (cross_attention_map - cross_attention_min) / (cross_attention_max - cross_attention_min)
|
378 |
+
colored_cross_attention_map = plt.get_cmap(cmap)(cross_attention_map)
|
379 |
+
rgb_image = (colored_cross_attention_map[:, :, :3] * 255).astype(np.uint8)
|
380 |
+
colored_cross_attention_maps.append(rgb_image)
|
381 |
+
|
382 |
+
cross_attention_maps = [PIL.Image.fromarray(cross_attention_map) for cross_attention_map in colored_cross_attention_maps]
|
383 |
+
|
384 |
+
return ConceptAttentionPipelineOutput(
|
385 |
+
image=image,
|
386 |
+
concept_heatmaps=concept_heatmaps,
|
387 |
+
cross_attention_maps=cross_attention_maps
|
388 |
+
)
|
389 |
|
concept_attention/flux/src/flux/sampling.py
CHANGED
@@ -111,14 +111,18 @@ def denoise(
|
|
111 |
joint_attention_kwargs=None,
|
112 |
):
|
113 |
intermediate_images = [img]
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
116 |
# this is ignored for schnell
|
117 |
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
118 |
iteration = 0
|
119 |
for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])):
|
120 |
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
121 |
-
pred,
|
122 |
img=img,
|
123 |
img_ids=img_ids,
|
124 |
txt=txt,
|
@@ -138,13 +142,13 @@ def denoise(
|
|
138 |
# increment iteration
|
139 |
iteration += 1
|
140 |
|
141 |
-
|
142 |
-
|
143 |
|
144 |
-
|
145 |
-
|
146 |
|
147 |
-
return img, intermediate_images,
|
148 |
|
149 |
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
150 |
return rearrange(
|
|
|
111 |
joint_attention_kwargs=None,
|
112 |
):
|
113 |
intermediate_images = [img]
|
114 |
+
combined_concept_attention_dict = {
|
115 |
+
"output_space_concept_vectors": [],
|
116 |
+
"output_space_image_vectors": [],
|
117 |
+
"cross_attention_concept_vectors": [],
|
118 |
+
"cross_attention_image_vectors": [],
|
119 |
+
}
|
120 |
# this is ignored for schnell
|
121 |
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
122 |
iteration = 0
|
123 |
for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])):
|
124 |
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
125 |
+
pred, concept_attention_dict = model(
|
126 |
img=img,
|
127 |
img_ids=img_ids,
|
128 |
txt=txt,
|
|
|
142 |
# increment iteration
|
143 |
iteration += 1
|
144 |
|
145 |
+
for key in combined_concept_attention_dict.keys():
|
146 |
+
combined_concept_attention_dict[key].append(concept_attention_dict[key])
|
147 |
|
148 |
+
for key in combined_concept_attention_dict.keys():
|
149 |
+
combined_concept_attention_dict[key] = torch.stack(combined_concept_attention_dict[key], dim=0)
|
150 |
|
151 |
+
return img, intermediate_images, combined_concept_attention_dict
|
152 |
|
153 |
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
154 |
return rearrange(
|
concept_attention/image_generator.py
CHANGED
@@ -171,7 +171,7 @@ class FluxGenerator():
|
|
171 |
torch.cuda.empty_cache()
|
172 |
self.model = self.model.to(self.device)
|
173 |
# denoise initial noise
|
174 |
-
x,
|
175 |
self.model,
|
176 |
**inp,
|
177 |
timesteps=timesteps,
|
@@ -203,4 +203,4 @@ class FluxGenerator():
|
|
203 |
|
204 |
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
205 |
|
206 |
-
return img,
|
|
|
171 |
torch.cuda.empty_cache()
|
172 |
self.model = self.model.to(self.device)
|
173 |
# denoise initial noise
|
174 |
+
x, _, concept_attention_dict = denoise(
|
175 |
self.model,
|
176 |
**inp,
|
177 |
timesteps=timesteps,
|
|
|
203 |
|
204 |
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
205 |
|
206 |
+
return img, concept_attention_dict
|
concept_attention/modified_double_stream_block.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import torch
|
2 |
from torch import nn, Tensor
|
3 |
import einops
|
@@ -77,6 +78,7 @@ class ModifiedDoubleStreamBlock(nn.Module):
|
|
77 |
concept_vec: Tensor,
|
78 |
concept_pe: Tensor,
|
79 |
joint_attention_kwargs=None,
|
|
|
80 |
**kwargs
|
81 |
) -> tuple[Tensor, Tensor]:
|
82 |
assert concept_vec is not None, "Concept vectors must be provided for this implementation."
|
@@ -175,19 +177,26 @@ class ModifiedDoubleStreamBlock(nn.Module):
|
|
175 |
concept_attn = einops.rearrange(concept_attn, "B H L D -> B L (H D)")
|
176 |
img_attn = einops.rearrange(img_attn, "B H L D -> B L (H D)")
|
177 |
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
"
|
183 |
-
|
184 |
-
|
185 |
-
# Compute the
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
# Do the block updates
|
192 |
# Calculate the img blocks
|
193 |
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
@@ -200,4 +209,4 @@ class ModifiedDoubleStreamBlock(nn.Module):
|
|
200 |
concepts = concepts + concept_mod1.gate * self.txt_attn.proj(concept_attn)
|
201 |
concepts = concepts + concept_mod2.gate * self.txt_mlp((1 + concept_mod2.scale) * self.txt_norm2(concepts) + concept_mod2.shift)
|
202 |
|
203 |
-
return img, txt, concepts,
|
|
|
1 |
+
from concept_attention.utils import linear_normalization
|
2 |
import torch
|
3 |
from torch import nn, Tensor
|
4 |
import einops
|
|
|
78 |
concept_vec: Tensor,
|
79 |
concept_pe: Tensor,
|
80 |
joint_attention_kwargs=None,
|
81 |
+
normalize_concepts=True,
|
82 |
**kwargs
|
83 |
) -> tuple[Tensor, Tensor]:
|
84 |
assert concept_vec is not None, "Concept vectors must be provided for this implementation."
|
|
|
177 |
concept_attn = einops.rearrange(concept_attn, "B H L D -> B L (H D)")
|
178 |
img_attn = einops.rearrange(img_attn, "B H L D -> B L (H D)")
|
179 |
|
180 |
+
concept_attention_dict = {
|
181 |
+
"output_space_concept_vectors": concept_attn,
|
182 |
+
"output_space_image_vectors": img_attn,
|
183 |
+
"cross_attention_concept_vectors": concept_q,
|
184 |
+
"cross_attention_image_vectors": img_q
|
185 |
+
}
|
186 |
+
|
187 |
+
# # Compute the cross attentions
|
188 |
+
# cross_attention_maps = einops.einsum(
|
189 |
+
# concept_q,
|
190 |
+
# img_q,
|
191 |
+
# "batch head concepts dim, batch had patches dim -> batch head concepts patches"
|
192 |
+
# )
|
193 |
+
# cross_attention_maps = einops.reduce(cross_attention_maps, "batch head concepts patches -> batch concepts patches", reduction="mean")
|
194 |
+
# # Compute the concept attentions
|
195 |
+
# concept_attention_maps = einops.einsum(
|
196 |
+
# concept_attn,
|
197 |
+
# img_attn,
|
198 |
+
# "batch concepts dim, batch patches dim -> batch concepts patches"
|
199 |
+
# )
|
200 |
# Do the block updates
|
201 |
# Calculate the img blocks
|
202 |
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
|
|
209 |
concepts = concepts + concept_mod1.gate * self.txt_attn.proj(concept_attn)
|
210 |
concepts = concepts + concept_mod2.gate * self.txt_mlp((1 + concept_mod2.scale) * self.txt_norm2(concepts) + concept_mod2.shift)
|
211 |
|
212 |
+
return img, txt, concepts, concept_attention_dict
|
concept_attention/modified_flux_dit.py
CHANGED
@@ -119,10 +119,14 @@ class ModifiedFluxDiT(nn.Module):
|
|
119 |
concept_vec = concept_vec + self.vector_in(original_concept_vec)
|
120 |
concepts = self.txt_in(concepts)
|
121 |
############## Modify the double blocks to also return concept vectors ##############
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
124 |
for block in self.double_blocks:
|
125 |
-
img, txt, concepts,
|
126 |
img=img,
|
127 |
txt=txt,
|
128 |
vec=vec,
|
@@ -134,18 +138,18 @@ class ModifiedFluxDiT(nn.Module):
|
|
134 |
iteration=iteration,
|
135 |
joint_attention_kwargs=joint_attention_kwargs
|
136 |
)
|
137 |
-
|
138 |
-
|
139 |
|
140 |
-
|
141 |
-
|
142 |
#####################################################################################
|
143 |
|
144 |
img = torch.cat((txt, img), 1)
|
145 |
|
146 |
# Speed up segmentation by not generating the full image
|
147 |
if stop_after_multimodal_attentions:
|
148 |
-
return None,
|
149 |
|
150 |
# Do the single blocks now
|
151 |
for block in self.single_blocks:
|
@@ -154,4 +158,5 @@ class ModifiedFluxDiT(nn.Module):
|
|
154 |
img = img[:, txt.shape[1] :, ...]
|
155 |
|
156 |
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
157 |
-
|
|
|
|
119 |
concept_vec = concept_vec + self.vector_in(original_concept_vec)
|
120 |
concepts = self.txt_in(concepts)
|
121 |
############## Modify the double blocks to also return concept vectors ##############
|
122 |
+
combined_concept_attention_dict = {
|
123 |
+
"output_space_concept_vectors": [],
|
124 |
+
"output_space_image_vectors": [],
|
125 |
+
"cross_attention_concept_vectors": [],
|
126 |
+
"cross_attention_image_vectors": []
|
127 |
+
}
|
128 |
for block in self.double_blocks:
|
129 |
+
img, txt, concepts, concept_attention_dict = block(
|
130 |
img=img,
|
131 |
txt=txt,
|
132 |
vec=vec,
|
|
|
138 |
iteration=iteration,
|
139 |
joint_attention_kwargs=joint_attention_kwargs
|
140 |
)
|
141 |
+
for key in combined_concept_attention_dict.keys():
|
142 |
+
combined_concept_attention_dict[key].append(concept_attention_dict[key])
|
143 |
|
144 |
+
for key in combined_concept_attention_dict.keys():
|
145 |
+
combined_concept_attention_dict[key] = torch.stack(combined_concept_attention_dict[key], dim=0)
|
146 |
#####################################################################################
|
147 |
|
148 |
img = torch.cat((txt, img), 1)
|
149 |
|
150 |
# Speed up segmentation by not generating the full image
|
151 |
if stop_after_multimodal_attentions:
|
152 |
+
return None, combined_concept_attention_dict
|
153 |
|
154 |
# Do the single blocks now
|
155 |
for block in self.single_blocks:
|
|
|
158 |
img = img[:, txt.shape[1] :, ...]
|
159 |
|
160 |
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
161 |
+
|
162 |
+
return img, combined_concept_attention_dict
|
concept_attention/segmentation.py
CHANGED
@@ -125,6 +125,7 @@ def encode_image(
|
|
125 |
Encodes a PIL image to the VAE latent space and adds noise to it
|
126 |
"""
|
127 |
if isinstance(image, PIL.Image.Image):
|
|
|
128 |
transform = transforms.Compose([
|
129 |
transforms.ToTensor(),
|
130 |
transforms.Lambda(lambda x: 2.0 * x - 1.0),
|
|
|
125 |
Encodes a PIL image to the VAE latent space and adds noise to it
|
126 |
"""
|
127 |
if isinstance(image, PIL.Image.Image):
|
128 |
+
image = image.convert("RGB")
|
129 |
transform = transforms.Compose([
|
130 |
transforms.ToTensor(),
|
131 |
transforms.Lambda(lambda x: 2.0 * x - 1.0),
|
experiments/test_image_encoder/test_encode_image.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from concept_attention.concept_attention_pipeline import ConceptAttentionFluxPipeline
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
if __name__ == "__main__":
|
5 |
+
pipeline = ConceptAttentionFluxPipeline(
|
6 |
+
model_name="flux-schnell",
|
7 |
+
offload_model=True
|
8 |
+
) # , device="cuda:0") # , offload_model=True)
|
9 |
+
|
10 |
+
image = Image.open("image.png").convert("RGB")
|
11 |
+
|
12 |
+
outputs = pipeline.encode_image(
|
13 |
+
image,
|
14 |
+
concepts=["animal", "background"]
|
15 |
+
)
|
16 |
+
concept_attention_maps = outputs.concept_heatmaps
|
17 |
+
|
18 |
+
concepts = ["animal", "background"]
|
19 |
+
for concept, attention_map in zip(concepts, concept_attention_maps):
|
20 |
+
attention_map.save(f"{concept}_attention_map.png")
|