elismasilva commited on
Commit
6138235
·
1 Parent(s): 14638d9

switch model

Browse files
Files changed (2) hide show
  1. app.py +44 -55
  2. pipeline/util.py +29 -1
app.py CHANGED
@@ -1,49 +1,59 @@
1
  import torch
2
  import spaces
3
- from diffusers import ControlNetUnionModel, AutoencoderKL
4
  import gradio as gr
5
 
6
  from pipeline.mod_controlnet_tile_sr_sdxl import StableDiffusionXLControlNetTileSRPipeline, calculate_overlap
7
  from pipeline.util import (
8
  SAMPLERS,
9
  create_hdr_effect,
 
10
  progressive_upscale,
 
11
  select_scheduler,
12
- torch_gc,
13
  )
14
 
15
  device = "cuda"
16
  pipe = None
17
  last_loaded_model = None
18
-
19
- # Initialize the models and pipeline
20
- controlnet = ControlNetUnionModel.from_pretrained(
21
- "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
22
- ).to(device=device)
23
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device=device)
24
 
25
  def load_model(model_id):
26
  global pipe, last_loaded_model
27
-
28
  if model_id != last_loaded_model:
29
- pipe = None
30
-
 
 
 
 
 
 
 
31
  pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
32
- model_id, controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
33
- ).to(device)
34
-
35
- #pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
36
  pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
37
  pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
 
 
 
 
 
38
  last_loaded_model = model_id
39
 
40
- load_model("SG161222/RealVisXL_V5.0_Lightning")
41
 
42
  # region functions
43
  @spaces.GPU(duration=120)
