File size: 3,913 Bytes
50f828f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36beff9
 
50f828f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""HiDiffusion demo for sd1.5 and sdxl."""
from functools import lru_cache

import gradio as gr
import PIL
import torch
from diffusers import DDIMScheduler, DiffusionPipeline, StableDiffusionPipeline, StableDiffusionXLPipeline

from hidiffusion import apply_hidiffusion

pretrained_models = {
    "sd1.5": "runwayml/stable-diffusion-v1-5",
    "sdxl": "stabilityai/stable-diffusion-xl-base-1.0",
}

pipeline_types = {
    "sd1.5": StableDiffusionPipeline,
    "sdxl": StableDiffusionXLPipeline,
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@lru_cache
def load_pipeline(model_type: str) -> DiffusionPipeline:
    """Load a pretrained model"""
    pretrained_model, pipeline_cls = pretrained_models[model_type], pipeline_types[model_type]
    scheduler = DDIMScheduler.from_pretrained(pretrained_model, subfolder="scheduler")
    pipe = pipeline_cls.from_pretrained(
        pretrained_model, scheduler=scheduler, torch_dtype=torch.float16, variant="fp16"
    ).to(device)
    pipe.enable_vae_tiling()
    if torch.cuda.is_available():
        pipe.enable_xformers_memory_efficient_attention()
    return pipe


def generate(
    model_type: str, use_hidiffusion: bool, positive: str, negative: str, width: int, height: int, guidance_scale: float
) -> PIL.Image.Image:
    pipe = load_pipeline(model_type)
    print(f"{model_type} pipeline is loaded")
    if use_hidiffusion:
        apply_hidiffusion(pipe)
        print("hidiffusion is applied")
    image = pipe(
        positive, negative_prompt=negative, guidance_scale=guidance_scale, height=height, width=width, eta=1.0
    ).images[0]
    print("generation is done")
    return image


demo = gr.Interface(
    fn=generate,
    inputs=[
        gr.Radio(choices=["sd1.5", "sdxl"], label="Model Type", value="sd1.5"),
        gr.Checkbox(value=True, label="Use HiDiffusion"),
        gr.Textbox(label="Positive Prompt"),
        gr.Textbox(label="Negative Prompt"),
        gr.Slider(512, 4096, value=1024, step=1, label="width"),
        gr.Slider(512, 4096, value=1024, step=1, label="height"),
        gr.Slider(0.0, 20.0, value=7.5, step=0.1, label="Guidance Scale"),
    ],
    outputs=gr.Image(),
    allow_flagging="never",
    title="HiDiffusion Demo",
    description="""
        HiDiffusion is a training-free method that increases the resolution and speed of pretrained diffusion models.\n
        It is designed as a plug-and-play implementation. It can be integrated into diffusion pipelines by only adding a single line of code!\n
        More information: https://github.com/megvii-research/HiDiffusion
    """,
    examples=[
        [
            "sd1.5",
            True,
            # positive
            "thick strokes, bright colors, an exotic fox, cute, chibi kawaii,"
            "detailed fur, hyperdetailed , big reflective eyes, fairytale, artstation,"
            "centered composition, perfect composition, centered, vibrant colors, muted colors, high detailed, 8k.",
            # negative
            "ugly, tiling, poorly drawn face, out of frame, disfigured, deformed, blurry, bad anatomy, blurred",
            # width
            1024,
            # height
            1024,
            # guidance scale
            7.5,
        ],
        [
            "sdxl",
            True,
            # positive
            "thick strokes, bright colors, an exotic fox, cute, chibi kawaii,"
            "detailed fur, hyperdetailed , big reflective eyes, fairytale, artstation,"
            "centered composition, perfect composition, centered, vibrant colors, muted colors, high detailed, 8k.",
            # negative
            "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
            # width
            2048,
            # height
            2048,
            # guidance scale
            7.5,
        ],
    ],
)


if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")