elismasilva commited on
Commit
32e89fe
·
1 Parent(s): 4bd749c

Initial commit

Browse files
.github/FUNDING.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # These are supported funding model platforms
2
+
3
+ ko_fi: elismasilva
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ mod_tests/
4
+ /.vs
5
+ .vscode/
6
+ .idea/
7
+ venv/
8
+ .venv/
9
+ *.log
10
+ .DS_Store
11
+ .gradio
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
- title: Mod Control Tile Upscaler Sdxl
3
- emoji: 👀
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.19.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: MoD ControlNet Tile Upscaler for SDXL
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Mod ControlNet Tile Upscaler SDXL
3
+ emoji: 🚀
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 5.15.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: Mixture of Diffusers and ControlNet Tile Upscaler for SDXL
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+ from diffusers import ControlNetUnionModel, AutoencoderKL, UNet2DConditionModel
4
+ import gradio as gr
5
+
6
+ from pipeline.mod_controlnet_tile_sr_sdxl import StableDiffusionXLControlNetTileSRPipeline
7
+ from pipeline.util import (
8
+ SAMPLERS,
9
+ Platinum,
10
+ calculate_overlap,
11
+ create_hdr_effect,
12
+ progressive_upscale,
13
+ quantize_8bit,
14
+ select_scheduler,
15
+ )
16
+
17
+ device = "cuda"
18
+
19
+ # Initialize the models and pipeline
20
+ controlnet = ControlNetUnionModel.from_pretrained(
21
+ "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
22
+ ).to(device=device)
23
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device=device)
24
+
25
+ model_id = "SG161222/RealVisXL_V5.0"
26
+ pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
27
+ model_id, controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
28
+ ).to(device)
29
+
30
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", variant="fp16", use_safetensors=True)
31
+ quantize_8bit(unet) # << Enable this if you have limited VRAM
32
+ pipe.unet = unet
33
+
34
+ pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
35
+ pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
36
+ pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
37
+
38
+ # region functions
39
+ @spaces.GPU
40
+ def predict(
41
+ image,
42
+ prompt,
43
+ negative_prompt,
44
+ resolution,
45
+ hdr,
46
+ num_inference_steps,
47
+ denoising_strenght,
48
+ controlnet_strength,
49
+ tile_gaussian_sigma,
50
+ scheduler,
51
+ guidance_scale,
52
+ max_tile_size,
53
+ tile_weighting_method,
54
+ progress=gr.Progress(track_tqdm=True),
55
+ ):
56
+ global pipe
57
+
58
+ # Set selected scheduler
59
+ print(f"Using scheduler: {scheduler}...")
60
+ pipe.scheduler = select_scheduler(pipe, scheduler)
61
+
62
+ # Get current image size
63
+ original_height = image.height
64
+ original_width = image.width
65
+ print(f"Current resolution: H:{original_height} x W:{original_width}")
66
+
67
+ # Pre-upscale image for tiling
68
+ control_image = progressive_upscale(image, resolution)
69
+ control_image = create_hdr_effect(control_image, hdr)
70
+
71
+ # Update target height and width
72
+ target_height = control_image.height
73
+ target_width = control_image.width
74
+ print(f"Target resolution: H:{target_height} x W:{target_width}")
75
+ print(f"Applied HDR effect: {True if hdr > 0 else False}")
76
+
77
+ # Calculate overlap size
78
+ normal_tile_overlap, border_tile_overlap = calculate_overlap(target_width, target_height)
79
+
80
+ # Image generation
81
+ print("Diffusion kicking in... almost done, coffee's on you!")
82
+ image = pipe(
83
+ image=control_image,
84
+ control_image=image,
85
+ control_mode=[6],
86
+ controlnet_conditioning_scale=float(controlnet_strength),
87
+ prompt=prompt,
88
+ negative_prompt=negative_prompt,
89
+ normal_tile_overlap=normal_tile_overlap,
90
+ border_tile_overlap=border_tile_overlap,
91
+ height=target_height,
92
+ width=target_width,
93
+ original_size=(original_width, original_height),
94
+ target_size=(target_width, target_height),
95
+ guidance_scale=guidance_scale,
96
+ strength=float(denoising_strenght),
97
+ tile_weighting_method=tile_weighting_method,
98
+ max_tile_size=max_tile_size,
99
+ tile_gaussian_sigma=float(tile_gaussian_sigma),
100
+ num_inference_steps=num_inference_steps,
101
+ )["images"][0]
102
+ image.save("result.png")
103
+ return image
104
+
105
+
106
+ def clear_result():
107
+ return gr.update(value=None)
108
+
109
+ def set_maximum_resolution(max_tile_size, current_value):
110
+ max_scale = 8 # <- you can try increase it to 12x, 16x if you wish!
111
+ maximum_value = max_tile_size * max_scale
112
+ if current_value > maximum_value:
113
+ return gr.update(maximum=maximum_value, value=maximum_value)
114
+ return gr.update(maximum=maximum_value)
115
+
116
+ def select_tile_weighting_method(tile_weighting_method):
117
+ return gr.update(visible=True if tile_weighting_method=="Gaussian" else False)
118
+
119
+ # endregion
120
+
121
+ css = """
122
+ body {
123
+ background: linear-gradient(135deg, #667eea, #764ba2);
124
+ font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
125
+ color: #333;
126
+ margin: 0;
127
+ padding: 0;
128
+ }
129
+ .gradio-container {
130
+ background: rgba(255, 255, 255, 0.95);
131
+ border-radius: 15px;
132
+ padding: 30px 40px;
133
+ box-shadow: 0 8px 30px rgba(0, 0, 0, 0.3);
134
+ margin: 40px 340px;
135
+ /*max-width: 1200px;*/
136
+ }
137
+ .gradio-container h1 {
138
+ color: #333;
139
+ text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.2);
140
+ }
141
+ .fillable {
142
+ width: 95% !important;
143
+ max-width: unset !important;
144
+ }
145
+ #examples_container {
146
+ margin: auto;
147
+ width: 90%;
148
+ }
149
+ #examples_row {
150
+ justify-content: center;
151
+ }
152
+ #tips_row{
153
+ padding-left: 20px;
154
+ }
155
+ .sidebar {
156
+ background: rgba(255, 255, 255, 0.98);
157
+ border-radius: 10px;
158
+ padding: 10px;
159
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
160
+ }
161
+ .sidebar .toggle-button {
162
+ background: linear-gradient(90deg, #7367f0, #9c93f4);
163
+ border: none;
164
+ color: #fff;
165
+ padding: 12px 24px;
166
+ text-transform: uppercase;
167
+ font-weight: bold;
168
+ letter-spacing: 1px;
169
+ border-radius: 5px;
170
+ cursor: pointer;
171
+ transition: transform 0.2s ease-in-out;
172
+ }
173
+ .toggle-button:hover {
174
+ transform: scale(1.05);
175
+ }
176
+ """
177
+ title = """<h1 align="center">MoD ControlNet Tile Upscaler for SDXL🤗</h1>
178
+ <div style="display: flex; flex-direction: column; justify-content: center; align-items: center; text-align: center; overflow:hidden;">
179
+ <span>This project implements the <a href="https://arxiv.org/pdf/2408.06072">📜 MoD (Mixture-of-Diffusers)</a> tiled diffusion technique and combines it with SDXL's ControlNet Tile process.</span>
180
+ <span>💻 <b><a href="https://github.com/DEVAIEXP/mod-control-tile-upscaler-sdxl">GitHub Code</a></b>
181
+ <span>🚀 <b>Controlnet Union Power!</b> Check out the model: <a href="https://huggingface.co/xinsir/controlnet-union-sdxl-1.0">Controlnet Union</a></span>
182
+ <span>🎨 <b>RealVisXL V5.0 for Stunning Visuals!</b> Explore it here: <a href="https://huggingface.co/SG161222/RealVisXL_V5.0">RealVisXL</a></span>
183
+ </div>
184
+ """
185
+
186
+ tips = """
187
+ ### Method
188
+ This project proposes an enhanced image upscaling method that leverages ControlNet Tile and Mixture-of-Diffusers techniques, integrating tile diffusion directly into the denoising process within the latent space.
189
+
190
+ Let's compare our method with conventional ControlNet Tile upscaling:
191
+
192
+ **Conventional ControlNet Tile:**
193
+ * Processes tiles in pixel space, potentially leading to edge artifacts during fusion.
194
+ * Processes each tile sequentially, increasing overall execution time (e.g., 16 tiles x 3 min = 48 min).
195
+ * Pixel space fusion using masks (e.g., Gaussian) can result in visible seams.
196
+ * Fixed or adaptively sized tiles and overlap can vary, causing inconsistencies.
197
+
198
+ **Proposed Method (MoD ControlNet Tile Upscaler):**
199
+ * Processes tiles in latent space, enabling smoother fusion and mitigating edge artifacts.
200
+ * Processes all tiles in parallel during denoising, drastically reducing execution time.
201
+ * Latent space fusion with dynamically calculated weights ensures seamless transitions between tiles.
202
+ * Tile size and overlap are dynamically adjusted based on the upscaling scale. For scales below 4x, fixed overlap maintains consistency.
203
+
204
+ """
205
+
206
+ about = """
207
+ 📧 **Contact**
208
+ <br>
209
+ If you have any questions or suggestions, feel free to send your question to <b>[email protected]</b>.
210
+ """
211
+
212
+ with gr.Blocks(css=css, theme=Platinum(), title="MoD ControlNet Tile Upscaler") as app:
213
+ gr.Markdown(title)
214
+ with gr.Row():
215
+ with gr.Column(scale=3):
216
+ with gr.Row():
217
+ with gr.Column():
218
+ input_image = gr.Image(type="pil", label="Input Image",sources=["upload"], height=500)
219
+ with gr.Column():
220
+ result = gr.Image(
221
+ label="Generated Image", show_label=True, format="png", interactive=False, scale=1, height=500, min_width=670
222
+ )
223
+ with gr.Row():
224
+ with gr.Accordion("Input Prompt", open=False):
225
+ with gr.Column():
226
+ prompt = gr.Textbox(
227
+ lines=2,
228
+ label="Prompt",
229
+ placeholder="Default prompt for image",
230
+ value="high-quality, noise-free edges, high quality, 4k, hd, 8k",
231
+ )
232
+ with gr.Column():
233
+ negative_prompt = gr.Textbox(
234
+ lines=2,
235
+ label="Negative Prompt (Optional)",
236
+ placeholder="e.g., blurry, low resolution, artifacts, poor details",
237
+ value="blurry, pixelated, noisy, low resolution, artifacts, poor details",
238
+ )
239
+ with gr.Row():
240
+ generate_button = gr.Button("Generate", variant="primary")
241
+ with gr.Column(scale=1):
242
+ with gr.Row(elem_id="tips_row"):
243
+ gr.Markdown(tips)
244
+ with gr.Sidebar(label="Parameters", open=True):
245
+ with gr.Row(elem_id="parameters_row"):
246
+ gr.Markdown("### General parameters")
247
+ tile_weighting_method = gr.Dropdown(
248
+ label="Tile Weighting Meethod", choices=["Cosine", "Gaussian"], value="Cosine"
249
+ )
250
+ tile_gaussian_sigma = gr.Slider(label="Gaussian Sigma", minimum=0.05, maximum=1.0, step=0.01, value=0.3, visible=False)
251
+ max_tile_size = gr.Dropdown(label="Max. Tile Size", choices=[1024, 1280], value=1024)
252
+ resolution = gr.Slider(minimum=128, maximum=8192, value=2048, step=128, label="Resolution")
253
+ num_inference_steps = gr.Slider(minimum=2, maximum=100, value=30, step=1, label="Inference Steps")
254
+ guidance_scale = gr.Slider(minimum=1, maximum=20, value=6, step=0.1, label="Guidance Scale")
255
+ denoising_strength = gr.Slider(minimum=0.1, maximum=1, value=0.6, step=0.01, label="Denoising Strength")
256
+ controlnet_strength = gr.Slider(
257
+ minimum=0.1, maximum=2.0, value=1.0, step=0.05, label="ControlNet Strength"
258
+ )
259
+ hdr = gr.Slider(minimum=0, maximum=1, value=0, step=0.1, label="HDR Effect")
260
+ with gr.Row():
261
+ scheduler = gr.Dropdown(
262
+ label="Sampler",
263
+ choices=list(SAMPLERS.keys()),
264
+ value="UniPC",
265
+ )
266
+ with gr.Accordion(label="Example Images", open=True):
267
+ with gr.Row(elem_id="examples_row"):
268
+ with gr.Column(scale=12, elem_id="examples_container"):
269
+ gr.Examples(
270
+ examples=[
271
+ [ "./examples/1.jpg",
272
+ prompt.value,
273
+ negative_prompt.value,
274
+ 4096,
275
+ 0.0,
276
+ 35,
277
+ 0.65,
278
+ 1.0,
279
+ 0.3,
280
+ "UniPC",
281
+ 4,
282
+ 1024,
283
+ "Cosine"
284
+ ],
285
+ [ "./examples/2.jpg",
286
+ prompt.value,
287
+ negative_prompt.value,
288
+ 4096,
289
+ 0.5,
290
+ 35,
291
+ 0.65,
292
+ 1.0,
293
+ 0.3,
294
+ "UniPC",
295
+ 4,
296
+ 1024,
297
+ "Cosine"
298
+ ],
299
+ [ "./examples/3.jpg",
300
+ prompt.value,
301
+ negative_prompt.value,
302
+ 5120,
303
+ 0.5,
304
+ 50,
305
+ 0.65,
306
+ 1.0,
307
+ 0.3,
308
+ "UniPC",
309
+ 4,
310
+ 1280,
311
+ "Gaussian"
312
+ ],
313
+ [ "./examples/4.jpg",
314
+ prompt.value,
315
+ negative_prompt.value,
316
+ 8192,
317
+ 0.1,
318
+ 50,
319
+ 0.5,
320
+ 1.0,
321
+ 0.3,
322
+ "UniPC",
323
+ 4,
324
+ 1024,
325
+ "Gaussian"
326
+ ],
327
+ [ "./examples/5.jpg",
328
+ prompt.value,
329
+ negative_prompt.value,
330
+ 8192,
331
+ 0.3,
332
+ 50,
333
+ 0.5,
334
+ 1.0,
335
+ 0.3,
336
+ "UniPC",
337
+ 4,
338
+ 1024,
339
+ "Cosine"
340
+ ],
341
+ ],
342
+ inputs=[
343
+ input_image,
344
+ prompt,
345
+ negative_prompt,
346
+ resolution,
347
+ hdr,
348
+ num_inference_steps,
349
+ denoising_strength,
350
+ controlnet_strength,
351
+ tile_gaussian_sigma,
352
+ scheduler,
353
+ guidance_scale,
354
+ max_tile_size,
355
+ tile_weighting_method,
356
+ ],
357
+ fn=predict,
358
+ outputs=result,
359
+ cache_examples=False,
360
+ )
361
+
362
+ max_tile_size.select(fn=set_maximum_resolution, inputs=[max_tile_size, resolution], outputs=resolution)
363
+ tile_weighting_method.select(fn=select_tile_weighting_method, inputs=tile_weighting_method, outputs=tile_gaussian_sigma)
364
+ generate_button.click(
365
+ fn=clear_result,
366
+ inputs=None,
367
+ outputs=result,
368
+ ).then(
369
+ fn=predict,
370
+ inputs=[
371
+ input_image,
372
+ prompt,
373
+ negative_prompt,
374
+ resolution,
375
+ hdr,
376
+ num_inference_steps,
377
+ denoising_strength,
378
+ controlnet_strength,
379
+ tile_gaussian_sigma,
380
+ scheduler,
381
+ guidance_scale,
382
+ max_tile_size,
383
+ tile_weighting_method,
384
+ ],
385
+ outputs=result,
386
+ show_progress="full"
387
+ )
388
+ gr.Markdown(about)
389
+ app.launch(share=False)
examples/1.jpg ADDED
examples/2.jpg ADDED
examples/3.jpg ADDED
examples/4.jpg ADDED
examples/5.jpg ADDED
pipeline/mod_controlnet_tile_sr_sdxl.py ADDED
@@ -0,0 +1,1845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 DEVAIEXP and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from enum import Enum
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from PIL import Image
23
+ from transformers import (
24
+ CLIPTextModel,
25
+ CLIPTextModelWithProjection,
26
+ CLIPTokenizer,
27
+ )
28
+
29
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import (
31
+ FromSingleFileMixin,
32
+ StableDiffusionXLLoraLoaderMixin,
33
+ TextualInversionLoaderMixin,
34
+ )
35
+ from diffusers.models import (
36
+ AutoencoderKL,
37
+ ControlNetModel,
38
+ ControlNetUnionModel,
39
+ MultiControlNetModel,
40
+ UNet2DConditionModel,
41
+ )
42
+ from diffusers.models.attention_processor import (
43
+ AttnProcessor2_0,
44
+ XFormersAttnProcessor,
45
+ )
46
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
47
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
48
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
49
+ from diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
50
+ from diffusers.utils import (
51
+ USE_PEFT_BACKEND,
52
+ logging,
53
+ replace_example_docstring,
54
+ scale_lora_layers,
55
+ unscale_lora_layers,
56
+ )
57
+ from diffusers.utils.import_utils import is_invisible_watermark_available
58
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
59
+
60
+ if is_invisible_watermark_available():
61
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
62
+
63
+ from diffusers.utils import is_torch_xla_available
64
+
65
+ if is_torch_xla_available():
66
+ import torch_xla.core.xla_model as xm
67
+
68
+ XLA_AVAILABLE = True
69
+ else:
70
+ XLA_AVAILABLE = False
71
+
72
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
73
+
74
+
75
+ EXAMPLE_DOC_STRING = """
76
+ Examples:
77
+ ```py
78
+ # !pip install controlnet_aux
79
+ from diffusers import (
80
+ StableDiffusionXLControlNetUnionImg2ImgPipeline,
81
+ ControlNetUnionModel,
82
+ AutoencoderKL,
83
+ )
84
+ from diffusers.utils import load_image
85
+ import torch
86
+ from PIL import Image
87
+ import numpy as np
88
+
89
+ prompt = "A cat"
90
+ # download an image
91
+ image = load_image(
92
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
93
+ )
94
+ # initialize the models and pipeline
95
+ controlnet = ControlNetUnionModel.from_pretrained(
96
+ "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
97
+ )
98
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
99
+ pipe = StableDiffusionXLControlNetUnionImg2ImgPipeline.from_pretrained(
100
+ "stabilityai/stable-diffusion-xl-base-1.0",
101
+ controlnet=controlnet,
102
+ vae=vae,
103
+ torch_dtype=torch.float16,
104
+ variant="fp16",
105
+ ).to("cuda")
106
+ # `enable_model_cpu_offload` is not recommended due to multiple generations
107
+ height = image.height
108
+ width = image.width
109
+ ratio = np.sqrt(1024.0 * 1024.0 / (width * height))
110
+ # 3 * 3 upscale correspond to 16 * 3 multiply, 2 * 2 correspond to 16 * 2 multiply and so on.
111
+ scale_image_factor = 3
112
+ base_factor = 16
113
+ factor = scale_image_factor * base_factor
114
+ W, H = int(width * ratio) // factor * factor, int(height * ratio) // factor * factor
115
+ image = image.resize((W, H))
116
+ target_width = W // scale_image_factor
117
+ target_height = H // scale_image_factor
118
+ images = []
119
+ crops_coords_list = [
120
+ (0, 0),
121
+ (0, width // 2),
122
+ (height // 2, 0),
123
+ (width // 2, height // 2),
124
+ 0,
125
+ 0,
126
+ 0,
127
+ 0,
128
+ 0,
129
+ ]
130
+ for i in range(scale_image_factor):
131
+ for j in range(scale_image_factor):
132
+ left = j * target_width
133
+ top = i * target_height
134
+ right = left + target_width
135
+ bottom = top + target_height
136
+ cropped_image = image.crop((left, top, right, bottom))
137
+ cropped_image = cropped_image.resize((W, H))
138
+ images.append(cropped_image)
139
+ # set ControlNetUnion input
140
+ result_images = []
141
+ for sub_img, crops_coords in zip(images, crops_coords_list):
142
+ new_width, new_height = W, H
143
+ out = pipe(
144
+ prompt=[prompt] * 1,
145
+ image=sub_img,
146
+ control_image=[sub_img],
147
+ control_mode=[6],
148
+ width=new_width,
149
+ height=new_height,
150
+ num_inference_steps=30,
151
+ crops_coords_top_left=(W, H),
152
+ target_size=(W, H),
153
+ original_size=(W * 2, H * 2),
154
+ )
155
+ result_images.append(out.images[0])
156
+ new_im = Image.new("RGB", (new_width * scale_image_factor, new_height * scale_image_factor))
157
+ new_im.paste(result_images[0], (0, 0))
158
+ new_im.paste(result_images[1], (new_width, 0))
159
+ new_im.paste(result_images[2], (new_width * 2, 0))
160
+ new_im.paste(result_images[3], (0, new_height))
161
+ new_im.paste(result_images[4], (new_width, new_height))
162
+ new_im.paste(result_images[5], (new_width * 2, new_height))
163
+ new_im.paste(result_images[6], (0, new_height * 2))
164
+ new_im.paste(result_images[7], (new_width, new_height * 2))
165
+ new_im.paste(result_images[8], (new_width * 2, new_height * 2))
166
+ ```
167
+ """
168
+
169
+
170
+ # This function was copied and adapted from https://huggingface.co/spaces/gokaygokay/TileUpscalerV2, licensed under Apache 2.0.
171
+ def _adaptive_tile_size(image_size, base_tile_size=512, max_tile_size=1280):
172
+ """
173
+ Calculate the adaptive tile size based on the image dimensions, ensuring the tile
174
+ respects the aspect ratio and stays within the specified size limits.
175
+ """
176
+ width, height = image_size
177
+ aspect_ratio = width / height
178
+
179
+ if aspect_ratio > 1:
180
+ # Landscape orientation
181
+ tile_width = min(width, max_tile_size)
182
+ tile_height = min(int(tile_width / aspect_ratio), max_tile_size)
183
+ else:
184
+ # Portrait or square orientation
185
+ tile_height = min(height, max_tile_size)
186
+ tile_width = min(int(tile_height * aspect_ratio), max_tile_size)
187
+
188
+ # Ensure the tile size is not smaller than the base_tile_size
189
+ tile_width = max(tile_width, base_tile_size)
190
+ tile_height = max(tile_height, base_tile_size)
191
+
192
+ return tile_width, tile_height
193
+
194
+
195
+ # Copied and adapted from https://github.com/huggingface/diffusers/blob/main/examples/community/mixture_tiling.py
196
+ def _tile2pixel_indices(
197
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, image_width, image_height
198
+ ):
199
+ """Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image
200
+
201
+ Returns a tuple with:
202
+ - Starting coordinates of rows in pixel space
203
+ - Ending coordinates of rows in pixel space
204
+ - Starting coordinates of columns in pixel space
205
+ - Ending coordinates of columns in pixel space
206
+ """
207
+ # Calculate initial indices
208
+ px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap)
209
+ px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap)
210
+
211
+ # Calculate end indices
212
+ px_row_end = px_row_init + tile_height
213
+ px_col_end = px_col_init + tile_width
214
+
215
+ # Ensure the last tile does not exceed the image dimensions
216
+ px_row_end = min(px_row_end, image_height)
217
+ px_col_end = min(px_col_end, image_width)
218
+
219
+ return px_row_init, px_row_end, px_col_init, px_col_end
220
+
221
+
222
+ # Copied and adapted from https://github.com/huggingface/diffusers/blob/main/examples/community/mixture_tiling.py
223
+ def _tile2latent_indices(
224
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, image_width, image_height
225
+ ):
226
+ """Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image
227
+
228
+ Returns a tuple with:
229
+ - Starting coordinates of rows in latent space
230
+ - Ending coordinates of rows in latent space
231
+ - Starting coordinates of columns in latent space
232
+ - Ending coordinates of columns in latent space
233
+ """
234
+ # Get pixel indices
235
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(
236
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, image_width, image_height
237
+ )
238
+
239
+ # Convert to latent space
240
+ latent_row_init = px_row_init // 8
241
+ latent_row_end = px_row_end // 8
242
+ latent_col_init = px_col_init // 8
243
+ latent_col_end = px_col_end // 8
244
+ latent_height = image_height // 8
245
+ latent_width = image_width // 8
246
+
247
+ # Ensure the last tile does not exceed the latent dimensions
248
+ latent_row_end = min(latent_row_end, latent_height)
249
+ latent_col_end = min(latent_col_end, latent_width)
250
+
251
+ return latent_row_init, latent_row_end, latent_col_init, latent_col_end
252
+
253
+
254
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
255
+ def retrieve_latents(
256
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
257
+ ):
258
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
259
+ return encoder_output.latent_dist.sample(generator)
260
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
261
+ return encoder_output.latent_dist.mode()
262
+ elif hasattr(encoder_output, "latents"):
263
+ return encoder_output.latents
264
+ else:
265
+ raise AttributeError("Could not access latents of provided encoder_output")
266
+
267
+ class TileWeightingMethod(Enum):
268
+ """Mode in which the tile weights will be generated"""
269
+
270
+ COSINE = "Cosine"
271
+ GAUSSIAN = "Gaussian"
272
+
273
+ class StableDiffusionXLControlNetTileSRPipeline(
274
+ DiffusionPipeline,
275
+ StableDiffusionMixin,
276
+ TextualInversionLoaderMixin,
277
+ StableDiffusionXLLoraLoaderMixin,
278
+ FromSingleFileMixin,
279
+ ):
280
+ r"""
281
+ Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.
282
+
283
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
284
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
285
+
286
+ The pipeline also inherits the following loading methods:
287
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
288
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
289
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
290
+
291
+ Args:
292
+ vae ([`AutoencoderKL`]):
293
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
294
+ text_encoder ([`CLIPTextModel`]):
295
+ Frozen text-encoder. Stable Diffusion uses the text portion of
296
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
297
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
298
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
299
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
300
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
301
+ specifically the
302
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
303
+ variant.
304
+ tokenizer (`CLIPTokenizer`):
305
+ Tokenizer of class
306
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
307
+ tokenizer_2 (`CLIPTokenizer`):
308
+ Second Tokenizer of class
309
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
310
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
311
+ controlnet ([`ControlNetUnionModel`]):
312
+ Provides additional conditioning to the unet during the denoising process.
313
+ scheduler ([`SchedulerMixin`]):
314
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
315
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
316
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
317
+ Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the
318
+ config of `stabilityai/stable-diffusion-xl-refiner-1-0`.
319
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
320
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
321
+ `stabilityai/stable-diffusion-xl-base-1-0`.
322
+ add_watermarker (`bool`, *optional*):
323
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
324
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
325
+ watermarker will be used.
326
+ """
327
+
328
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
329
+ _optional_components = [
330
+ "tokenizer",
331
+ "tokenizer_2",
332
+ "text_encoder",
333
+ "text_encoder_2",
334
+ ]
335
+
336
+ def __init__(
337
+ self,
338
+ vae: AutoencoderKL,
339
+ text_encoder: CLIPTextModel,
340
+ text_encoder_2: CLIPTextModelWithProjection,
341
+ tokenizer: CLIPTokenizer,
342
+ tokenizer_2: CLIPTokenizer,
343
+ unet: UNet2DConditionModel,
344
+ controlnet: ControlNetUnionModel,
345
+ scheduler: KarrasDiffusionSchedulers,
346
+ requires_aesthetics_score: bool = False,
347
+ force_zeros_for_empty_prompt: bool = True,
348
+ add_watermarker: Optional[bool] = None,
349
+ ):
350
+ super().__init__()
351
+
352
+ if not isinstance(controlnet, ControlNetUnionModel):
353
+ raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
354
+
355
+ self.register_modules(
356
+ vae=vae,
357
+ text_encoder=text_encoder,
358
+ text_encoder_2=text_encoder_2,
359
+ tokenizer=tokenizer,
360
+ tokenizer_2=tokenizer_2,
361
+ unet=unet,
362
+ controlnet=controlnet,
363
+ scheduler=scheduler,
364
+ )
365
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
366
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
367
+ self.control_image_processor = VaeImageProcessor(
368
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
369
+ )
370
+ self.mask_processor = VaeImageProcessor(
371
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
372
+ )
373
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
374
+
375
+ if add_watermarker:
376
+ self.watermark = StableDiffusionXLWatermarker()
377
+ else:
378
+ self.watermark = None
379
+
380
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
381
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
382
+
383
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
384
+ def encode_prompt(
385
+ self,
386
+ prompt: str,
387
+ prompt_2: Optional[str] = None,
388
+ device: Optional[torch.device] = None,
389
+ num_images_per_prompt: int = 1,
390
+ do_classifier_free_guidance: bool = True,
391
+ negative_prompt: Optional[str] = None,
392
+ negative_prompt_2: Optional[str] = None,
393
+ prompt_embeds: Optional[torch.Tensor] = None,
394
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
395
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
396
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
397
+ lora_scale: Optional[float] = None,
398
+ clip_skip: Optional[int] = None,
399
+ ):
400
+ r"""
401
+ Encodes the prompt into text encoder hidden states.
402
+
403
+ Args:
404
+ prompt (`str` or `List[str]`, *optional*):
405
+ prompt to be encoded
406
+ prompt_2 (`str` or `List[str]`, *optional*):
407
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
408
+ used in both text-encoders
409
+ device: (`torch.device`):
410
+ torch device
411
+ num_images_per_prompt (`int`):
412
+ number of images that should be generated per prompt
413
+ do_classifier_free_guidance (`bool`):
414
+ whether to use classifier free guidance or not
415
+ negative_prompt (`str` or `List[str]`, *optional*):
416
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
417
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
418
+ less than `1`).
419
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
420
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
421
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
422
+ prompt_embeds (`torch.Tensor`, *optional*):
423
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
424
+ provided, text embeddings will be generated from `prompt` input argument.
425
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
426
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
427
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
428
+ argument.
429
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
430
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
431
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
432
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
433
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
434
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
435
+ input argument.
436
+ lora_scale (`float`, *optional*):
437
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
438
+ clip_skip (`int`, *optional*):
439
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
440
+ the output of the pre-final layer will be used for computing the prompt embeddings.
441
+ """
442
+ device = device or self._execution_device
443
+
444
+ # set lora scale so that monkey patched LoRA
445
+ # function of text encoder can correctly access it
446
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
447
+ self._lora_scale = lora_scale
448
+
449
+ # dynamically adjust the LoRA scale
450
+ if self.text_encoder is not None:
451
+ if not USE_PEFT_BACKEND:
452
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
453
+ else:
454
+ scale_lora_layers(self.text_encoder, lora_scale)
455
+
456
+ if self.text_encoder_2 is not None:
457
+ if not USE_PEFT_BACKEND:
458
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
459
+ else:
460
+ scale_lora_layers(self.text_encoder_2, lora_scale)
461
+
462
+ prompt = [prompt] if isinstance(prompt, str) else prompt
463
+
464
+ if prompt is not None:
465
+ batch_size = len(prompt)
466
+ else:
467
+ batch_size = prompt_embeds.shape[0]
468
+
469
+ # Define tokenizers and text encoders
470
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
471
+ text_encoders = (
472
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
473
+ )
474
+ dtype = text_encoders[0].dtype
475
+ if prompt_embeds is None:
476
+ prompt_2 = prompt_2 or prompt
477
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
478
+
479
+ # textual inversion: process multi-vector tokens if necessary
480
+ prompt_embeds_list = []
481
+ prompts = [prompt, prompt_2]
482
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
483
+ if isinstance(self, TextualInversionLoaderMixin):
484
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
485
+
486
+ text_inputs = tokenizer(
487
+ prompt,
488
+ padding="max_length",
489
+ max_length=tokenizer.model_max_length,
490
+ truncation=True,
491
+ return_tensors="pt",
492
+ )
493
+
494
+ text_input_ids = text_inputs.input_ids
495
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
496
+
497
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
498
+ text_input_ids, untruncated_ids
499
+ ):
500
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
501
+ logger.warning(
502
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
503
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
504
+ )
505
+ text_encoder.to(dtype)
506
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
507
+
508
+ # We are only ALWAYS interested in the pooled output of the final text encoder
509
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
510
+ pooled_prompt_embeds = prompt_embeds[0]
511
+
512
+ if clip_skip is None:
513
+ prompt_embeds = prompt_embeds.hidden_states[-2]
514
+ else:
515
+ # "2" because SDXL always indexes from the penultimate layer.
516
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
517
+
518
+ prompt_embeds_list.append(prompt_embeds)
519
+
520
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
521
+
522
+ # get unconditional embeddings for classifier free guidance
523
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
524
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
525
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
526
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
527
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
528
+ negative_prompt = negative_prompt or ""
529
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
530
+
531
+ # normalize str to list
532
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
533
+ negative_prompt_2 = (
534
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
535
+ )
536
+
537
+ uncond_tokens: List[str]
538
+ if prompt is not None and type(prompt) is not type(negative_prompt):
539
+ raise TypeError(
540
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
541
+ f" {type(prompt)}."
542
+ )
543
+ elif batch_size != len(negative_prompt):
544
+ raise ValueError(
545
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
546
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
547
+ " the batch size of `prompt`."
548
+ )
549
+ else:
550
+ uncond_tokens = [negative_prompt, negative_prompt_2]
551
+
552
+ negative_prompt_embeds_list = []
553
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
554
+ if isinstance(self, TextualInversionLoaderMixin):
555
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
556
+
557
+ max_length = prompt_embeds.shape[1]
558
+ uncond_input = tokenizer(
559
+ negative_prompt,
560
+ padding="max_length",
561
+ max_length=max_length,
562
+ truncation=True,
563
+ return_tensors="pt",
564
+ )
565
+
566
+ negative_prompt_embeds = text_encoder(
567
+ uncond_input.input_ids.to(device),
568
+ output_hidden_states=True,
569
+ )
570
+
571
+ # We are only ALWAYS interested in the pooled output of the final text encoder
572
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
573
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
574
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
575
+
576
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
577
+
578
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
579
+
580
+ if self.text_encoder_2 is not None:
581
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
582
+ else:
583
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
584
+
585
+ bs_embed, seq_len, _ = prompt_embeds.shape
586
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
587
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
588
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
589
+
590
+ if do_classifier_free_guidance:
591
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
592
+ seq_len = negative_prompt_embeds.shape[1]
593
+
594
+ if self.text_encoder_2 is not None:
595
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
596
+ else:
597
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
598
+
599
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
600
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
601
+
602
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
603
+ bs_embed * num_images_per_prompt, -1
604
+ )
605
+ if do_classifier_free_guidance:
606
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
607
+ bs_embed * num_images_per_prompt, -1
608
+ )
609
+
610
+ if self.text_encoder is not None:
611
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
612
+ # Retrieve the original scale by scaling back the LoRA layers
613
+ unscale_lora_layers(self.text_encoder, lora_scale)
614
+
615
+ if self.text_encoder_2 is not None:
616
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
617
+ # Retrieve the original scale by scaling back the LoRA layers
618
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
619
+
620
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
621
+
622
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
623
+ def prepare_extra_step_kwargs(self, generator, eta):
624
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
625
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
626
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
627
+ # and should be between [0, 1]
628
+
629
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
630
+ extra_step_kwargs = {}
631
+ if accepts_eta:
632
+ extra_step_kwargs["eta"] = eta
633
+
634
+ # check if the scheduler accepts generator
635
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
636
+ if accepts_generator:
637
+ extra_step_kwargs["generator"] = generator
638
+ return extra_step_kwargs
639
+
640
+ def check_inputs(
641
+ self,
642
+ prompt,
643
+ height,
644
+ width,
645
+ image,
646
+ strength,
647
+ num_inference_steps,
648
+ normal_tile_overlap,
649
+ border_tile_overlap,
650
+ max_tile_size,
651
+ tile_gaussian_sigma,
652
+ tile_weighting_method,
653
+ controlnet_conditioning_scale=1.0,
654
+ control_guidance_start=0.0,
655
+ control_guidance_end=1.0,
656
+ ):
657
+ if height % 8 != 0 or width % 8 != 0:
658
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
659
+
660
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
661
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
662
+
663
+ if strength < 0 or strength > 1:
664
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
665
+ if num_inference_steps is None:
666
+ raise ValueError("`num_inference_steps` cannot be None.")
667
+ elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
668
+ raise ValueError(
669
+ f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
670
+ f" {type(num_inference_steps)}."
671
+ )
672
+ if normal_tile_overlap is None:
673
+ raise ValueError("`normal_tile_overlap` cannot be None.")
674
+ elif not isinstance(normal_tile_overlap, int) or normal_tile_overlap < 64:
675
+ raise ValueError(
676
+ f"`normal_tile_overlap` has to be greater than 64 but is {normal_tile_overlap} of type"
677
+ f" {type(normal_tile_overlap)}."
678
+ )
679
+ if border_tile_overlap is None:
680
+ raise ValueError("`border_tile_overlap` cannot be None.")
681
+ elif not isinstance(border_tile_overlap, int) or border_tile_overlap < 128:
682
+ raise ValueError(
683
+ f"`border_tile_overlap` has to be greater than 128 but is {border_tile_overlap} of type"
684
+ f" {type(border_tile_overlap)}."
685
+ )
686
+ if max_tile_size is None:
687
+ raise ValueError("`max_tile_size` cannot be None.")
688
+ elif not isinstance(max_tile_size, int) or max_tile_size not in(1024, 1280):
689
+ raise ValueError(
690
+ f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type"
691
+ f" {type(max_tile_size)}."
692
+ )
693
+ if tile_gaussian_sigma is None:
694
+ raise ValueError("`tile_gaussian_sigma` cannot be None.")
695
+ elif not isinstance(tile_gaussian_sigma, float) or tile_gaussian_sigma <= 0:
696
+ raise ValueError(
697
+ f"`tile_gaussian_sigma` has to be a positive float but is {tile_gaussian_sigma} of type"
698
+ f" {type(tile_gaussian_sigma)}."
699
+ )
700
+ if tile_weighting_method is None:
701
+ raise ValueError("`tile_weighting_method` cannot be None.")
702
+ elif not isinstance(tile_weighting_method, str) or tile_weighting_method not in [t.value for t in TileWeightingMethod]:
703
+ raise ValueError(
704
+ f"`tile_weighting_method` has to be a string in ({[t.value for t in TileWeightingMethod]}) but is {tile_weighting_method} of type"
705
+ f" {type(tile_weighting_method)}."
706
+ )
707
+
708
+ # Check `image`
709
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
710
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
711
+ )
712
+ if (
713
+ isinstance(self.controlnet, ControlNetModel)
714
+ or is_compiled
715
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
716
+ ):
717
+ self.check_image(image, prompt)
718
+ elif (
719
+ isinstance(self.controlnet, ControlNetUnionModel)
720
+ or is_compiled
721
+ and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
722
+ ):
723
+ self.check_image(image, prompt)
724
+ else:
725
+ assert False
726
+
727
+ # Check `controlnet_conditioning_scale`
728
+ if (
729
+ isinstance(self.controlnet, ControlNetUnionModel)
730
+ or is_compiled
731
+ and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
732
+ ) or (
733
+ isinstance(self.controlnet, MultiControlNetModel)
734
+ or is_compiled
735
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
736
+ ):
737
+ if not isinstance(controlnet_conditioning_scale, float):
738
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
739
+ elif (
740
+ isinstance(self.controlnet, MultiControlNetModel)
741
+ or is_compiled
742
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
743
+ ):
744
+ if isinstance(controlnet_conditioning_scale, list):
745
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
746
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
747
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
748
+ self.controlnet.nets
749
+ ):
750
+ raise ValueError(
751
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
752
+ " the same length as the number of controlnets"
753
+ )
754
+ else:
755
+ assert False
756
+
757
+ if not isinstance(control_guidance_start, (tuple, list)):
758
+ control_guidance_start = [control_guidance_start]
759
+
760
+ if not isinstance(control_guidance_end, (tuple, list)):
761
+ control_guidance_end = [control_guidance_end]
762
+
763
+ if len(control_guidance_start) != len(control_guidance_end):
764
+ raise ValueError(
765
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
766
+ )
767
+
768
+ for start, end in zip(control_guidance_start, control_guidance_end):
769
+ if start >= end:
770
+ raise ValueError(
771
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
772
+ )
773
+ if start < 0.0:
774
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
775
+ if end > 1.0:
776
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
777
+
778
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
779
+ def check_image(self, image, prompt):
780
+ image_is_pil = isinstance(image, Image.Image)
781
+ image_is_tensor = isinstance(image, torch.Tensor)
782
+ image_is_np = isinstance(image, np.ndarray)
783
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], Image.Image)
784
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
785
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
786
+
787
+ if (
788
+ not image_is_pil
789
+ and not image_is_tensor
790
+ and not image_is_np
791
+ and not image_is_pil_list
792
+ and not image_is_tensor_list
793
+ and not image_is_np_list
794
+ ):
795
+ raise TypeError(
796
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
797
+ )
798
+
799
+ if image_is_pil:
800
+ image_batch_size = 1
801
+ else:
802
+ image_batch_size = len(image)
803
+
804
+ if prompt is not None and isinstance(prompt, str):
805
+ prompt_batch_size = 1
806
+ elif prompt is not None and isinstance(prompt, list):
807
+ prompt_batch_size = len(prompt)
808
+
809
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
810
+ raise ValueError(
811
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
812
+ )
813
+
814
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
815
+ def prepare_control_image(
816
+ self,
817
+ image,
818
+ width,
819
+ height,
820
+ batch_size,
821
+ num_images_per_prompt,
822
+ device,
823
+ dtype,
824
+ do_classifier_free_guidance=False,
825
+ guess_mode=False,
826
+ ):
827
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
828
+ image_batch_size = image.shape[0]
829
+
830
+ if image_batch_size == 1:
831
+ repeat_by = batch_size
832
+ else:
833
+ # image batch size is the same as prompt batch size
834
+ repeat_by = num_images_per_prompt
835
+
836
+ image = image.repeat_interleave(repeat_by, dim=0)
837
+
838
+ image = image.to(device=device, dtype=dtype)
839
+
840
+ if do_classifier_free_guidance and not guess_mode:
841
+ image = torch.cat([image] * 2)
842
+
843
+ return image
844
+
845
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
846
+ def get_timesteps(self, num_inference_steps, strength):
847
+ # get the original timestep using init_timestep
848
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
849
+
850
+ t_start = max(num_inference_steps - init_timestep, 0)
851
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
852
+ if hasattr(self.scheduler, "set_begin_index"):
853
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
854
+
855
+ return timesteps, num_inference_steps - t_start
856
+
857
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
858
+ def prepare_latents(
859
+ self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
860
+ ):
861
+ if not isinstance(image, (torch.Tensor, Image.Image, list)):
862
+ raise ValueError(
863
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
864
+ )
865
+
866
+ latents_mean = latents_std = None
867
+ if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
868
+ latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
869
+ if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
870
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
871
+
872
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
873
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
874
+ self.text_encoder_2.to("cpu")
875
+ torch.cuda.empty_cache()
876
+
877
+ image = image.to(device=device, dtype=dtype)
878
+
879
+ batch_size = batch_size * num_images_per_prompt
880
+
881
+ if image.shape[1] == 4:
882
+ init_latents = image
883
+
884
+ else:
885
+ # make sure the VAE is in float32 mode, as it overflows in float16
886
+ if self.vae.config.force_upcast:
887
+ image = image.float()
888
+ self.vae.to(dtype=torch.float32)
889
+
890
+ if isinstance(generator, list) and len(generator) != batch_size:
891
+ raise ValueError(
892
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
893
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
894
+ )
895
+
896
+ elif isinstance(generator, list):
897
+ if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
898
+ image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
899
+ elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
900
+ raise ValueError(
901
+ f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
902
+ )
903
+
904
+ init_latents = [
905
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
906
+ for i in range(batch_size)
907
+ ]
908
+ init_latents = torch.cat(init_latents, dim=0)
909
+ else:
910
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
911
+
912
+ if self.vae.config.force_upcast:
913
+ self.vae.to(dtype)
914
+
915
+ init_latents = init_latents.to(dtype)
916
+ if latents_mean is not None and latents_std is not None:
917
+ latents_mean = latents_mean.to(device=device, dtype=dtype)
918
+ latents_std = latents_std.to(device=device, dtype=dtype)
919
+ init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
920
+ else:
921
+ init_latents = self.vae.config.scaling_factor * init_latents
922
+
923
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
924
+ # expand init_latents for batch_size
925
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
926
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
927
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
928
+ raise ValueError(
929
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
930
+ )
931
+ else:
932
+ init_latents = torch.cat([init_latents], dim=0)
933
+
934
+ if add_noise:
935
+ shape = init_latents.shape
936
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
937
+ # get latents
938
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
939
+
940
+ latents = init_latents
941
+
942
+ return latents
943
+
944
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
945
+ def _get_add_time_ids(
946
+ self,
947
+ original_size,
948
+ crops_coords_top_left,
949
+ target_size,
950
+ aesthetic_score,
951
+ negative_aesthetic_score,
952
+ negative_original_size,
953
+ negative_crops_coords_top_left,
954
+ negative_target_size,
955
+ dtype,
956
+ text_encoder_projection_dim=None,
957
+ ):
958
+ if self.config.requires_aesthetics_score:
959
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
960
+ add_neg_time_ids = list(
961
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
962
+ )
963
+ else:
964
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
965
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
966
+
967
+ passed_add_embed_dim = (
968
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
969
+ )
970
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
971
+
972
+ if (
973
+ expected_add_embed_dim > passed_add_embed_dim
974
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
975
+ ):
976
+ raise ValueError(
977
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
978
+ )
979
+ elif (
980
+ expected_add_embed_dim < passed_add_embed_dim
981
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
982
+ ):
983
+ raise ValueError(
984
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
985
+ )
986
+ elif expected_add_embed_dim != passed_add_embed_dim:
987
+ raise ValueError(
988
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
989
+ )
990
+
991
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
992
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
993
+
994
+ return add_time_ids, add_neg_time_ids
995
+
996
+ def _generate_cosine_weights(self, tile_width, tile_height, nbatches, device, dtype):
997
+ """
998
+ Generates cosine weights as a PyTorch tensor for blending tiles.
999
+
1000
+ Args:
1001
+ tile_width (int): Width of the tile in pixels.
1002
+ tile_height (int): Height of the tile in pixels.
1003
+ nbatches (int): Number of batches.
1004
+ device (torch.device): Device where the tensor will be allocated (e.g., 'cuda' or 'cpu').
1005
+ dtype (torch.dtype): Data type of the tensor (e.g., torch.float32).
1006
+
1007
+ Returns:
1008
+ torch.Tensor: A tensor containing cosine weights for blending tiles, expanded to match batch and channel dimensions.
1009
+ """
1010
+ # Convert tile dimensions to latent space
1011
+ latent_width = tile_width // 8
1012
+ latent_height = tile_height // 8
1013
+
1014
+ # Generate x and y coordinates in latent space
1015
+ x = np.arange(0, latent_width)
1016
+ y = np.arange(0, latent_height)
1017
+
1018
+ # Calculate midpoints
1019
+ midpoint_x = (latent_width - 1) / 2
1020
+ midpoint_y = (latent_height - 1) / 2
1021
+
1022
+ # Compute cosine probabilities for x and y
1023
+ x_probs = np.cos(np.pi * (x - midpoint_x) / latent_width)
1024
+ y_probs = np.cos(np.pi * (y - midpoint_y) / latent_height)
1025
+
1026
+ # Create a 2D weight matrix using the outer product
1027
+ weights_np = np.outer(y_probs, x_probs)
1028
+
1029
+ # Convert to a PyTorch tensor with the correct device and dtype
1030
+ weights_torch = torch.tensor(weights_np, device=device, dtype=dtype)
1031
+
1032
+ # Expand for batch and channel dimensions
1033
+ tile_weights_expanded = torch.tile(weights_torch, (nbatches, self.unet.config.in_channels, 1, 1))
1034
+
1035
+ return tile_weights_expanded
1036
+
1037
+ def _generate_gaussian_weights(self, tile_width, tile_height, nbatches, device, dtype, sigma=0.05):
1038
+ """
1039
+ Generates Gaussian weights as a PyTorch tensor for blending tiles in latent space.
1040
+
1041
+ Args:
1042
+ tile_width (int): Width of the tile in pixels.
1043
+ tile_height (int): Height of the tile in pixels.
1044
+ nbatches (int): Number of batches.
1045
+ device (torch.device): Device where the tensor will be allocated (e.g., 'cuda' or 'cpu').
1046
+ dtype (torch.dtype): Data type of the tensor (e.g., torch.float32).
1047
+ sigma (float, optional): Standard deviation of the Gaussian distribution. Controls the smoothness of the weights. Defaults to 0.05.
1048
+
1049
+ Returns:
1050
+ torch.Tensor: A tensor containing Gaussian weights for blending tiles, expanded to match batch and channel dimensions.
1051
+ """
1052
+ # Convert tile dimensions to latent space
1053
+ latent_width = tile_width // 8
1054
+ latent_height = tile_height // 8
1055
+
1056
+ # Generate Gaussian weights in latent space
1057
+ x = np.linspace(-1, 1, latent_width)
1058
+ y = np.linspace(-1, 1, latent_height)
1059
+ xx, yy = np.meshgrid(x, y)
1060
+ gaussian_weight = np.exp(-(xx**2 + yy**2) / (2 * sigma**2))
1061
+
1062
+ # Convert to a PyTorch tensor with the correct device and dtype
1063
+ weights_torch = torch.tensor(gaussian_weight, device=device, dtype=dtype)
1064
+
1065
+ # Expand for batch and channel dimensions
1066
+ weights_expanded = weights_torch.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions
1067
+ weights_expanded = weights_expanded.expand(nbatches, -1, -1, -1) # Expand to the number of batches
1068
+
1069
+ return weights_expanded
1070
+
1071
+ def _get_num_tiles(self, height, width, tile_height, tile_width, normal_tile_overlap, border_tile_overlap):
1072
+ """
1073
+ Calculates the number of tiles needed to cover an image, choosing the appropriate formula based on the
1074
+ ratio between the image size and the tile size.
1075
+
1076
+ This function automatically selects between two formulas:
1077
+ 1. A universal formula for typical cases (image-to-tile ratio <= 6:1).
1078
+ 2. A specialized formula with border tile overlap for larger or atypical cases (image-to-tile ratio > 6:1).
1079
+
1080
+ Args:
1081
+ height (int): Height of the image in pixels.
1082
+ width (int): Width of the image in pixels.
1083
+ tile_height (int): Height of each tile in pixels.
1084
+ tile_width (int): Width of each tile in pixels.
1085
+ normal_tile_overlap (int): Overlap between tiles in pixels for normal (non-border) tiles.
1086
+ border_tile_overlap (int): Overlap between tiles in pixels for border tiles.
1087
+
1088
+ Returns:
1089
+ tuple: A tuple containing:
1090
+ - grid_rows (int): Number of rows in the tile grid.
1091
+ - grid_cols (int): Number of columns in the tile grid.
1092
+
1093
+ Notes:
1094
+ - The function uses the universal formula (without border_tile_overlap) for typical cases where the
1095
+ image-to-tile ratio is 6:1 or smaller.
1096
+ - For larger or atypical cases (image-to-tile ratio > 6:1), it uses a specialized formula that includes
1097
+ border_tile_overlap to ensure complete coverage of the image, especially at the edges.
1098
+ """
1099
+ # Calculate the ratio between the image size and the tile size
1100
+ height_ratio = height / tile_height
1101
+ width_ratio = width / tile_width
1102
+
1103
+ # If the ratio is greater than 6:1, use the formula with border_tile_overlap
1104
+ if height_ratio > 6 or width_ratio > 6:
1105
+ grid_rows = int(np.ceil((height - border_tile_overlap) / (tile_height - normal_tile_overlap))) + 1
1106
+ grid_cols = int(np.ceil((width - border_tile_overlap) / (tile_width - normal_tile_overlap))) + 1
1107
+ else:
1108
+ # Otherwise, use the universal formula
1109
+ grid_rows = int(np.ceil((height - normal_tile_overlap) / (tile_height - normal_tile_overlap)))
1110
+ grid_cols = int(np.ceil((width - normal_tile_overlap) / (tile_width - normal_tile_overlap)))
1111
+
1112
+ return grid_rows, grid_cols
1113
+
1114
+ def prepare_tiles(
1115
+ self,
1116
+ grid_rows,
1117
+ grid_cols,
1118
+ tile_weighting_method,
1119
+ tile_width,
1120
+ tile_height,
1121
+ normal_tile_overlap,
1122
+ border_tile_overlap,
1123
+ width,
1124
+ height,
1125
+ tile_sigma,
1126
+ batch_size,
1127
+ device,
1128
+ dtype,
1129
+ ):
1130
+ """
1131
+ Processes image tiles by dynamically adjusting overlap and calculating Gaussian or cosine weights.
1132
+
1133
+ Args:
1134
+ grid_rows (int): Number of rows in the tile grid.
1135
+ grid_cols (int): Number of columns in the tile grid.
1136
+ tile_weighting_method (str): Method for weighting tiles. Options: "Gaussian" or "Cosine".
1137
+ tile_width (int): Width of each tile in pixels.
1138
+ tile_height (int): Height of each tile in pixels.
1139
+ normal_tile_overlap (int): Overlap between tiles in pixels for normal tiles.
1140
+ border_tile_overlap (int): Overlap between tiles in pixels for border tiles.
1141
+ width (int): Width of the image in pixels.
1142
+ height (int): Height of the image in pixels.
1143
+ tile_sigma (float): Sigma parameter for Gaussian weighting.
1144
+ batch_size (int): Batch size for weight tiles.
1145
+ device (torch.device): Device where tensors will be allocated (e.g., 'cuda' or 'cpu').
1146
+ dtype (torch.dtype): Data type of the tensors (e.g., torch.float32).
1147
+
1148
+ Returns:
1149
+ tuple: A tuple containing:
1150
+ - tile_weights (np.ndarray): Array of weights for each tile.
1151
+ - tile_row_overlaps (np.ndarray): Array of row overlaps for each tile.
1152
+ - tile_col_overlaps (np.ndarray): Array of column overlaps for each tile.
1153
+ """
1154
+
1155
+ # Create arrays to store dynamic overlaps and weights
1156
+ tile_row_overlaps = np.full((grid_rows, grid_cols), normal_tile_overlap)
1157
+ tile_col_overlaps = np.full((grid_rows, grid_cols), normal_tile_overlap)
1158
+ tile_weights = np.empty((grid_rows, grid_cols), dtype=object) # Stores Gaussian or cosine weights
1159
+
1160
+ # Iterate over tiles to adjust overlap and calculate weights
1161
+ for row in range(grid_rows):
1162
+ for col in range(grid_cols):
1163
+ # Calculate the size of the current tile
1164
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(
1165
+ row, col, tile_width, tile_height, normal_tile_overlap, normal_tile_overlap, width, height
1166
+ )
1167
+ current_tile_width = px_col_end - px_col_init
1168
+ current_tile_height = px_row_end - px_row_init
1169
+ sigma = tile_sigma
1170
+
1171
+ # Adjust overlap for smaller tiles
1172
+ if current_tile_width < tile_width:
1173
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(
1174
+ row, col, tile_width, tile_height, border_tile_overlap, border_tile_overlap, width, height
1175
+ )
1176
+ current_tile_width = px_col_end - px_col_init
1177
+ tile_col_overlaps[row, col] = border_tile_overlap
1178
+ sigma = tile_sigma * 1.2
1179
+ if current_tile_height < tile_height:
1180
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(
1181
+ row, col, tile_width, tile_height, border_tile_overlap, border_tile_overlap, width, height
1182
+ )
1183
+ current_tile_height = px_row_end - px_row_init
1184
+ tile_row_overlaps[row, col] = border_tile_overlap
1185
+ sigma = tile_sigma * 1.2
1186
+
1187
+ # Calculate weights for the current tile
1188
+ if tile_weighting_method == TileWeightingMethod.COSINE.value:
1189
+ tile_weights[row, col] = self._generate_cosine_weights(
1190
+ tile_width=current_tile_width,
1191
+ tile_height=current_tile_height,
1192
+ nbatches=batch_size,
1193
+ device=device,
1194
+ dtype=torch.float32,
1195
+ )
1196
+ else:
1197
+ tile_weights[row, col] = self._generate_gaussian_weights(
1198
+ tile_width=current_tile_width,
1199
+ tile_height=current_tile_height,
1200
+ nbatches=batch_size,
1201
+ device=device,
1202
+ dtype=dtype,
1203
+ sigma=sigma,
1204
+ )
1205
+
1206
+ return tile_weights, tile_row_overlaps, tile_col_overlaps
1207
+
1208
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
1209
+ def upcast_vae(self):
1210
+ dtype = self.vae.dtype
1211
+ self.vae.to(dtype=torch.float32)
1212
+ use_torch_2_0_or_xformers = isinstance(
1213
+ self.vae.decoder.mid_block.attentions[0].processor,
1214
+ (
1215
+ AttnProcessor2_0,
1216
+ XFormersAttnProcessor,
1217
+ ),
1218
+ )
1219
+ # if xformers or torch_2_0 is used attention block does not need
1220
+ # to be in float32 which can save lots of memory
1221
+ if use_torch_2_0_or_xformers:
1222
+ self.vae.post_quant_conv.to(dtype)
1223
+ self.vae.decoder.conv_in.to(dtype)
1224
+ self.vae.decoder.mid_block.to(dtype)
1225
+
1226
+ @property
1227
+ def guidance_scale(self):
1228
+ return self._guidance_scale
1229
+
1230
+ @property
1231
+ def clip_skip(self):
1232
+ return self._clip_skip
1233
+
1234
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1235
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1236
+ # corresponds to doing no classifier free guidance.
1237
+ @property
1238
+ def do_classifier_free_guidance(self):
1239
+ return self._guidance_scale > 1
1240
+
1241
+ @property
1242
+ def cross_attention_kwargs(self):
1243
+ return self._cross_attention_kwargs
1244
+
1245
+ @property
1246
+ def num_timesteps(self):
1247
+ return self._num_timesteps
1248
+
1249
+ @property
1250
+ def interrupt(self):
1251
+ return self._interrupt
1252
+
1253
+ @torch.no_grad()
1254
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1255
+ def __call__(
1256
+ self,
1257
+ prompt: Union[str, List[str]] = None,
1258
+ image: PipelineImageInput = None,
1259
+ control_image: PipelineImageInput = None,
1260
+ height: Optional[int] = None,
1261
+ width: Optional[int] = None,
1262
+ strength: float = 0.9999,
1263
+ num_inference_steps: int = 50,
1264
+ guidance_scale: float = 5.0,
1265
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1266
+ num_images_per_prompt: Optional[int] = 1,
1267
+ eta: float = 0.0,
1268
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1269
+ latents: Optional[torch.Tensor] = None,
1270
+ output_type: Optional[str] = "pil",
1271
+ return_dict: bool = True,
1272
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1273
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
1274
+ guess_mode: bool = False,
1275
+ control_guidance_start: Union[float, List[float]] = 0.0,
1276
+ control_guidance_end: Union[float, List[float]] = 1.0,
1277
+ control_mode: Optional[Union[int, List[int]]] = None,
1278
+ original_size: Tuple[int, int] = None,
1279
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
1280
+ target_size: Tuple[int, int] = None,
1281
+ negative_original_size: Optional[Tuple[int, int]] = None,
1282
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
1283
+ negative_target_size: Optional[Tuple[int, int]] = None,
1284
+ aesthetic_score: float = 6.0,
1285
+ negative_aesthetic_score: float = 2.5,
1286
+ clip_skip: Optional[int] = None,
1287
+ normal_tile_overlap: int = 64,
1288
+ border_tile_overlap: int = 128,
1289
+ max_tile_size: int = 1024,
1290
+ tile_gaussian_sigma: float = 0.05,
1291
+ tile_weighting_method: str = "Cosine",
1292
+ **kwargs,
1293
+ ):
1294
+ r"""
1295
+ Function invoked when calling the pipeline for generation.
1296
+
1297
+ Args:
1298
+ prompt (`str` or `List[str]`, *optional*):
1299
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1300
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, *optional*):
1301
+ The initial image to be used as the starting point for the image generation process. Can also accept
1302
+ image latents as `image`, if passing latents directly, they will not be encoded again.
1303
+ control_image (`PipelineImageInput`, *optional*):
1304
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance for Unet.
1305
+ If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
1306
+ be accepted as an image. The dimensions of the output image default to `image`'s dimensions. If height
1307
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
1308
+ init, images must be passed as a list such that each element of the list can be correctly batched for
1309
+ input to a single ControlNet.
1310
+ height (`int`, *optional*):
1311
+ The height in pixels of the generated image. If not provided, defaults to the height of `control_image`.
1312
+ width (`int`, *optional*):
1313
+ The width in pixels of the generated image. If not provided, defaults to the width of `control_image`.
1314
+ strength (`float`, *optional*, defaults to 0.9999):
1315
+ Indicates the extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
1316
+ starting point, and more noise is added the higher the `strength`. The number of denoising steps depends
1317
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum, and the denoising
1318
+ process runs for the full number of iterations specified in `num_inference_steps`.
1319
+ num_inference_steps (`int`, *optional*, defaults to 50):
1320
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1321
+ expense of slower inference.
1322
+ guidance_scale (`float`, *optional*, defaults to 5.0):
1323
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1324
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
1325
+ Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages generating
1326
+ images closely linked to the text `prompt`, usually at the expense of lower image quality.
1327
+ negative_prompt (`str` or `List[str]`, *optional*):
1328
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1329
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1330
+ less than `1`).
1331
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1332
+ The number of images to generate per prompt.
1333
+ eta (`float`, *optional*, defaults to 0.0):
1334
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1335
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1336
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1337
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1338
+ to make generation deterministic.
1339
+ latents (`torch.Tensor`, *optional*):
1340
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1341
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1342
+ tensor will be generated by sampling using the supplied random `generator`.
1343
+ output_type (`str`, *optional*, defaults to `"pil"`):
1344
+ The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/):
1345
+ `PIL.Image.Image` or `np.array`.
1346
+ return_dict (`bool`, *optional*, defaults to `True`):
1347
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1348
+ plain tuple.
1349
+ cross_attention_kwargs (`dict`, *optional*):
1350
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1351
+ `self.processor` in
1352
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1353
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1354
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
1355
+ to the residual in the original UNet. If multiple ControlNets are specified in init, you can set the
1356
+ corresponding scale as a list.
1357
+ guess_mode (`bool`, *optional*, defaults to `False`):
1358
+ In this mode, the ControlNet encoder will try to recognize the content of the input image even if
1359
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
1360
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1361
+ The percentage of total steps at which the ControlNet starts applying.
1362
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1363
+ The percentage of total steps at which the ControlNet stops applying.
1364
+ control_mode (`int` or `List[int]`, *optional*):
1365
+ The mode of ControlNet guidance. Can be used to specify different behaviors for multiple ControlNets.
1366
+ original_size (`Tuple[int, int]`, *optional*):
1367
+ If `original_size` is not the same as `target_size`, the image will appear to be down- or upsampled.
1368
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning.
1369
+ crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to (0, 0)):
1370
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1371
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1372
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning.
1373
+ target_size (`Tuple[int, int]`, *optional*):
1374
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1375
+ not specified, it will default to `(height, width)`. Part of SDXL's micro-conditioning.
1376
+ negative_original_size (`Tuple[int, int]`, *optional*):
1377
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1378
+ micro-conditioning.
1379
+ negative_crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to (0, 0)):
1380
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1381
+ micro-conditioning.
1382
+ negative_target_size (`Tuple[int, int]`, *optional*):
1383
+ To negatively condition the generation process based on a target image resolution. It should be the same
1384
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning.
1385
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
1386
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
1387
+ Part of SDXL's micro-conditioning.
1388
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
1389
+ Used to simulate an aesthetic score of the generated image by influencing the negative text condition.
1390
+ Part of SDXL's micro-conditioning.
1391
+ clip_skip (`int`, *optional*):
1392
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1393
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1394
+ normal_tile_overlap (`int`, *optional*, defaults to 64):
1395
+ Number of overlapping pixels between tiles in consecutive rows.
1396
+ border_tile_overlap (`int`, *optional*, defaults to 128):
1397
+ Number of overlapping pixels between tiles at the borders.
1398
+ max_tile_size (`int`, *optional*, defaults to 1024):
1399
+ Maximum size of a tile in pixels.
1400
+ tile_gaussian_sigma (`float`, *optional*, defaults to 0.3):
1401
+ Sigma parameter for Gaussian weighting of tiles.
1402
+ tile_weighting_method (`str`, *optional*, defaults to "Cosine"):
1403
+ Method for weighting tiles. Options: "Cosine" or "Gaussian".
1404
+
1405
+ Examples:
1406
+
1407
+ Returns:
1408
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1409
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
1410
+ containing the output images.
1411
+ """
1412
+
1413
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1414
+
1415
+ # align format for control guidance
1416
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1417
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1418
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1419
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1420
+
1421
+ if not isinstance(control_image, list):
1422
+ control_image = [control_image]
1423
+ else:
1424
+ control_image = control_image.copy()
1425
+
1426
+ if control_mode is None or isinstance(control_mode, list) and len(control_mode) == 0:
1427
+ raise ValueError("The value for `control_mode` is expected!")
1428
+
1429
+ if not isinstance(control_mode, list):
1430
+ control_mode = [control_mode]
1431
+
1432
+ if len(control_image) != len(control_mode):
1433
+ raise ValueError("Expected len(control_image) == len(control_mode)")
1434
+
1435
+ num_control_type = controlnet.config.num_control_type
1436
+
1437
+ # 0. Set internal use parameters
1438
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
1439
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
1440
+ original_size = original_size or (height, width)
1441
+ target_size = target_size or (height, width)
1442
+ negative_original_size = negative_original_size or original_size
1443
+ negative_target_size = negative_target_size or target_size
1444
+ control_type = [0 for _ in range(num_control_type)]
1445
+ control_type = torch.Tensor(control_type)
1446
+ self._guidance_scale = guidance_scale
1447
+ self._clip_skip = clip_skip
1448
+ self._cross_attention_kwargs = cross_attention_kwargs
1449
+ self._interrupt = False
1450
+ batch_size = 1
1451
+ device = self._execution_device
1452
+ global_pool_conditions = controlnet.config.global_pool_conditions
1453
+ guess_mode = guess_mode or global_pool_conditions
1454
+
1455
+ # 1. Check inputs
1456
+ for _image, control_idx in zip(control_image, control_mode):
1457
+ control_type[control_idx] = 1
1458
+ self.check_inputs(
1459
+ prompt,
1460
+ height,
1461
+ width,
1462
+ _image,
1463
+ strength,
1464
+ num_inference_steps,
1465
+ normal_tile_overlap,
1466
+ border_tile_overlap,
1467
+ max_tile_size,
1468
+ tile_gaussian_sigma,
1469
+ tile_weighting_method,
1470
+ controlnet_conditioning_scale,
1471
+ control_guidance_start,
1472
+ control_guidance_end,
1473
+ )
1474
+
1475
+ # 2 Get tile width and tile height size
1476
+ tile_width, tile_height = _adaptive_tile_size((width, height), max_tile_size=max_tile_size)
1477
+
1478
+ # 2.1 Calculate the number of tiles needed
1479
+ grid_rows, grid_cols = self._get_num_tiles(height, width, tile_height, tile_width, normal_tile_overlap, border_tile_overlap)
1480
+
1481
+ # 2.2 Expand prompt to number of tiles
1482
+ if not isinstance(prompt, list):
1483
+ prompt = [[prompt] * grid_cols] * grid_rows
1484
+
1485
+ # 2.3 Update height and width tile size by tile size and tile overlap size
1486
+ width = (grid_cols - 1) * (tile_width - normal_tile_overlap) + min(
1487
+ tile_width, width - (grid_cols - 1) * (tile_width - normal_tile_overlap)
1488
+ )
1489
+ height = (grid_rows - 1) * (tile_height - normal_tile_overlap) + min(
1490
+ tile_height, height - (grid_rows - 1) * (tile_height - normal_tile_overlap)
1491
+ )
1492
+
1493
+ # 3. Encode input prompt
1494
+ text_encoder_lora_scale = (
1495
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1496
+ )
1497
+ text_embeddings = [
1498
+ [
1499
+ self.encode_prompt(
1500
+ prompt=col,
1501
+ device=device,
1502
+ num_images_per_prompt=num_images_per_prompt,
1503
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1504
+ negative_prompt=negative_prompt,
1505
+ prompt_embeds=None,
1506
+ negative_prompt_embeds=None,
1507
+ pooled_prompt_embeds=None,
1508
+ negative_pooled_prompt_embeds=None,
1509
+ lora_scale=text_encoder_lora_scale,
1510
+ clip_skip=self.clip_skip,
1511
+ )
1512
+ for col in row
1513
+ ]
1514
+ for row in prompt
1515
+ ]
1516
+
1517
+ # 4. Prepare latent image
1518
+ image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
1519
+
1520
+ # 4.1 Prepare controlnet_conditioning_image
1521
+ control_image = self.prepare_control_image(
1522
+ image=image,
1523
+ width=width,
1524
+ height=height,
1525
+ batch_size=batch_size * num_images_per_prompt,
1526
+ num_images_per_prompt=num_images_per_prompt,
1527
+ device=device,
1528
+ dtype=controlnet.dtype,
1529
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1530
+ guess_mode=guess_mode,
1531
+ )
1532
+ control_type = (
1533
+ control_type.reshape(1, -1)
1534
+ .to(device, dtype=controlnet.dtype)
1535
+ .repeat(batch_size * num_images_per_prompt * 2, 1)
1536
+ )
1537
+
1538
+ # 5. Prepare timesteps
1539
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
1540
+ extra_set_kwargs = {}
1541
+ if accepts_offset:
1542
+ extra_set_kwargs["offset"] = 1
1543
+ self.scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
1544
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
1545
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1546
+ self._num_timesteps = len(timesteps)
1547
+
1548
+ # 6. Prepare latent variables
1549
+ dtype = text_embeddings[0][0][0].dtype
1550
+ if latents is None:
1551
+ latents = self.prepare_latents(
1552
+ image_tensor,
1553
+ latent_timestep,
1554
+ batch_size,
1555
+ num_images_per_prompt,
1556
+ dtype,
1557
+ device,
1558
+ generator,
1559
+ True,
1560
+ )
1561
+
1562
+ # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
1563
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
1564
+ latents = latents * self.scheduler.sigmas[0]
1565
+
1566
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1567
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1568
+
1569
+ # 8. Create tensor stating which controlnets to keep
1570
+ controlnet_keep = []
1571
+ for i in range(len(timesteps)):
1572
+ controlnet_keep.append(
1573
+ 1.0
1574
+ - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
1575
+ )
1576
+
1577
+ # 8.1 Prepare added time ids & embeddings
1578
+ # text_embeddings order: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
1579
+ embeddings_and_added_time = []
1580
+ crops_coords_top_left = negative_crops_coords_top_left = (tile_width, tile_height)
1581
+ for row in range(grid_rows):
1582
+ addition_embed_type_row = []
1583
+ for col in range(grid_cols):
1584
+ # extract generated values
1585
+ prompt_embeds = text_embeddings[row][col][0]
1586
+ negative_prompt_embeds = text_embeddings[row][col][1]
1587
+ pooled_prompt_embeds = text_embeddings[row][col][2]
1588
+ negative_pooled_prompt_embeds = text_embeddings[row][col][3]
1589
+
1590
+ if negative_original_size is None:
1591
+ negative_original_size = original_size
1592
+ if negative_target_size is None:
1593
+ negative_target_size = target_size
1594
+ add_text_embeds = pooled_prompt_embeds
1595
+
1596
+ if self.text_encoder_2 is None:
1597
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1598
+ else:
1599
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1600
+
1601
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
1602
+ original_size,
1603
+ crops_coords_top_left,
1604
+ target_size,
1605
+ aesthetic_score,
1606
+ negative_aesthetic_score,
1607
+ negative_original_size,
1608
+ negative_crops_coords_top_left,
1609
+ negative_target_size,
1610
+ dtype=prompt_embeds.dtype,
1611
+ text_encoder_projection_dim=text_encoder_projection_dim,
1612
+ )
1613
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1614
+
1615
+ if self.do_classifier_free_guidance:
1616
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1617
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1618
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1619
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
1620
+
1621
+ prompt_embeds = prompt_embeds.to(device)
1622
+ add_text_embeds = add_text_embeds.to(device)
1623
+ add_time_ids = add_time_ids.to(device)
1624
+ addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
1625
+
1626
+ embeddings_and_added_time.append(addition_embed_type_row)
1627
+
1628
+ # 9. Prepare tiles weights and latent overlaps size to denoising process
1629
+ tile_weights, tile_row_overlaps, tile_col_overlaps = self.prepare_tiles(
1630
+ grid_rows,
1631
+ grid_cols,
1632
+ tile_weighting_method,
1633
+ tile_width,
1634
+ tile_height,
1635
+ normal_tile_overlap,
1636
+ border_tile_overlap,
1637
+ width,
1638
+ height,
1639
+ tile_gaussian_sigma,
1640
+ batch_size,
1641
+ device,
1642
+ dtype,
1643
+ )
1644
+
1645
+ # 10. Denoising loop
1646
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1647
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1648
+ for i, t in enumerate(timesteps):
1649
+ # Diffuse each tile
1650
+ noise_preds = []
1651
+ for row in range(grid_rows):
1652
+ noise_preds_row = []
1653
+ for col in range(grid_cols):
1654
+ if self.interrupt:
1655
+ continue
1656
+ tile_row_overlap = tile_row_overlaps[row, col]
1657
+ tile_col_overlap = tile_col_overlaps[row, col]
1658
+
1659
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(
1660
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, width, height
1661
+ )
1662
+
1663
+ tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
1664
+
1665
+ # expand the latents if we are doing classifier free guidance
1666
+ latent_model_input = (
1667
+ torch.cat([tile_latents] * 2)
1668
+ if self.do_classifier_free_guidance
1669
+ else tile_latents # 1, 4, ...
1670
+ )
1671
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1672
+
1673
+ # predict the noise residual
1674
+ added_cond_kwargs = {
1675
+ "text_embeds": embeddings_and_added_time[row][col][1],
1676
+ "time_ids": embeddings_and_added_time[row][col][2],
1677
+ }
1678
+
1679
+ # controlnet(s) inference
1680
+ if guess_mode and self.do_classifier_free_guidance:
1681
+ # Infer ControlNet only for the conditional batch.
1682
+ control_model_input = tile_latents
1683
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1684
+ controlnet_prompt_embeds = embeddings_and_added_time[row][col][0].chunk(2)[1]
1685
+ controlnet_added_cond_kwargs = {
1686
+ "text_embeds": embeddings_and_added_time[row][col][1].chunk(2)[1],
1687
+ "time_ids": embeddings_and_added_time[row][col][2].chunk(2)[1],
1688
+ }
1689
+ else:
1690
+ control_model_input = latent_model_input
1691
+ controlnet_prompt_embeds = embeddings_and_added_time[row][col][0]
1692
+ controlnet_added_cond_kwargs = added_cond_kwargs
1693
+
1694
+ if isinstance(controlnet_keep[i], list):
1695
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1696
+ else:
1697
+ controlnet_cond_scale = controlnet_conditioning_scale
1698
+ if isinstance(controlnet_cond_scale, list):
1699
+ controlnet_cond_scale = controlnet_cond_scale[0]
1700
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1701
+
1702
+ px_row_init_pixel, px_row_end_pixel, px_col_init_pixel, px_col_end_pixel = _tile2pixel_indices(
1703
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, width, height
1704
+ )
1705
+
1706
+ tile_control_image = control_image[
1707
+ :, :, px_row_init_pixel:px_row_end_pixel, px_col_init_pixel:px_col_end_pixel
1708
+ ]
1709
+
1710
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1711
+ control_model_input,
1712
+ t,
1713
+ encoder_hidden_states=controlnet_prompt_embeds,
1714
+ controlnet_cond=[tile_control_image],
1715
+ control_type=control_type,
1716
+ control_type_idx=control_mode,
1717
+ conditioning_scale=cond_scale,
1718
+ guess_mode=guess_mode,
1719
+ added_cond_kwargs=controlnet_added_cond_kwargs,
1720
+ return_dict=False,
1721
+ )
1722
+
1723
+ if guess_mode and self.do_classifier_free_guidance:
1724
+ # Inferred ControlNet only for the conditional batch.
1725
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1726
+ # add 0 to the unconditional batch to keep it unchanged.
1727
+ down_block_res_samples = [
1728
+ torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples
1729
+ ]
1730
+ mid_block_res_sample = torch.cat(
1731
+ [torch.zeros_like(mid_block_res_sample), mid_block_res_sample]
1732
+ )
1733
+
1734
+ # predict the noise residual
1735
+ with torch.amp.autocast(device.type, dtype=dtype, enabled=dtype != self.unet.dtype):
1736
+ noise_pred = self.unet(
1737
+ latent_model_input,
1738
+ t,
1739
+ encoder_hidden_states=embeddings_and_added_time[row][col][0],
1740
+ cross_attention_kwargs=self.cross_attention_kwargs,
1741
+ down_block_additional_residuals=down_block_res_samples,
1742
+ mid_block_additional_residual=mid_block_res_sample,
1743
+ added_cond_kwargs=added_cond_kwargs,
1744
+ return_dict=False,
1745
+ )[0]
1746
+
1747
+ # perform guidance
1748
+ if self.do_classifier_free_guidance:
1749
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1750
+ noise_pred_tile = noise_pred_uncond + guidance_scale * (
1751
+ noise_pred_text - noise_pred_uncond
1752
+ )
1753
+ noise_preds_row.append(noise_pred_tile)
1754
+ noise_preds.append(noise_preds_row)
1755
+
1756
+ # Stitch noise predictions for all tiles
1757
+ noise_pred = torch.zeros(latents.shape, device=device)
1758
+ contributors = torch.zeros(latents.shape, device=device)
1759
+
1760
+ # Add each tile contribution to overall latents
1761
+ for row in range(grid_rows):
1762
+ for col in range(grid_cols):
1763
+ tile_row_overlap = tile_row_overlaps[row, col]
1764
+ tile_col_overlap = tile_col_overlaps[row, col]
1765
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(
1766
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, width, height
1767
+ )
1768
+ tile_weights_resized = tile_weights[row, col]
1769
+
1770
+ noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += (
1771
+ noise_preds[row][col] * tile_weights_resized
1772
+ )
1773
+ contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights_resized
1774
+
1775
+ # Average overlapping areas with more than 1 contributor
1776
+ noise_pred /= contributors
1777
+ noise_pred = noise_pred.to(dtype)
1778
+
1779
+ # compute the previous noisy sample x_t -> x_t-1
1780
+ latents_dtype = latents.dtype
1781
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1782
+ if latents.dtype != latents_dtype:
1783
+ if torch.backends.mps.is_available():
1784
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1785
+ latents = latents.to(latents_dtype)
1786
+
1787
+ # update progress bar
1788
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1789
+ progress_bar.update()
1790
+
1791
+ if XLA_AVAILABLE:
1792
+ xm.mark_step()
1793
+
1794
+ # If we do sequential model offloading, let's offload unet and controlnet
1795
+ # manually for max memory savings
1796
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1797
+ self.unet.to("cpu")
1798
+ self.controlnet.to("cpu")
1799
+ torch.cuda.empty_cache()
1800
+
1801
+ if not output_type == "latent":
1802
+ # make sure the VAE is in float32 mode, as it overflows in float16
1803
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1804
+
1805
+ if needs_upcasting:
1806
+ self.upcast_vae()
1807
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1808
+
1809
+ # unscale/denormalize the latents
1810
+ # denormalize with the mean and std if available and not None
1811
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1812
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1813
+ if has_latents_mean and has_latents_std:
1814
+ latents_mean = (
1815
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1816
+ )
1817
+ latents_std = (
1818
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1819
+ )
1820
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1821
+ else:
1822
+ latents = latents / self.vae.config.scaling_factor
1823
+
1824
+ image = self.vae.decode(latents, return_dict=False)[0]
1825
+
1826
+ # cast back to fp16 if needed
1827
+ if needs_upcasting:
1828
+ self.vae.to(dtype=torch.float16)
1829
+
1830
+ # apply watermark if available
1831
+ if self.watermark is not None:
1832
+ image = self.watermark.apply_watermark(image)
1833
+
1834
+ image = self.image_processor.postprocess(image, output_type=output_type)
1835
+ else:
1836
+ image = latents
1837
+
1838
+ # Offload all models
1839
+ self.maybe_free_model_hooks()
1840
+
1841
+ result = StableDiffusionXLPipelineOutput(images=image)
1842
+ if not return_dict:
1843
+ return (image,)
1844
+
1845
+ return result
pipeline/util.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The DEVAIEXP Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import gc
17
+ import cv2
18
+ import numpy as np
19
+ import torch
20
+ from PIL import Image
21
+ from gradio.themes import Default
22
+ import gradio as gr
23
+
24
+
25
+ MAX_SEED = np.iinfo(np.int32).max
26
+ SAMPLERS = {
27
+ "DDIM": ("DDIMScheduler", {}),
28
+ "DDIM trailing": ("DDIMScheduler", {"timestep_spacing": "trailing"}),
29
+ "DDPM": ("DDPMScheduler", {}),
30
+ "DEIS": ("DEISMultistepScheduler", {}),
31
+ "Heun": ("HeunDiscreteScheduler", {}),
32
+ "Heun Karras": ("HeunDiscreteScheduler", {"use_karras_sigmas": True}),
33
+ "Euler": ("EulerDiscreteScheduler", {}),
34
+ "Euler trailing": ("EulerDiscreteScheduler", {"timestep_spacing": "trailing", "prediction_type": "sample"}),
35
+ "Euler Ancestral": ("EulerAncestralDiscreteScheduler", {}),
36
+ "Euler Ancestral trailing": ("EulerAncestralDiscreteScheduler", {"timestep_spacing": "trailing"}),
37
+ "DPM++ 1S": ("DPMSolverMultistepScheduler", {"solver_order": 1}),
38
+ "DPM++ 1S Karras": ("DPMSolverMultistepScheduler", {"solver_order": 1, "use_karras_sigmas": True}),
39
+ "DPM++ 2S": ("DPMSolverSinglestepScheduler", {"use_karras_sigmas": False}),
40
+ "DPM++ 2S Karras": ("DPMSolverSinglestepScheduler", {"use_karras_sigmas": True}),
41
+ "DPM++ 2M": ("DPMSolverMultistepScheduler", {"use_karras_sigmas": False}),
42
+ "DPM++ 2M Karras": ("DPMSolverMultistepScheduler", {"use_karras_sigmas": True}),
43
+ "DPM++ 2M SDE": ("DPMSolverMultistepScheduler", {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
44
+ "DPM++ 2M SDE Karras": (
45
+ "DPMSolverMultistepScheduler",
46
+ {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"},
47
+ ),
48
+ "DPM++ 3M": ("DPMSolverMultistepScheduler", {"solver_order": 3}),
49
+ "DPM++ 3M Karras": ("DPMSolverMultistepScheduler", {"solver_order": 3, "use_karras_sigmas": True}),
50
+ "DPM++ SDE": ("DPMSolverSDEScheduler", {"use_karras_sigmas": False}),
51
+ "DPM++ SDE Karras": ("DPMSolverSDEScheduler", {"use_karras_sigmas": True}),
52
+ "DPM2": ("KDPM2DiscreteScheduler", {}),
53
+ "DPM2 Karras": ("KDPM2DiscreteScheduler", {"use_karras_sigmas": True}),
54
+ "DPM2 Ancestral": ("KDPM2AncestralDiscreteScheduler", {}),
55
+ "DPM2 Ancestral Karras": ("KDPM2AncestralDiscreteScheduler", {"use_karras_sigmas": True}),
56
+ "LMS": ("LMSDiscreteScheduler", {}),
57
+ "LMS Karras": ("LMSDiscreteScheduler", {"use_karras_sigmas": True}),
58
+ "UniPC": ("UniPCMultistepScheduler", {}),
59
+ "UniPC Karras": ("UniPCMultistepScheduler", {"use_karras_sigmas": True}),
60
+ "PNDM": ("PNDMScheduler", {}),
61
+ "Euler EDM": ("EDMEulerScheduler", {}),
62
+ "Euler EDM Karras": ("EDMEulerScheduler", {"use_karras_sigmas": True}),
63
+ "DPM++ 2M EDM": (
64
+ "EDMDPMSolverMultistepScheduler",
65
+ {"solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++"},
66
+ ),
67
+ "DPM++ 2M EDM Karras": (
68
+ "EDMDPMSolverMultistepScheduler",
69
+ {
70
+ "use_karras_sigmas": True,
71
+ "solver_order": 2,
72
+ "solver_type": "midpoint",
73
+ "final_sigmas_type": "zero",
74
+ "algorithm_type": "dpmsolver++",
75
+ },
76
+ ),
77
+ "DPM++ 2M Lu": ("DPMSolverMultistepScheduler", {"use_lu_lambdas": True}),
78
+ "DPM++ 2M Ef": ("DPMSolverMultistepScheduler", {"euler_at_final": True}),
79
+ "DPM++ 2M SDE Lu": ("DPMSolverMultistepScheduler", {"use_lu_lambdas": True, "algorithm_type": "sde-dpmsolver++"}),
80
+ "DPM++ 2M SDE Ef": ("DPMSolverMultistepScheduler", {"algorithm_type": "sde-dpmsolver++", "euler_at_final": True}),
81
+ "LCM": ("LCMScheduler", {}),
82
+ "LCM trailing": ("LCMScheduler", {"timestep_spacing": "trailing"}),
83
+ "TCD": ("TCDScheduler", {}),
84
+ "TCD trailing": ("TCDScheduler", {"timestep_spacing": "trailing"}),
85
+ }
86
+
87
+ class Platinum(Default):
88
+ def __init__(
89
+ self,
90
+ ):
91
+ super().__init__(
92
+ font = (
93
+ gr.themes.GoogleFont("Karla"), 'Segoe UI Emoji', 'Public Sans', 'system-ui', 'sans-serif'
94
+ )
95
+ )
96
+ self.name = "Diffusers"
97
+ super().set(
98
+ block_border_width='1px',
99
+ block_border_width_dark='1px',
100
+ block_info_text_size='13px',
101
+ block_info_text_weight='450',
102
+ block_info_text_color='#474a50',
103
+ block_label_background_fill='*background_fill_secondary',
104
+ block_label_text_color='*neutral_700',
105
+ block_title_text_color='black',
106
+ block_title_text_weight='600',
107
+ block_background_fill='#fcfcfc',
108
+ body_background_fill='*background_fill_secondary',
109
+ body_text_color='black',
110
+ background_fill_secondary='#f8f8f8',
111
+ border_color_accent='*primary_50',
112
+ border_color_primary='#ededed',
113
+ color_accent='#7367f0',
114
+ color_accent_soft='#fcfcfc',
115
+ panel_background_fill='#fcfcfc',
116
+ section_header_text_weight='600',
117
+ checkbox_background_color='*background_fill_secondary',
118
+ input_background_fill='white',
119
+ input_placeholder_color='*neutral_300',
120
+ loader_color = '#7367f0',
121
+ slider_color='#7367f0',
122
+ table_odd_background_fill='*neutral_100',
123
+ button_small_radius='*radius_sm',
124
+ button_primary_background_fill='linear-gradient(to bottom right, #7367f0, #9c93f4)',
125
+ button_primary_background_fill_hover='linear-gradient(to bottom right, #9c93f4, #9c93f4)',
126
+ button_primary_background_fill_hover_dark='linear-gradient(to bottom right, #5e50ee, #5e50ee)',
127
+ button_cancel_background_fill='linear-gradient(to bottom right, #fc0379, #ff88ac)',
128
+ button_cancel_background_fill_dark='linear-gradient(to bottom right, #dc2626, #b91c1c)',
129
+ button_cancel_background_fill_hover='linear-gradient(to bottom right, #f592c9, #f592c9)',
130
+ button_cancel_background_fill_hover_dark='linear-gradient(to bottom right, #dc2626, #dc2626)',
131
+ button_primary_border_color='#5949ed',
132
+ button_primary_text_color='white',
133
+ button_cancel_text_color='white',
134
+ button_cancel_text_color_dark='#dc2626',
135
+ button_cancel_border_color='#f04668',
136
+ button_cancel_border_color_dark='#dc2626',
137
+ button_cancel_border_color_hover='#fe6565',
138
+ button_cancel_border_color_hover_dark='#dc2626',
139
+ form_gap_width='1px',
140
+ layout_gap='5px'
141
+ )
142
+
143
+
144
+ def select_scheduler(pipe, selected_sampler):
145
+ import diffusers
146
+
147
+ scheduler_class_name, add_kwargs = SAMPLERS[selected_sampler]
148
+ config = pipe.scheduler.config
149
+ scheduler = getattr(diffusers, scheduler_class_name)
150
+ if selected_sampler in ("LCM", "LCM trailing"):
151
+ config = {
152
+ x: config[x] for x in config if x not in ("skip_prk_steps", "interpolation_type", "use_karras_sigmas")
153
+ }
154
+ elif selected_sampler in ("TCD", "TCD trailing"):
155
+ config = {x: config[x] for x in config if x not in ("skip_prk_steps")}
156
+
157
+ return scheduler.from_config(config, **add_kwargs)
158
+
159
+
160
+ def calculate_overlap(width, height, base_overlap=128):
161
+ """
162
+ Calculates dynamic overlap based on the image's aspect ratio.
163
+
164
+ Args:
165
+ width (int): Width of the image in pixels.
166
+ height (int): Height of the image in pixels.
167
+ base_overlap (int, optional): Base overlap value in pixels. Defaults to 128.
168
+
169
+ Returns:
170
+ tuple: A tuple containing:
171
+ - row_overlap (int): Overlap between tiles in consecutive rows.
172
+ - col_overlap (int): Overlap between tiles in consecutive columns.
173
+ """
174
+ ratio = height / width
175
+ if ratio < 1: # Image is wider than tall
176
+ return base_overlap // 2, base_overlap
177
+ else: # Image is taller than wide
178
+ return base_overlap, base_overlap * 2
179
+
180
+
181
+ # def calculate_overlap(width, height, base_overlap=128, scale=4):
182
+ # """
183
+ # Calculates dynamic overlap based on the image's aspect ratio and resolution.
184
+ # For scales less than 4, the overlap is fixed at 64, 128 (or 128, 256).
185
+ # For scales 4 or greater, the overlap is adjusted proportionally to the scale.
186
+
187
+ # Args:
188
+ # width (int): Width of the image in pixels.
189
+ # height (int): Height of the image in pixels.
190
+ # base_overlap (int, optional): Base overlap value in pixels. Defaults to 128.
191
+ # scale (int, optional): Scale factor for calculating the overlap. Defaults to 4.
192
+
193
+ # Returns:
194
+ # tuple: A tuple containing:
195
+ # - row_overlap (int): Overlap between tiles in consecutive rows.
196
+ # - col_overlap (int): Overlap between tiles in consecutive columns.
197
+ # """
198
+ # # Define the base scale (4)
199
+ # base_scale = 4
200
+
201
+ # # If scale is less than 4, use fixed overlap values
202
+ # if scale < base_scale:
203
+ # ratio = height / width
204
+ # if ratio < 1: # Image is wider than tall
205
+ # return base_overlap // 2, base_overlap
206
+ # else: # Image is taller than wide
207
+ # return base_overlap, base_overlap * 2
208
+ # else:
209
+ # # For scales 4 or greater, adjust overlap proportionally
210
+ # scaling_factor = scale / base_scale
211
+ # base_overlap = int(base_overlap * base_scale)
212
+ # #base_overlap = int(base_overlap * scaling_factor)
213
+
214
+ # ratio = height / width
215
+ # if ratio < 1: # Image is wider than tall
216
+ # return base_overlap // 2, base_overlap
217
+ # else: # Image is taller than wide
218
+ # return base_overlap, base_overlap * 2
219
+
220
+
221
+ # This function was copied and adapted from https://huggingface.co/spaces/gokaygokay/TileUpscalerV2, licensed under Apache 2.0.
222
+ def progressive_upscale(input_image, target_resolution, steps=3):
223
+ """
224
+ Progressively upscales an image to the target resolution in multiple steps.
225
+
226
+ Args:
227
+ input_image (PIL.Image.Image): The input image to be upscaled.
228
+ target_resolution (int): The target resolution (width or height) in pixels.
229
+ steps (int, optional): The number of upscaling steps. Defaults to 3.
230
+
231
+ Returns:
232
+ PIL.Image.Image: The upscaled image at the target resolution.
233
+ """
234
+ current_image = input_image.convert("RGB")
235
+ current_size = max(current_image.size)
236
+
237
+ # Upscale in multiple steps
238
+ for _ in range(steps):
239
+ if current_size >= target_resolution:
240
+ break
241
+ scale_factor = min(2, target_resolution / current_size)
242
+ new_size = (int(current_image.width * scale_factor), int(current_image.height * scale_factor))
243
+ current_image = current_image.resize(new_size, Image.LANCZOS)
244
+ current_size = max(current_image.size)
245
+
246
+ # Final resize to exact target resolution
247
+ if current_size != target_resolution:
248
+ aspect_ratio = current_image.width / current_image.height
249
+ if current_image.width > current_image.height:
250
+ new_size = (target_resolution, int(target_resolution / aspect_ratio))
251
+ else:
252
+ new_size = (int(target_resolution * aspect_ratio), target_resolution)
253
+ current_image = current_image.resize(new_size, Image.LANCZOS)
254
+
255
+ return current_image
256
+
257
+
258
+ # This function was copied and adapted from https://huggingface.co/spaces/gokaygokay/TileUpscalerV2, licensed under Apache 2.0.
259
+ def create_hdr_effect(original_image, hdr):
260
+ """
261
+ Applies an HDR (High Dynamic Range) effect to an image based on the specified intensity.
262
+
263
+ Args:
264
+ original_image (PIL.Image.Image): The original image to which the HDR effect will be applied.
265
+ hdr (float): The intensity of the HDR effect, ranging from 0 (no effect) to 1 (maximum effect).
266
+
267
+ Returns:
268
+ PIL.Image.Image: The image with the HDR effect applied.
269
+ """
270
+ if hdr == 0:
271
+ return original_image # No effect applied if hdr is 0
272
+
273
+ # Convert the PIL image to a NumPy array in BGR format (OpenCV format)
274
+ cv_original = cv2.cvtColor(np.array(original_image), cv2.COLOR_RGB2BGR)
275
+
276
+ # Define scaling factors for creating multiple exposures
277
+ factors = [
278
+ 1.0 - 0.9 * hdr,
279
+ 1.0 - 0.7 * hdr,
280
+ 1.0 - 0.45 * hdr,
281
+ 1.0 - 0.25 * hdr,
282
+ 1.0,
283
+ 1.0 + 0.2 * hdr,
284
+ 1.0 + 0.4 * hdr,
285
+ 1.0 + 0.6 * hdr,
286
+ 1.0 + 0.8 * hdr,
287
+ ]
288
+
289
+ # Generate multiple exposure images by scaling the original image
290
+ images = [cv2.convertScaleAbs(cv_original, alpha=factor) for factor in factors]
291
+
292
+ # Merge the images using the Mertens algorithm to create an HDR effect
293
+ merge_mertens = cv2.createMergeMertens()
294
+ hdr_image = merge_mertens.process(images)
295
+
296
+ # Convert the HDR image to 8-bit format (0-255 range)
297
+ hdr_image_8bit = np.clip(hdr_image * 255, 0, 255).astype("uint8")
298
+
299
+ # Convert the image back to RGB format and return as a PIL image
300
+ return Image.fromarray(cv2.cvtColor(hdr_image_8bit, cv2.COLOR_BGR2RGB))
301
+
302
+
303
+ def torch_gc():
304
+ if torch.cuda.is_available():
305
+ with torch.cuda.device("cuda"):
306
+ torch.cuda.empty_cache()
307
+ torch.cuda.ipc_collect()
308
+
309
+ gc.collect()
310
+
311
+
312
+ def quantize_8bit(unet):
313
+ if unet is None:
314
+ return
315
+
316
+ from peft.tuners.tuners_utils import BaseTunerLayer
317
+
318
+ dtype = unet.dtype
319
+ unet.to(torch.float8_e4m3fn)
320
+ for module in unet.modules(): # revert lora modules to prevent errors with fp8
321
+ if isinstance(module, BaseTunerLayer):
322
+ module.to(dtype)
323
+
324
+ if hasattr(unet, "encoder_hid_proj"): # revert ip adapter modules to prevent errors with fp8
325
+ if unet.encoder_hid_proj is not None:
326
+ for module in unet.encoder_hid_proj.modules():
327
+ module.to(dtype)
328
+ torch_gc()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ peft
3
+ opencv-python
4
+ spaces
5
+ scipy
6
+ gradio==5.15.0
7
+ numpy==1.26.4
8
+ transformers
9
+ accelerate
10
+ diffusers
11
+ fastapi>=0.115.2