helblazer811 commited on
Commit
be12820
·
1 Parent(s): a8468a7

Changes to app UI.

Browse files
Files changed (2) hide show
  1. ConceptAttentionCallout.svg +3 -3
  2. app.py +194 -123
ConceptAttentionCallout.svg CHANGED
app.py CHANGED
@@ -15,13 +15,13 @@ COLUMNS = 5
15
  def update_default_concepts(prompt):
16
  default_concepts = {
17
  "A dog by a tree": ["dog", "grass", "tree", "background"],
18
- "A dragon": ["dragon", "sky", "rock", "cloud"],
19
  "A hot air balloon": ["balloon", "sky", "water", "tree"]
20
  }
21
 
22
  return gr.update(value=default_concepts.get(prompt, []))
23
 
24
- pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda") # , offload_model=True)
25
 
26
  def convert_pil_to_bytes(img):
27
  img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
@@ -76,50 +76,42 @@ def process_inputs(prompt, concepts, seed, layer_start_index, timestep_start_ind
76
  with gr.Blocks(
77
  css="""
78
  .container {
79
- max-width: 1400px;
80
  margin: 0 auto;
81
  padding: 20px;
82
  }
83
- .authors { text-align: center; margin-bottom: 10px; }
84
- .affiliations { text-align: center; color: #666; margin-bottom: 10px; }
85
- .abstract { text-align: center; margin-bottom: 40px; }
86
  .generated-image {
87
  display: flex;
88
  align-items: center;
89
  justify-content: center;
90
  height: 100%; /* Ensures full height */
91
  }
92
- .header {
93
- display: flex;
94
- flex-direction: column;
95
- }
96
  .input {
97
  height: 47px;
98
  }
99
  .input-column {
100
  flex-direction: column;
101
  gap: 0px;
 
102
  }
103
  .input-column-label {}
104
- .gallery {}
 
 
105
  .run-button-column {
106
  width: 100px !important;
107
  }
108
- #title {
109
- font-size: 2.4em;
110
- text-align: center;
111
- margin-bottom: 10px;
112
- }
113
- #subtitle {
114
- font-size: 2.0em;
115
- text-align: center;
116
- }
117
 
118
- #concept-attention-callout-svg {
119
- width: 250px;
 
120
  }
121
-
122
- /* Show only on screens wider than 768px (adjust as needed) */
123
  @media (min-width: 1024px) {
124
  .svg-container {
125
  min-width: 150px;
@@ -137,124 +129,203 @@ with gr.Blocks(
137
  }
138
  @media (min-width: 1530px) {
139
  .svg-container {
140
- min-width: 200px;
141
  width: 300px;
142
  padding-top: 400px;
143
  }
144
  }
145
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  @media (max-width: 1024px) {
148
  .svg-container {
 
 
 
149
  display: none;
150
  }
151
  }
152
 
153
- """
154
- # ,
155
- # elem_classes="container"
156
- ) as demo:
157
- with gr.Row(elem_classes="container"):
158
- with gr.Column(elem_classes="application", scale=15):
159
- with gr.Row(scale=3, elem_classes="header"):
160
- gr.HTML("<h1 id='title'> ConceptAttention: Visualize Any Concepts in Your Generated Images</h1>")
161
- gr.HTML("<h2 id='subtitle'> Interpret generative models with precise, high-quality heatmaps. <br/> Check out our paper <a href='https://arxiv.org/abs/2502.04320'> here </a>. </h2>")
162
-
163
- with gr.Row(scale=1, equal_height=True):
164
- with gr.Column(scale=4, elem_classes="input-column", min_width=250):
165
- gr.HTML(
166
- "Write a Prompt",
167
- elem_classes="input-column-label"
168
- )
169
- prompt = gr.Dropdown(
170
- ["A dog by a tree", "A dragon", "A hot air balloon"],
171
- container=False,
172
- allow_custom_value=True,
173
- elem_classes="input"
174
- )
 
 
 
 
 
 
 
 
 
 
175
 
176
- with gr.Column(scale=7, elem_classes="input-column"):
177
- gr.HTML(
178
- "Select or Write Concepts",
179
- elem_classes="input-column-label"
180
- )
181
- concepts = gr.Dropdown(
182
- ["dog", "grass", "tree", "dragon", "sky", "rock", "cloud", "balloon", "water", "background"],
183
- value=["dog", "grass", "tree", "background"],
184
- multiselect=True,
185
- label="Concepts",
186
- container=False,
187
- allow_custom_value=True,
188
- # scale=4,
189
- elem_classes="input",
190
- max_choices=5
191
- )
192
 
193
- with gr.Column(scale=1, min_width=100, elem_classes="input-column run-button-column"):
194
- gr.HTML(
195
- "&#8203;",
196
- elem_classes="input-column-label"
197
- )
198
- submit_btn = gr.Button(
199
- "Run",
200
- elem_classes="input"
201
- )
202
 
203
- with gr.Row(elem_classes="gallery", scale=8):
 
 
204
 
205
- with gr.Column(scale=1, min_width=250):
206
- generated_image = gr.Image(
207
- elem_classes="generated-image",
208
- show_label=False
209
- )
210
-
211
- with gr.Column(scale=4):
212
- concept_attention_gallery = gr.Gallery(
213
- label="Concept Attention (Ours)",
214
- show_label=True,
215
- # columns=3,
216
- rows=1,
217
- object_fit="contain",
218
- height="200px",
219
- elem_classes="gallery",
220
- elem_id="concept-attention-gallery"
221
- )
222
 
223
- cross_attention_gallery = gr.Gallery(
224
- label="Cross Attention",
225
- show_label=True,
226
- # columns=3,
227
- rows=1,
228
- object_fit="contain",
229
- height="200px",
230
- elem_classes="gallery"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  )
232
 
233
- with gr.Accordion("Advanced Settings", open=False):
234
- seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
235
- layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
236
- timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
237
-
238
- submit_btn.click(
239
- fn=process_inputs,
240
- inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
241
- outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
242
- )
243
-
244
- prompt.change(update_default_concepts, inputs=[prompt], outputs=[concepts])
245
-
246
- # Automatically process the first example on launch
247
- demo.load(
248
- process_inputs,
249
- inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
250
- outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
251
- )
252
-
253
- with gr.Column(scale=4, min_width=250, elem_classes="svg-container"):
254
- concept_attention_callout_svg = gr.HTML(
255
- "<img src='/gradio_api/file=ConceptAttentionCallout.svg' id='concept-attention-callout-svg'/>",
256
- # container=False,
257
- )
258
 
259
  if __name__ == "__main__":
260
  if os.path.exists("/data-nvme/zerogpu-offload"):
 
15
  def update_default_concepts(prompt):
16
  default_concepts = {
17
  "A dog by a tree": ["dog", "grass", "tree", "background"],
18
+ "A man on the beach": ["man", "dirt", "ocean", "sky"],
19
  "A hot air balloon": ["balloon", "sky", "water", "tree"]
20
  }
21
 
22
  return gr.update(value=default_concepts.get(prompt, []))
23
 
24
+ pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda:2", offload_model=True)
25
 
26
  def convert_pil_to_bytes(img):
27
  img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
 
76
  with gr.Blocks(
77
  css="""
78
  .container {
79
+ max-width: 1300px;
80
  margin: 0 auto;
81
  padding: 20px;
82
  }
83
+ .application {
84
+ max-width: 1200px;
85
+ }
86
  .generated-image {
87
  display: flex;
88
  align-items: center;
89
  justify-content: center;
90
  height: 100%; /* Ensures full height */
91
  }
92
+
 
 
 
93
  .input {
94
  height: 47px;
95
  }
96
  .input-column {
97
  flex-direction: column;
98
  gap: 0px;
99
+ height: 100%;
100
  }
101
  .input-column-label {}
102
+ .gallery {
103
+ height: 200px;
104
+ }
105
  .run-button-column {
106
  width: 100px !important;
107
  }
 
 
 
 
 
 
 
 
 
108
 
109
+ .gallery-container {
110
+ scrollbar-width: thin;
111
+ scrollbar-color: grey black;
112
  }
113
+
114
+ /* Show only on screens wider than 768px (adjust as needed)
115
  @media (min-width: 1024px) {
116
  .svg-container {
117
  min-width: 150px;
 
129
  }
130
  @media (min-width: 1530px) {
131
  .svg-container {
132
+ min-width: 200px;
133
  width: 300px;
134
  padding-top: 400px;
135
  }
136
  }
137
 
138
+ */
139
+
140
+ @media (min-width: 1024px) {
141
+ .svg-container {
142
+ min-width: 250px;
143
+ }
144
+ #concept-attention-callout-svg {
145
+ width: 250px;
146
+ }
147
+ }
148
+
149
 
150
  @media (max-width: 1024px) {
151
  .svg-container {
152
+ display: none !important;
153
+ }
154
+ #concept-attention-callout-svg {
155
  display: none;
156
  }
157
  }
158
 
159
+ .header {
160
+ display: flex;
161
+ flex-direction: column;
162
+ }
163
+ #title {
164
+ font-size: 4.4em;
165
+ color: #F3B13E;
166
+ text-align: center;
167
+ margin: 5px;
168
+ }
169
+ #subtitle {
170
+ font-size: 3.0em;
171
+ color: #FAE2BA;
172
+ text-align: center;
173
+ margin: 5px;
174
+ }
175
+ #abstract {
176
+ text-align: center;
177
+ font-size: 2.0em;
178
+ color:rgb(219, 219, 219);
179
+ margin: 5px;
180
+ margin-top: 10px;
181
+ }
182
+ #links {
183
+ text-align: center;
184
+ font-size: 2.0em;
185
+ margin: 5px;
186
+ }
187
+ #links a {
188
+ color: #93B7E9;
189
+ text-decoration: none;
190
+ }
191
 
192
+ .svg-container {
193
+ display: flex;
194
+ justify-content: center;
195
+ align-items: center;
196
+ }
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ .caption-label {
199
+ font-size: 1.15em;
200
+ }
 
 
 
 
 
 
201
 
202
+ .gallery label {
203
+ font-size: 1.15em;
204
+ }
205
 
206
+ """
207
+ ) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
+ # with gr.Column(elem_classes="container"):
210
+
211
+
212
+ with gr.Row(elem_classes="container", scale=8):
213
+
214
+ with gr.Column(elem_classes="application-content", scale=10):
215
+
216
+ with gr.Row(scale=3, elem_classes="header"):
217
+ gr.HTML("""
218
+ <h1 id='title'> ConceptAttention </h1>
219
+ <h1 id='subtitle'> Visualize Any Concepts in Your Generated Images </h1>
220
+ <h1 id='abstract'> Interpret diffusion models with precise, high-quality heatmaps. </h1>
221
+ <h1 id='links'> <a href='https://arxiv.org/abs/2502.04320'> Paper </a> | <a href='https://github.com/helblazer811/ConceptAttention'> Code </a> </h1>
222
+ """)
223
+
224
+ with gr.Row(elem_classes="input-row", scale=2):
225
+ with gr.Column(scale=4, elem_classes="input-column", min_width=250):
226
+ gr.HTML(
227
+ "Write a Prompt",
228
+ elem_classes="input-column-label"
229
+ )
230
+ prompt = gr.Dropdown(
231
+ ["A dog by a tree", "A dragon", "A hot air balloon"],
232
+ container=False,
233
+ allow_custom_value=True,
234
+ elem_classes="input"
235
+ )
236
+
237
+ with gr.Column(scale=7, elem_classes="input-column"):
238
+ gr.HTML(
239
+ "Select or Write Concepts",
240
+ elem_classes="input-column-label"
241
+ )
242
+ concepts = gr.Dropdown(
243
+ ["dog", "grass", "tree", "dragon", "sky", "rock", "cloud", "balloon", "water", "background"],
244
+ value=["dog", "grass", "tree", "background"],
245
+ multiselect=True,
246
+ label="Concepts",
247
+ container=False,
248
+ allow_custom_value=True,
249
+ # scale=4,
250
+ elem_classes="input",
251
+ max_choices=5
252
+ )
253
+
254
+ with gr.Column(scale=1, min_width=100, elem_classes="input-column run-button-column"):
255
+ gr.HTML(
256
+ "&#8203;",
257
+ elem_classes="input-column-label"
258
+ )
259
+ submit_btn = gr.Button(
260
+ "Run",
261
+ elem_classes="input"
262
+ )
263
+
264
+ with gr.Row(elem_classes="gallery-container", scale=8):
265
+
266
+ with gr.Column(scale=1, min_width=250):
267
+ generated_image = gr.Image(
268
+ elem_classes="generated-image",
269
+ show_label=False
270
+ )
271
+
272
+ with gr.Column(scale=4):
273
+ concept_attention_gallery = gr.Gallery(
274
+ label="Concept Attention (Ours)",
275
+ show_label=True,
276
+ # columns=3,
277
+ rows=1,
278
+ object_fit="contain",
279
+ # height="200px",
280
+ elem_classes="gallery",
281
+ elem_id="concept-attention-gallery",
282
+ # scale=4
283
+ )
284
+
285
+ cross_attention_gallery = gr.Gallery(
286
+ label="Cross Attention",
287
+ show_label=True,
288
+ # columns=3,
289
+ rows=1,
290
+ object_fit="contain",
291
+ # height="200px",
292
+ elem_classes="gallery",
293
+ # scale=4
294
+ )
295
+
296
+ with gr.Accordion("Advanced Settings", open=False):
297
+ seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
298
+ layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
299
+ timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
300
+
301
+ submit_btn.click(
302
+ fn=process_inputs,
303
+ inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
304
+ outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
305
+ )
306
+
307
+ prompt.change(update_default_concepts, inputs=[prompt], outputs=[concepts])
308
+
309
+ # Automatically process the first example on launch
310
+ demo.load(
311
+ process_inputs,
312
+ inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
313
+ outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
314
+ )
315
+
316
+ with gr.Column(scale=2, min_width=200, elem_classes="svg-column"):
317
+
318
+ with gr.Row(scale=8):
319
+ gr.HTML("<div></div>")
320
+
321
+ with gr.Row(scale=4, elem_classes="svg-container"):
322
+ concept_attention_callout_svg = gr.HTML(
323
+ "<img src='/gradio_api/file=ConceptAttentionCallout.svg' id='concept-attention-callout-svg'/>",
324
+ # container=False,
325
  )
326
 
327
+ with gr.Row(scale=4):
328
+ gr.HTML("<div></div>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  if __name__ == "__main__":
331
  if os.path.exists("/data-nvme/zerogpu-offload"):