44
  def predict(
45
- model_id,
46
  image,
 
47
  prompt,
48
  negative_prompt,
49
  resolution,
@@ -124,38 +134,6 @@ def set_maximum_resolution(max_tile_size, current_value):
124
  def select_tile_weighting_method(tile_weighting_method):
125
  return gr.update(visible=True if tile_weighting_method=="Gaussian" else False)
126
 
127
- @spaces.GPU(duration=120)
128
- def run_for_examples(image,
129
- prompt,
130
- negative_prompt,
131
- resolution,
132
- hdr,
133
- num_inference_steps,
134
- denoising_strenght,
135
- controlnet_strength,
136
- tile_gaussian_sigma,
137
- scheduler,
138
- guidance_scale,
139
- max_tile_size,
140
- tile_weighting_method):
141
-
142
- predict(
143
- model.value,
144
- image,
145
- prompt,
146
- negative_prompt,
147
- resolution,
148
- hdr,
149
- num_inference_steps,
150
- denoising_strenght,
151
- controlnet_strength,
152
- tile_gaussian_sigma,
153
- scheduler,
154
- guidance_scale,
155
- max_tile_size,
156
- tile_weighting_method)
157
-
158
-
159
  # endregion
160
 
161
  css = """
@@ -174,7 +152,7 @@ body {
174
  text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.2);
175
  }
176
  .fillable {
177
- width: 95% !important;
178
  max-width: unset !important;
179
  }
180
  #examples_container {
@@ -279,7 +257,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
279
  with gr.Row(elem_id="parameters_row"):
280
  gr.Markdown("### General parameters")
281
  model = gr.Dropdown(
282
- label="Model", choices=["SG161222/RealVisXL_V5.0_Lightning", "SG161222/RealVisXL_V5.0"], value="SG161222/RealVisXL_V5.0_Lightning"
283
  )
284
  tile_weighting_method = gr.Dropdown(
285
  label="Tile Weighting Method", choices=["Cosine", "Gaussian"], value="Cosine"
@@ -303,9 +281,10 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
303
  with gr.Accordion(label="Example Images", open=True):
304
  with gr.Row(elem_id="examples_row"):
305
  with gr.Column(scale=12, elem_id="examples_container"):
306
- gr.Examples(
307
  examples=[
308
  [ "./examples/1.jpg",
 
309
  prompt.value,
310
  negative_prompt.value,
311
  4096,
@@ -320,6 +299,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
320
  "Cosine"
321
  ],
322
  [ "./examples/1.jpg",
 
323
  prompt.value,
324
  negative_prompt.value,
325
  4096,
@@ -334,6 +314,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
334
  "Cosine"
335
  ],
336
  [ "./examples/2.jpg",
 
337
  prompt.value,
338
  negative_prompt.value,
339
  4096,
@@ -348,6 +329,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
348
  "Cosine"
349
  ],
350
  [ "./examples/2.jpg",
 
351
  prompt.value,
352
  negative_prompt.value,
353
  4096,
@@ -362,6 +344,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
362
  "Cosine"
363
  ],
364
  [ "./examples/3.jpg",
 
365
  prompt.value,
366
  negative_prompt.value,
367
  5120,
@@ -376,6 +359,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
376
  "Gaussian"
377
  ],
378
  [ "./examples/3.jpg",
 
379
  prompt.value,
380
  negative_prompt.value,
381
  5120,
@@ -390,6 +374,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
390
  "Gaussian"
391
  ],
392
  [ "./examples/4.jpg",
 
393
  prompt.value,
394
  negative_prompt.value,
395
  8192,
@@ -404,6 +389,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
404
  "Gaussian"
405
  ],
406
  [ "./examples/4.jpg",
 
407
  prompt.value,
408
  negative_prompt.value,
409
  8192,
@@ -418,6 +404,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
418
  "Gaussian"
419
  ],
420
  [ "./examples/5.jpg",
 
421
  prompt.value,
422
  negative_prompt.value,
423
  8192,
@@ -432,6 +419,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
432
  "Cosine"
433
  ],
434
  [ "./examples/5.jpg",
 
435
  prompt.value,
436
  negative_prompt.value,
437
  8192,
@@ -448,6 +436,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
448
  ],
449
  inputs=[
450
  input_image,
 
451
  prompt,
452
  negative_prompt,
453
  resolution,
@@ -461,13 +450,13 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
461
  max_tile_size,
462
  tile_weighting_method,
463
  ],
464
- fn=run_for_examples,
465
  outputs=result,
466
  cache_examples=False,
467
  )
468
 
469
  max_tile_size.select(fn=set_maximum_resolution, inputs=[max_tile_size, resolution], outputs=resolution)
470
- tile_weighting_method.select(fn=select_tile_weighting_method, inputs=tile_weighting_method, outputs=tile_gaussian_sigma)
471
  generate_button.click(
472
  fn=clear_result,
473
  inputs=None,
@@ -475,8 +464,8 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
475
  ).then(
476
  fn=predict,
477
  inputs=[
478
- model,
479
  input_image,
 
480
  prompt,
481
  negative_prompt,
482
  resolution,
 
1
  import torch
2
  import spaces
3
+ from diffusers import ControlNetUnionModel, AutoencoderKL, UNet2DConditionModel
4
  import gradio as gr
5
 
6
  from pipeline.mod_controlnet_tile_sr_sdxl import StableDiffusionXLControlNetTileSRPipeline, calculate_overlap
7
  from pipeline.util import (
8
  SAMPLERS,
9
  create_hdr_effect,
10
+ optionally_disable_offloading,
11
  progressive_upscale,
12
+ quantize_8bit,
13
  select_scheduler,
14
+ torch_gc,
15
  )
16
 
17
  device = "cuda"
18
  pipe = None
19
  last_loaded_model = None
20
+ MODELS = {"RealVisXL 5 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
21
+ "RealVisXL 5": "SG161222/RealVisXL_V5.0"
22
+ }
 
 
 
23
 
24
  def load_model(model_id):
25
  global pipe, last_loaded_model
26
+
27
  if model_id != last_loaded_model:
28
+
29
+ # Initialize the models and pipeline
30
+ controlnet = ControlNetUnionModel.from_pretrained(
31
+ "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
32
+ )
33
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
34
+ if pipe is not None:
35
+ optionally_disable_offloading(pipe)
36
+ torch_gc()
37
  pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
38
+ MODELS[model_id], controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
39
+ )
40
+ pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
 
41
  pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
42
  pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
43
+
44
+ unet = UNet2DConditionModel.from_pretrained(MODELS[model_id], subfolder="unet", variant="fp16", use_safetensors=True)
45
+ quantize_8bit(unet) # << Enable this if you have limited VRAM
46
+ pipe.unet = unet
47
+
48
  last_loaded_model = model_id
49
 
50
+ load_model("RealVisXL 5 Lightning")
51
 
52
  # region functions
53
  @spaces.GPU(duration=120)
54
  def predict(
 
55
  image,
56
+ model_id,
57
  prompt,
58
  negative_prompt,
59
  resolution,
 
134
  def select_tile_weighting_method(tile_weighting_method):
135
  return gr.update(visible=True if tile_weighting_method=="Gaussian" else False)
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  # endregion
138
 
139
  css = """
 
152
  text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.2);
153
  }
154
  .fillable {
155
+ width: 100% !important;
156
  max-width: unset !important;
157
  }
158
  #examples_container {
 
257
  with gr.Row(elem_id="parameters_row"):
258
  gr.Markdown("### General parameters")
259
  model = gr.Dropdown(
260
+ label="Model", choices=MODELS.keys(), value=list(MODELS.keys())[0]
261
  )
