helblazer811 commited on
Commit
ed0bb32
·
1 Parent(s): 29c7873

Properly render cross attention as well.

Browse files
app.py CHANGED
@@ -45,10 +45,6 @@ def encode_image(image, prompt, concepts, seed, layer_start_index, noise_timeste
45
  if len(concepts) > 9:
46
  raise gr.Error("Please enter at most 9 concepts", duration=10)
47
 
48
- print(f"Num samples: {num_samples}")
49
- print(f"Layer start index: {layer_start_index}")
50
- print(f"Noise timestep: {noise_timestep}")
51
- print(image)
52
  image = image.convert("RGB")
53
 
54
  pipeline_output = pipeline.encode_image(
 
45
  if len(concepts) > 9:
46
  raise gr.Error("Please enter at most 9 concepts", duration=10)
47
 
 
 
 
 
48
  image = image.convert("RGB")
49
 
50
  pipeline_output = pipeline.encode_image(
concept_attention/concept_attention_pipeline.py CHANGED
@@ -354,8 +354,7 @@ class ConceptAttentionFluxPipeline():
354
  # w=64
355
  # )
356
  concept_heatmaps = concept_heatmaps.to(torch.float32).detach().cpu().numpy()[0]
357
- # cross_attention_maps = cross_attention_maps.to(torch.float32).detach().cpu().numpy()[0]
358
- cross_attention_maps = concept_heatmaps
359
  # Convert the torch heatmaps to PIL images.
360
  if return_pil_heatmaps:
361
  concept_heatmaps_min = concept_heatmaps.min()
 
354
  # w=64
355
  # )
356
  concept_heatmaps = concept_heatmaps.to(torch.float32).detach().cpu().numpy()[0]
357
+ cross_attention_maps = cross_attention_maps.to(torch.float32).detach().cpu().numpy()[0]
 
358
  # Convert the torch heatmaps to PIL images.
359
  if return_pil_heatmaps:
360
  concept_heatmaps_min = concept_heatmaps.min()