Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
ed0bb32
1
Parent(s):
29c7873
Properly render cross attention as well.
Browse files
app.py
CHANGED
@@ -45,10 +45,6 @@ def encode_image(image, prompt, concepts, seed, layer_start_index, noise_timeste
|
|
45 |
if len(concepts) > 9:
|
46 |
raise gr.Error("Please enter at most 9 concepts", duration=10)
|
47 |
|
48 |
-
print(f"Num samples: {num_samples}")
|
49 |
-
print(f"Layer start index: {layer_start_index}")
|
50 |
-
print(f"Noise timestep: {noise_timestep}")
|
51 |
-
print(image)
|
52 |
image = image.convert("RGB")
|
53 |
|
54 |
pipeline_output = pipeline.encode_image(
|
|
|
45 |
if len(concepts) > 9:
|
46 |
raise gr.Error("Please enter at most 9 concepts", duration=10)
|
47 |
|
|
|
|
|
|
|
|
|
48 |
image = image.convert("RGB")
|
49 |
|
50 |
pipeline_output = pipeline.encode_image(
|
concept_attention/concept_attention_pipeline.py
CHANGED
@@ -354,8 +354,7 @@ class ConceptAttentionFluxPipeline():
|
|
354 |
# w=64
|
355 |
# )
|
356 |
concept_heatmaps = concept_heatmaps.to(torch.float32).detach().cpu().numpy()[0]
|
357 |
-
|
358 |
-
cross_attention_maps = concept_heatmaps
|
359 |
# Convert the torch heatmaps to PIL images.
|
360 |
if return_pil_heatmaps:
|
361 |
concept_heatmaps_min = concept_heatmaps.min()
|
|
|
354 |
# w=64
|
355 |
# )
|
356 |
concept_heatmaps = concept_heatmaps.to(torch.float32).detach().cpu().numpy()[0]
|
357 |
+
cross_attention_maps = cross_attention_maps.to(torch.float32).detach().cpu().numpy()[0]
|
|
|
358 |
# Convert the torch heatmaps to PIL images.
|
359 |
if return_pil_heatmaps:
|
360 |
concept_heatmaps_min = concept_heatmaps.min()
|