elismasilva commited on
Commit
9d38503
·
1 Parent(s): 6138235

switch model

Browse files
Files changed (2) hide show
  1. app.py +5 -14
  2. pipeline/util.py +1 -1
app.py CHANGED
@@ -24,27 +24,18 @@ MODELS = {"RealVisXL 5 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
24
  def load_model(model_id):
25
  global pipe, last_loaded_model
26
 
27
- if model_id != last_loaded_model:
28
-
29
  # Initialize the models and pipeline
30
  controlnet = ControlNetUnionModel.from_pretrained(
31
  "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
32
- )
33
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
34
- if pipe is not None:
35
- optionally_disable_offloading(pipe)
36
- torch_gc()
37
  pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
38
  MODELS[model_id], controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
39
- )
40
- pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
41
  pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
42
  pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
43
-
44
- unet = UNet2DConditionModel.from_pretrained(MODELS[model_id], subfolder="unet", variant="fp16", use_safetensors=True)
45
- quantize_8bit(unet) # << Enable this if you have limited VRAM
46
- pipe.unet = unet
47
-
48
  last_loaded_model = model_id
49
 
50
  load_model("RealVisXL 5 Lightning")
 
24
  def load_model(model_id):
25
  global pipe, last_loaded_model
26
 
27
+ if model_id != last_loaded_model:
 
28
  # Initialize the models and pipeline
29
  controlnet = ControlNetUnionModel.from_pretrained(
30
  "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
31
+ ).to(device)
32
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
 
 
 
33
  pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
34
  MODELS[model_id], controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
35
+ ).to(device)
36
+ #pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
37
  pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
38
  pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
 
 
 
 
 
39
  last_loaded_model = model_id
40
 
41
  load_model("RealVisXL 5 Lightning")
pipeline/util.py CHANGED
@@ -213,7 +213,7 @@ def torch_gc():
213
  if torch.cuda.is_available():
214
  with torch.cuda.device("cuda"):
215
  torch.cuda.empty_cache()
216
- #torch.cuda.ipc_collect()
217
 
218
  gc.collect()
219
 
 
213
  if torch.cuda.is_available():
214
  with torch.cuda.device("cuda"):
215
  torch.cuda.empty_cache()
216
+ torch.cuda.ipc_collect()
217
 
218
  gc.collect()
219