Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
9d38503
1
Parent(s):
6138235
switch model
Browse files- app.py +5 -14
- pipeline/util.py +1 -1
app.py
CHANGED
@@ -24,27 +24,18 @@ MODELS = {"RealVisXL 5 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
|
|
24 |
def load_model(model_id):
|
25 |
global pipe, last_loaded_model
|
26 |
|
27 |
-
if model_id != last_loaded_model:
|
28 |
-
|
29 |
# Initialize the models and pipeline
|
30 |
controlnet = ControlNetUnionModel.from_pretrained(
|
31 |
"brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
|
32 |
-
)
|
33 |
-
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
34 |
-
if pipe is not None:
|
35 |
-
optionally_disable_offloading(pipe)
|
36 |
-
torch_gc()
|
37 |
pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
|
38 |
MODELS[model_id], controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
39 |
-
)
|
40 |
-
pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
|
41 |
pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
|
42 |
pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
|
43 |
-
|
44 |
-
unet = UNet2DConditionModel.from_pretrained(MODELS[model_id], subfolder="unet", variant="fp16", use_safetensors=True)
|
45 |
-
quantize_8bit(unet) # << Enable this if you have limited VRAM
|
46 |
-
pipe.unet = unet
|
47 |
-
|
48 |
last_loaded_model = model_id
|
49 |
|
50 |
load_model("RealVisXL 5 Lightning")
|
|
|
24 |
def load_model(model_id):
|
25 |
global pipe, last_loaded_model
|
26 |
|
27 |
+
if model_id != last_loaded_model:
|
|
|
28 |
# Initialize the models and pipeline
|
29 |
controlnet = ControlNetUnionModel.from_pretrained(
|
30 |
"brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
|
31 |
+
).to(device)
|
32 |
+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
|
|
|
|
|
|
|
33 |
pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
|
34 |
MODELS[model_id], controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
35 |
+
).to(device)
|
36 |
+
#pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
|
37 |
pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
|
38 |
pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
|
|
|
|
|
|
|
|
|
|
|
39 |
last_loaded_model = model_id
|
40 |
|
41 |
load_model("RealVisXL 5 Lightning")
|
pipeline/util.py
CHANGED
@@ -213,7 +213,7 @@ def torch_gc():
|
|
213 |
if torch.cuda.is_available():
|
214 |
with torch.cuda.device("cuda"):
|
215 |
torch.cuda.empty_cache()
|
216 |
-
|
217 |
|
218 |
gc.collect()
|
219 |
|
|
|
213 |
if torch.cuda.is_available():
|
214 |
with torch.cuda.device("cuda"):
|
215 |
torch.cuda.empty_cache()
|
216 |
+
torch.cuda.ipc_collect()
|
217 |
|
218 |
gc.collect()
|
219 |
|