elismasilva commited on
Commit
76bf5ed
·
1 Parent(s): e6c446b

added load model pipeline

Browse files
Files changed (2) hide show
  1. app.py +52 -28
  2. pipeline/util.py +1 -31
app.py CHANGED
@@ -8,34 +8,51 @@ from pipeline.util import (
8
  SAMPLERS,
9
  create_hdr_effect,
10
  progressive_upscale,
11
- select_scheduler
 
12
  )
13
 
14
  device = "cuda"
15
- pipe = None
16
- last_loaded_model = None
17
  MODELS = {"RealVisXL 5 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
18
  "RealVisXL 5": "SG161222/RealVisXL_V5.0"
19
  }
 
 
 
 
 
 
20
 
21
- def load_model(model_id):
22
- global pipe, last_loaded_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- if model_id != last_loaded_model:
25
- # Initialize the models and pipeline
26
- controlnet = ControlNetUnionModel.from_pretrained(
27
- "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
28
- ).to(device)
29
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
30
- pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
31
- MODELS[model_id], controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
32
- ).to(device)
33
- #pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
34
- pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
35
- pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
36
- last_loaded_model = model_id
37
 
38
- load_model("RealVisXL 5 Lightning")
 
 
 
 
 
 
 
39
 
40
  # region functions
41
  @spaces.GPU(duration=120)
@@ -56,14 +73,12 @@ def predict(
56
  tile_weighting_method,
57
  progress=gr.Progress(track_tqdm=True),
58
  ):
59
- global pipe
60
-
61
  # Load model if changed
62
  load_model(model_id)
63
-
64
  # Set selected scheduler
65
  print(f"Using scheduler: {scheduler}...")
66
- pipe.scheduler = select_scheduler(pipe, scheduler)
67
 
68
  # Get current image size
69
  original_height = image.height
@@ -86,7 +101,7 @@ def predict(
86
 
87
  # Image generation
88
  print("Diffusion kicking in... almost done, coffee's on you!")
89
- image = pipe(
90
  image=image,
91
  control_image=control_image,
92
  control_mode=[6],
@@ -112,6 +127,14 @@ def predict(
112
  def clear_result():
113
  return gr.update(value=None)
114
 
 
 
 
 
 
 
 
 
115
  def set_maximum_resolution(max_tile_size, current_value):
116
  max_scale = 8 # <- you can try increase it to 12x, 16x if you wish!
117
  maximum_value = max_tile_size * max_scale
@@ -213,7 +236,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
213
  with gr.Column(scale=3):
214
  with gr.Row():
215
  with gr.Column():
216
- input_image = gr.Image(type="pil", label="Input Image",sources=["upload"], height=500)
217
  with gr.Column():
218
  result = gr.Image(
219
  label="Generated Image", show_label=True, format="png", interactive=False, scale=1, height=500, min_width=670
@@ -245,7 +268,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
245
  with gr.Row(elem_id="parameters_row"):
246
  gr.Markdown("### General parameters")
247
  model = gr.Dropdown(
248
- label="Model", choices=MODELS.keys(), value=list(MODELS.keys())[0]
249
  )
250
  tile_weighting_method = gr.Dropdown(
251
  label="Tile Weighting Method", choices=["Cosine", "Gaussian"], value="Cosine"
@@ -446,6 +469,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
446
 
447
  max_tile_size.select(fn=set_maximum_resolution, inputs=[max_tile_size, resolution], outputs=resolution)
448
  tile_weighting_method.change(fn=select_tile_weighting_method, inputs=tile_weighting_method, outputs=tile_gaussian_sigma)
 
449
  generate_button.click(
450
  fn=clear_result,
451
  inputs=None,
@@ -468,8 +492,8 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
468
  max_tile_size,
469
  tile_weighting_method,
470
  ],
471
- outputs=result,
472
- show_progress="full"
473
  )
474
  gr.Markdown(about)
 
475
  app.launch(share=False)
 
8
  SAMPLERS,
9
  create_hdr_effect,
10
  progressive_upscale,
11
+ select_scheduler,
12
+ torch_gc,
13
  )
14
 
15
  device = "cuda"
 
 
16
  MODELS = {"RealVisXL 5 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
17
  "RealVisXL 5": "SG161222/RealVisXL_V5.0"
18
  }
19
+ class Pipeline:
20
+ def __init__(self):
21
+ self.pipe = None
22
+ self.controlnet = None
23
+ self.vae = None
24
+ self.last_loaded_model = None
25
 
26
+ def load_model(self, model_id):
27
+ if model_id != self.last_loaded_model:
28
+ print(f"\n--- Loading model: {model_id} ---")
29
+ if self.pipe is not None:
30
+ self.pipe.to("cpu")
31
+ del self.pipe
32
+ self.pipe = None
33
+ del self.controlnet
34
+ self.controlnet = None
35
+ del self.vae
36
+ self.vae = None
37
+ torch_gc()
38
+
39
+ self.controlnet = ControlNetUnionModel.from_pretrained(
40
+ "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
41
+ ).to(device=device)
42
+ self.vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device=device)
43
 
44
+ self.pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
45
+ MODELS[model_id], controlnet=self.controlnet, vae=self.vae, torch_dtype=torch.float16, variant="fp16"
46
+ ).to(device=device)
 
 
 
 
 
 
 
 
 
 
47
 
48
+ self.pipe.enable_model_cpu_offload()
49
+ self.pipe.enable_vae_tiling()
50
+ self.pipe.enable_vae_slicing()
51
+ self.last_loaded_model = model_id
52
+ print(f"Model {model_id} loaded.")
53
+
54
+ def __call__(self, *args, **kwargs):
55
+ return self.pipe(*args, **kwargs)
56
 
57
  # region functions
58
  @spaces.GPU(duration=120)
 
73
  tile_weighting_method,
74
  progress=gr.Progress(track_tqdm=True),
75
  ):
 
 