262
  tile_weighting_method = gr.Dropdown(
263
  label="Tile Weighting Method", choices=["Cosine", "Gaussian"], value="Cosine"
 
281
  with gr.Accordion(label="Example Images", open=True):
282
  with gr.Row(elem_id="examples_row"):
283
  with gr.Column(scale=12, elem_id="examples_container"):
284
+ eg = gr.Examples(
285
  examples=[
286
  [ "./examples/1.jpg",
287
+ "RealVisXL 5 Lightning",
288
  prompt.value,
289
  negative_prompt.value,
290
  4096,
 
299
  "Cosine"
300
  ],
301
  [ "./examples/1.jpg",
302
+ "RealVisXL 5",
303
  prompt.value,
304
  negative_prompt.value,
305
  4096,
 
314
  "Cosine"
315
  ],
316
  [ "./examples/2.jpg",
317
+ "RealVisXL 5 Lightning",
318
  prompt.value,
319
  negative_prompt.value,
320
  4096,
 
329
  "Cosine"
330
  ],
331
  [ "./examples/2.jpg",
332
+ "RealVisXL 5",
333
  prompt.value,
334
  negative_prompt.value,
335
  4096,
 
344
  "Cosine"
345
  ],
346
  [ "./examples/3.jpg",
347
+ "RealVisXL 5 Lightning",
348
  prompt.value,
349
  negative_prompt.value,
350
  5120,
 
359
  "Gaussian"
360
  ],
361
  [ "./examples/3.jpg",
362
+ "RealVisXL 5",
363
  prompt.value,
364
  negative_prompt.value,
365
  5120,
 
374
  "Gaussian"
375
  ],
376
  [ "./examples/4.jpg",
377
+ "RealVisXL 5 Lightning",
378
  prompt.value,
379
  negative_prompt.value,
380
  8192,
 
389
  "Gaussian"
390
  ],
391
  [ "./examples/4.jpg",
392
+ "RealVisXL 5",
393
  prompt.value,
394
  negative_prompt.value,
395
  8192,
 
404
  "Gaussian"
405
  ],
406
  [ "./examples/5.jpg",
407
+ "RealVisXL 5 Lightning",
408
  prompt.value,
409
  negative_prompt.value,
410
  8192,
 
419
  "Cosine"
420
  ],
421
  [ "./examples/5.jpg",
422
+ "RealVisXL 5",
423
  prompt.value,
424
  negative_prompt.value,
425
  8192,
 
436
  ],
437
  inputs=[
438
  input_image,
439
+ model,
440
  prompt,
441
  negative_prompt,
442
  resolution,
 
450
  max_tile_size,
451
  tile_weighting_method,
452
  ],
453
+ fn=predict,
454
  outputs=result,
455
  cache_examples=False,
456
  )
457
 
458
  max_tile_size.select(fn=set_maximum_resolution, inputs=[max_tile_size, resolution], outputs=resolution)
459
+ tile_weighting_method.change(fn=select_tile_weighting_method, inputs=tile_weighting_method, outputs=tile_gaussian_sigma)
460
  generate_button.click(
461
  fn=clear_result,
462
  inputs=None,
 
464
  ).then(
465
  fn=predict,
466
  inputs=[
 
467
  input_image,
468
+ model,
469
  prompt,
470
  negative_prompt,
471
  resolution,
pipeline/util.py CHANGED
@@ -16,6 +16,8 @@
16
  import gc
17
  import cv2
18
  import numpy as np
 
 
19
  import torch
20
  from PIL import Image
21
 
@@ -96,6 +98,32 @@ def select_scheduler(pipe, selected_sampler):
96
 
97
  return scheduler.from_config(config, **add_kwargs)
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  # This function was copied and adapted from https://huggingface.co/spaces/gokaygokay/TileUpscalerV2, licensed under Apache 2.0.
101
  def progressive_upscale(input_image, target_resolution, steps=3):
@@ -185,7 +213,7 @@ def torch_gc():
185
  if torch.cuda.is_available():
186
  with torch.cuda.device("cuda"):
187
  torch.cuda.empty_cache()
188
- torch.cuda.ipc_collect()
189
 
190
  gc.collect()
191
 
 
16
  import gc
17
  import cv2
18
  import numpy as np
19
+ from torch import nn
20
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
21
  import torch
22
  from PIL import Image
23
 
 
98
 
99
  return scheduler.from_config(config, **add_kwargs)
100
 
101
+ def optionally_disable_offloading(_pipeline):
102
+ """
103
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
104
+
105
+ Args:
106
+ _pipeline (`DiffusionPipeline`):
107
+ The pipeline to disable offloading for.
108
+
109
+ Returns:
110
+ tuple:
111
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
112
+ """
113
+ is_model_cpu_offload = False
114
+ is_sequential_cpu_offload = False
115
+ if _pipeline is not None:
116
+ for _, component in _pipeline.components.items():
117
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
118
+ if not is_model_cpu_offload:
119
+ is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
120
+ if not is_sequential_cpu_offload:
121
+ is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)
122
+
123
+
124
+ remove_hook_from_module(component, recurse=True)
125
+
126
+ return (is_model_cpu_offload, is_sequential_cpu_offload)
127
 
128
  # This function was copied and adapted from https://huggingface.co/spaces/gokaygokay/TileUpscalerV2, licensed under Apache 2.0.
129
  def progressive_upscale(input_image, target_resolution, steps=3):
 
213
  if torch.cuda.is_available():
214
  with torch.cuda.device("cuda"):
215
  torch.cuda.empty_cache()
216
+ #torch.cuda.ipc_collect()
217
 
218
  gc.collect()
219