76
  # Load model if changed
77
  load_model(model_id)
78
+
79
  # Set selected scheduler
80
  print(f"Using scheduler: {scheduler}...")
81
+ pipeline.pipe.scheduler = select_scheduler(pipeline.pipe, scheduler)
82
 
83
  # Get current image size
84
  original_height = image.height
 
101
 
102
  # Image generation
103
  print("Diffusion kicking in... almost done, coffee's on you!")
104
+ image = pipeline(
105
  image=image,
106
  control_image=control_image,
107
  control_mode=[6],
 
127
  def clear_result():
128
  return gr.update(value=None)
129
 
130
+ def load_model(model_name, on_load=False):
131
+ global pipeline # Declare pipeline as global
132
+ if on_load and 'pipeline' not in globals(): # Prevent reload page
133
+ pipeline = Pipeline() # Create pipeline inside the function
134
+ pipeline.load_model(model_name) # Load the initial model
135
+ elif pipeline is not None and not on_load:
136
+ pipeline.load_model(model_name) # Switch model
137
+
138
  def set_maximum_resolution(max_tile_size, current_value):
139
  max_scale = 8 # <- you can try increase it to 12x, 16x if you wish!
140
  maximum_value = max_tile_size * max_scale
 
236
  with gr.Column(scale=3):
237
  with gr.Row():
238
  with gr.Column():
239
+ input_image = gr.Image(type="pil", label="Input Image", sources=["upload"], height=500)
240
  with gr.Column():
241
  result = gr.Image(
242
  label="Generated Image", show_label=True, format="png", interactive=False, scale=1, height=500, min_width=670
 
268
  with gr.Row(elem_id="parameters_row"):
269
  gr.Markdown("### General parameters")
270
  model = gr.Dropdown(
271
+ label="Model", choices=list(MODELS.keys()), value=list(MODELS.keys())[0]
272
  )
273
  tile_weighting_method = gr.Dropdown(
274
  label="Tile Weighting Method", choices=["Cosine", "Gaussian"], value="Cosine"
 
469
 
470
  max_tile_size.select(fn=set_maximum_resolution, inputs=[max_tile_size, resolution], outputs=resolution)
471
  tile_weighting_method.change(fn=select_tile_weighting_method, inputs=tile_weighting_method, outputs=tile_gaussian_sigma)
472
+
473
  generate_button.click(
474
  fn=clear_result,
475
  inputs=None,
 
492
  max_tile_size,
493
  tile_weighting_method,
494
  ],
495
+ outputs=result,
 
496
  )
497
  gr.Markdown(about)
498
+ app.load(fn=load_model, inputs=[model, gr.State(value=True)], outputs=None, concurrency_limit=1) # Load initial model on app load
499
  app.launch(share=False)
pipeline/util.py CHANGED
@@ -16,8 +16,6 @@
16
  import gc
17
  import cv2
18
  import numpy as np
19
- from torch import nn
20
- from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
21
  import torch
22
  from PIL import Image
23
 
@@ -98,32 +96,6 @@ def select_scheduler(pipe, selected_sampler):
98
 
99
  return scheduler.from_config(config, **add_kwargs)
100
 
101
- def optionally_disable_offloading(_pipeline):
102
- """
103
- Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
104
-
105
- Args:
106
- _pipeline (`DiffusionPipeline`):
107
- The pipeline to disable offloading for.
108
-
109
- Returns:
110
- tuple:
111
- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
112
- """
113
- is_model_cpu_offload = False
114
- is_sequential_cpu_offload = False
115
- if _pipeline is not None:
116
- for _, component in _pipeline.components.items():
117
- if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
118
- if not is_model_cpu_offload:
119
- is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
120
- if not is_sequential_cpu_offload:
121
- is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)
122
-
123
-
124
- remove_hook_from_module(component, recurse=True)
125
-
126
- return (is_model_cpu_offload, is_sequential_cpu_offload)
127
 
128
  # This function was copied and adapted from https://huggingface.co/spaces/gokaygokay/TileUpscalerV2, licensed under Apache 2.0.
129
  def progressive_upscale(input_image, target_resolution, steps=3):
@@ -210,14 +182,12 @@ def create_hdr_effect(original_image, hdr):
210
 
211
 
212
  def torch_gc():
 
213
  if torch.cuda.is_available():
214
  with torch.cuda.device("cuda"):
215
  torch.cuda.empty_cache()
216
  torch.cuda.ipc_collect()
217
 
218
- gc.collect()
219
-
220
-
221
  def quantize_8bit(unet):
222
  if unet is None:
223
  return
 
16
  import gc
17
  import cv2
18
  import numpy as np
 
 
19
  import torch
20
  from PIL import Image
21
 
 
96
 
97
  return scheduler.from_config(config, **add_kwargs)
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  # This function was copied and adapted from https://huggingface.co/spaces/gokaygokay/TileUpscalerV2, licensed under Apache 2.0.
101
  def progressive_upscale(input_image, target_resolution, steps=3):
 
182
 
183
 
184
  def torch_gc():
185
+ gc.collect()
186
  if torch.cuda.is_available():
187
  with torch.cuda.device("cuda"):
188
  torch.cuda.empty_cache()
189
  torch.cuda.ipc_collect()
190
 
 
 
 
191
  def quantize_8bit(unet):
192
  if unet is None:
193
  return