OriLib commited on
Commit
196e472
·
verified ·
1 Parent(s): d94978e

Upload 2 files

Browse files
replace_bg/model/controlnet.py CHANGED
@@ -19,7 +19,7 @@ from torch import nn
19
  from torch.nn import functional as F
20
 
21
  from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.loaders import FromOriginalControlNetMixin
23
  from diffusers.utils import BaseOutput, logging
24
  from diffusers.models.attention_processor import (
25
  ADDED_KV_ATTENTION_PROCESSORS,
@@ -54,7 +54,7 @@ class ControlNetOutput(BaseOutput):
54
  be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
55
  used to condition the original UNet's downsampling activations.
56
  mid_down_block_re_sample (`torch.Tensor`):
57
- The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
58
  `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
59
  Output can be used to condition the original UNet's middle block activation.
60
  """
@@ -76,12 +76,12 @@ class ControlNetConditioningEmbedding(nn.Module):
76
  def __init__(
77
  self,
78
  conditioning_embedding_channels: int,
79
- conditioning_channels: int = 5, #update to 5
80
  block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
81
  ):
82
  super().__init__()
83
 
84
- self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
85
 
86
  self.blocks = nn.ModuleList([])
87
 
@@ -89,7 +89,7 @@ class ControlNetConditioningEmbedding(nn.Module):
89
  channel_in = block_out_channels[i]
90
  channel_out = block_out_channels[i + 1]
91
  self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
92
- self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=1)) # update to 1
93
 
94
  self.conv_out = zero_module(
95
  nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
@@ -108,7 +108,7 @@ class ControlNetConditioningEmbedding(nn.Module):
108
  return embedding
109
 
110
 
111
- class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
112
  """
113
  A ControlNet model.
114
 
@@ -530,7 +530,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
530
 
531
  def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
532
  if hasattr(module, "get_processor"):
533
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
534
 
535
  for sub_name, child in module.named_children():
536
  fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -665,10 +665,10 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
665
 
666
  def forward(
667
  self,
668
- sample: torch.FloatTensor,
669
  timestep: Union[torch.Tensor, float, int],
670
  encoder_hidden_states: torch.Tensor,
671
- controlnet_cond: torch.FloatTensor,
672
  conditioning_scale: float = 1.0,
673
  class_labels: Optional[torch.Tensor] = None,
674
  timestep_cond: Optional[torch.Tensor] = None,
@@ -677,18 +677,18 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
677
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
678
  guess_mode: bool = False,
679
  return_dict: bool = True,
680
- ) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
681
  """
682
  The [`ControlNetModel`] forward method.
683
 
684
  Args:
685
- sample (`torch.FloatTensor`):
686
  The noisy input tensor.
687
  timestep (`Union[torch.Tensor, float, int]`):
688
  The number of timesteps to denoise an input.
689
  encoder_hidden_states (`torch.Tensor`):
690
  The encoder hidden states.
691
- controlnet_cond (`torch.FloatTensor`):
692
  The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
693
  conditioning_scale (`float`, defaults to `1.0`):
694
  The scale factor for ControlNet outputs.
@@ -710,12 +710,13 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
710
  In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
711
  you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
712
  return_dict (`bool`, defaults to `True`):
713
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
 
714
 
715
  Returns:
716
- [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
717
- If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
718
- returned where the first element is the sample tensor.
719
  """
720
  # check channel order
721
  channel_order = self.config.controlnet_conditioning_channel_order
@@ -868,4 +869,4 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
868
  def zero_module(module):
869
  for p in module.parameters():
870
  nn.init.zeros_(p)
871
- return module
 
19
  from torch.nn import functional as F
20
 
21
  from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
23
  from diffusers.utils import BaseOutput, logging
24
  from diffusers.models.attention_processor import (
25
  ADDED_KV_ATTENTION_PROCESSORS,
 
54
  be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
55
  used to condition the original UNet's downsampling activations.
56
  mid_down_block_re_sample (`torch.Tensor`):
57
+ The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
58
  `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
59
  Output can be used to condition the original UNet's middle block activation.
60
  """
 
76
  def __init__(
77
  self,
78
  conditioning_embedding_channels: int,
79
+ conditioning_channels: int = 5, # Bria: update to 5
80
  block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
81
  ):
82
  super().__init__()
83
 
84
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
85
 
86
  self.blocks = nn.ModuleList([])
87
 
 
89
  channel_in = block_out_channels[i]
90
  channel_out = block_out_channels[i + 1]
91
  self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
92
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=1)) # Bria: update stride to 1
93
 
94
  self.conv_out = zero_module(
95
  nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
 
108
  return embedding
109
 
110
 
111
+ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
112
  """
113
  A ControlNet model.
114
 
 
530
 
531
  def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
532
  if hasattr(module, "get_processor"):
533
+ processors[f"{name}.processor"] = module.get_processor()
534
 
535
  for sub_name, child in module.named_children():
536
  fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
 
665
 
666
  def forward(
667
  self,
668
+ sample: torch.Tensor,
669
  timestep: Union[torch.Tensor, float, int],
670
  encoder_hidden_states: torch.Tensor,
671
+ controlnet_cond: torch.Tensor,
672
  conditioning_scale: float = 1.0,
673
  class_labels: Optional[torch.Tensor] = None,
674
  timestep_cond: Optional[torch.Tensor] = None,
 
677
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
678
  guess_mode: bool = False,
679
  return_dict: bool = True,
680
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
681
  """
682
  The [`ControlNetModel`] forward method.
683
 
684
  Args:
685
+ sample (`torch.Tensor`):
686
  The noisy input tensor.
687
  timestep (`Union[torch.Tensor, float, int]`):
688
  The number of timesteps to denoise an input.
689
  encoder_hidden_states (`torch.Tensor`):
690
  The encoder hidden states.
691
+ controlnet_cond (`torch.Tensor`):
692
  The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
693
  conditioning_scale (`float`, defaults to `1.0`):
694
  The scale factor for ControlNet outputs.
 
710
  In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
711
  you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
712
  return_dict (`bool`, defaults to `True`):
713
+ Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain
714
+ tuple.
715
 
716
  Returns:
717
+ [`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`:
718
+ If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned,
719
+ otherwise a tuple is returned where the first element is the sample tensor.
720
  """
721
  # check channel order
722
  channel_order = self.config.controlnet_conditioning_channel_order
 
869
  def zero_module(module):
870
  for p in module.parameters():
871
  nn.init.zeros_(p)
872
+ return module
replace_bg/model/pipeline_controlnet_sd_xl.py CHANGED
@@ -37,8 +37,8 @@ from diffusers.loaders import (
37
  StableDiffusionXLLoraLoaderMixin,
38
  TextualInversionLoaderMixin,
39
  )
40
-
41
  from .controlnet import ControlNetModel
 
42
  from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
43
  from diffusers.models.attention_processor import (
44
  AttnProcessor2_0,
@@ -57,9 +57,9 @@ from diffusers.utils import (
57
  unscale_lora_layers,
58
  )
59
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
60
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
61
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
62
-
63
 
64
  if is_invisible_watermark_available():
65
  from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
@@ -116,8 +116,69 @@ EXAMPLE_DOC_STRING = """
116
  """
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  class StableDiffusionXLControlNetPipeline(
120
  DiffusionPipeline,
 
121
  TextualInversionLoaderMixin,
122
  StableDiffusionXLLoraLoaderMixin,
123
  IPAdapterMixin,
@@ -176,7 +237,16 @@ class StableDiffusionXLControlNetPipeline(
176
  "feature_extractor",
177
  "image_encoder",
178
  ]
179
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
 
 
 
 
 
 
 
 
 
180
 
181
  def __init__(
182
  self,
@@ -224,39 +294,6 @@ class StableDiffusionXLControlNetPipeline(
224
 
225
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
226
 
227
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
228
- def enable_vae_slicing(self):
229
- r"""
230
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
231
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
232
- """
233
- self.vae.enable_slicing()
234
-
235
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
236
- def disable_vae_slicing(self):
237
- r"""
238
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
239
- computing decoding in one step.
240
- """
241
- self.vae.disable_slicing()
242
-
243
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
244
- def enable_vae_tiling(self):
245
- r"""
246
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
247
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
248
- processing larger images.
249
- """
250
- self.vae.enable_tiling()
251
-
252
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
253
- def disable_vae_tiling(self):
254
- r"""
255
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
256
- computing decoding in one step.
257
- """
258
- self.vae.disable_tiling()
259
-
260
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
261
  def encode_prompt(
262
  self,
@@ -267,10 +304,10 @@ class StableDiffusionXLControlNetPipeline(
267
  do_classifier_free_guidance: bool = True,
268
  negative_prompt: Optional[str] = None,
269
  negative_prompt_2: Optional[str] = None,
270
- prompt_embeds: Optional[torch.FloatTensor] = None,
271
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
272
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
273
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
274
  lora_scale: Optional[float] = None,
275
  clip_skip: Optional[int] = None,
276
  ):
@@ -296,17 +333,17 @@ class StableDiffusionXLControlNetPipeline(
296
  negative_prompt_2 (`str` or `List[str]`, *optional*):
297
  The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
298
  `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
299
- prompt_embeds (`torch.FloatTensor`, *optional*):
300
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
301
  provided, text embeddings will be generated from `prompt` input argument.
302
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
303
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
304
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
305
  argument.
306
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
307
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
308
  If not provided, pooled text embeddings will be generated from `prompt` input argument.
309
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
310
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
311
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
312
  input argument.
@@ -353,7 +390,7 @@ class StableDiffusionXLControlNetPipeline(
353
  prompt_2 = prompt_2 or prompt
354
  prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
355
 
356
- # textual inversion: procecss multi-vector tokens if necessary
357
  prompt_embeds_list = []
358
  prompts = [prompt, prompt_2]
359
  for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
@@ -518,33 +555,50 @@ class StableDiffusionXLControlNetPipeline(
518
  return image_embeds, uncond_image_embeds
519
 
520
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
521
- def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
522
- if not isinstance(ip_adapter_image, list):
523
- ip_adapter_image = [ip_adapter_image]
 
 
 
 
 
 
524
 
525
- if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
526
- raise ValueError(
527
- f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
528
- )
529
 
530
- image_embeds = []
531
- for single_ip_adapter_image, image_proj_layer in zip(
532
- ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
533
- ):
534
- output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
535
- single_image_embeds, single_negative_image_embeds = self.encode_image(
536
- single_ip_adapter_image, device, 1, output_hidden_state
537
- )
538
- single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
539
- single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
540
 
541
- if self.do_classifier_free_guidance:
542
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
543
- single_image_embeds = single_image_embeds.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
544
 
545
- image_embeds.append(single_image_embeds)
 
546
 
547
- return image_embeds
548
 
549
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
550
  def prepare_extra_step_kwargs(self, generator, eta):
@@ -575,6 +629,8 @@ class StableDiffusionXLControlNetPipeline(
575
  prompt_embeds=None,
576
  negative_prompt_embeds=None,
577
  pooled_prompt_embeds=None,
 
 
578
  negative_pooled_prompt_embeds=None,
579
  controlnet_conditioning_scale=1.0,
580
  control_guidance_start=0.0,
@@ -736,6 +792,21 @@ class StableDiffusionXLControlNetPipeline(
736
  if end > 1.0:
737
  raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739
  # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
740
  def check_image(self, image, prompt, prompt_embeds):
741
  image_is_pil = isinstance(image, PIL.Image.Image)
@@ -807,7 +878,12 @@ class StableDiffusionXLControlNetPipeline(
807
 
808
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
809
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
810
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
 
 
 
 
 
811
  if isinstance(generator, list) and len(generator) != batch_size:
812
  raise ValueError(
813
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -851,8 +927,6 @@ class StableDiffusionXLControlNetPipeline(
851
  (
852
  AttnProcessor2_0,
853
  XFormersAttnProcessor,
854
- LoRAXFormersAttnProcessor,
855
- LoRAAttnProcessor2_0,
856
  ),
857
  )
858
  # if xformers or torch_2_0 is used attention block does not need
@@ -862,49 +936,23 @@ class StableDiffusionXLControlNetPipeline(
862
  self.vae.decoder.conv_in.to(dtype)
863
  self.vae.decoder.mid_block.to(dtype)
864
 
865
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
866
- def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
867
- r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
868
-
869
- The suffixes after the scaling factors represent the stages where they are being applied.
870
-
871
- Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
872
- that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
873
-
874
- Args:
875
- s1 (`float`):
876
- Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
877
- mitigate "oversmoothing effect" in the enhanced denoising process.
878
- s2 (`float`):
879
- Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
880
- mitigate "oversmoothing effect" in the enhanced denoising process.
881
- b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
882
- b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
883
- """
884
- if not hasattr(self, "unet"):
885
- raise ValueError("The pipeline must have `unet` for using FreeU.")
886
- self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
887
-
888
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
889
- def disable_freeu(self):
890
- """Disables the FreeU mechanism if enabled."""
891
- self.unet.disable_freeu()
892
-
893
  # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
894
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
 
 
895
  """
896
  See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
897
 
898
  Args:
899
- timesteps (`torch.Tensor`):
900
- generate embedding vectors at these timesteps
901
  embedding_dim (`int`, *optional*, defaults to 512):
902
- dimension of the embeddings to generate
903
- dtype:
904
- data type of the generated embeddings
905
 
906
  Returns:
907
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
908
  """
909
  assert len(w.shape) == 1
910
  w = w * 1000.0
@@ -938,10 +986,18 @@ class StableDiffusionXLControlNetPipeline(
938
  def cross_attention_kwargs(self):
939
  return self._cross_attention_kwargs
940
 
 
 
 
 
941
  @property
942
  def num_timesteps(self):
943
  return self._num_timesteps
944
 
 
 
 
 
945
  @torch.no_grad()
946
  @replace_example_docstring(EXAMPLE_DOC_STRING)
947
  def __call__(
@@ -952,18 +1008,22 @@ class StableDiffusionXLControlNetPipeline(
952
  height: Optional[int] = None,
953
  width: Optional[int] = None,
954
  num_inference_steps: int = 50,
 
 
 
955
  guidance_scale: float = 5.0,
956
  negative_prompt: Optional[Union[str, List[str]]] = None,
957
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
958
  num_images_per_prompt: Optional[int] = 1,
959
  eta: float = 0.0,
960
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
961
- latents: Optional[torch.FloatTensor] = None,
962
- prompt_embeds: Optional[torch.FloatTensor] = None,
963
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
964
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
965
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
966
  ip_adapter_image: Optional[PipelineImageInput] = None,
 
967
  output_type: Optional[str] = "pil",
968
  return_dict: bool = True,
969
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -978,7 +1038,9 @@ class StableDiffusionXLControlNetPipeline(
978
  negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
979
  negative_target_size: Optional[Tuple[int, int]] = None,
980
  clip_skip: Optional[int] = None,
981
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
 
 
982
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
983
  **kwargs,
984
  ):
@@ -991,14 +1053,14 @@ class StableDiffusionXLControlNetPipeline(
991
  prompt_2 (`str` or `List[str]`, *optional*):
992
  The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
993
  used in both text-encoders.
994
- image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
995
- `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
996
  The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
997
- specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
998
- accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
999
- and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
1000
- `init`, images must be passed as a list such that each element of the list can be correctly batched for
1001
- input to a single ControlNet.
1002
  height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1003
  The height in pixels of the generated image. Anything below 512 pixels won't work well for
1004
  [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
@@ -1010,6 +1072,21 @@ class StableDiffusionXLControlNetPipeline(
1010
  num_inference_steps (`int`, *optional*, defaults to 50):
1011
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1012
  expense of slower inference.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1013
  guidance_scale (`float`, *optional*, defaults to 5.0):
1014
  A higher guidance scale value encourages the model to generate images closely linked to the text
1015
  `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -1027,24 +1104,29 @@ class StableDiffusionXLControlNetPipeline(
1027
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1028
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1029
  generation deterministic.
1030
- latents (`torch.FloatTensor`, *optional*):
1031
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1032
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1033
  tensor is generated by sampling using the supplied random `generator`.
1034
- prompt_embeds (`torch.FloatTensor`, *optional*):
1035
  Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1036
  provided, text embeddings are generated from the `prompt` input argument.
1037
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1038
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1039
  not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1040
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1041
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1042
  not provided, pooled text embeddings are generated from `prompt` input argument.
1043
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1044
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
1045
  weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
1046
  argument.
1047
  ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
 
 
 
 
 
1048
  output_type (`str`, *optional*, defaults to `"pil"`):
1049
  The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1050
  return_dict (`bool`, *optional*, defaults to `True`):
@@ -1096,15 +1178,15 @@ class StableDiffusionXLControlNetPipeline(
1096
  clip_skip (`int`, *optional*):
1097
  Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1098
  the output of the pre-final layer will be used for computing the prompt embeddings.
1099
- callback_on_step_end (`Callable`, *optional*):
1100
- A function that calls at the end of each denoising steps during the inference. The function is called
1101
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1102
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1103
- `callback_on_step_end_tensor_inputs`.
1104
  callback_on_step_end_tensor_inputs (`List`, *optional*):
1105
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1106
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1107
- `._callback_tensor_inputs` attribute of your pipeine class.
1108
 
1109
  Examples:
1110
 
@@ -1130,6 +1212,9 @@ class StableDiffusionXLControlNetPipeline(
1130
  "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1131
  )
1132
 
 
 
 
1133
  controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1134
 
1135
  # align format for control guidance
@@ -1155,6 +1240,8 @@ class StableDiffusionXLControlNetPipeline(
1155
  prompt_embeds,
1156
  negative_prompt_embeds,
1157
  pooled_prompt_embeds,
 
 
1158
  negative_pooled_prompt_embeds,
1159
  controlnet_conditioning_scale,
1160
  control_guidance_start,
@@ -1165,6 +1252,8 @@ class StableDiffusionXLControlNetPipeline(
1165
  self._guidance_scale = guidance_scale
1166
  self._clip_skip = clip_skip
1167
  self._cross_attention_kwargs = cross_attention_kwargs
 
 
1168
 
1169
  # 2. Define call parameters
1170
  if prompt is not None and isinstance(prompt, str):
@@ -1212,9 +1301,13 @@ class StableDiffusionXLControlNetPipeline(
1212
  )
1213
 
1214
  # 3.2 Encode ip_adapter_image
1215
- if ip_adapter_image is not None:
1216
  image_embeds = self.prepare_ip_adapter_image_embeds(
1217
- ip_adapter_image, device, batch_size * num_images_per_prompt
 
 
 
 
1218
  )
1219
 
1220
  # 4. Prepare image
@@ -1231,7 +1324,7 @@ class StableDiffusionXLControlNetPipeline(
1231
  guess_mode=guess_mode,
1232
  )
1233
  height, width = image.shape[-2:]
1234
- height, width = height*self.vae_scale_factor, width*self.vae_scale_factor # for vae controlnet
1235
  elif isinstance(controlnet, MultiControlNetModel):
1236
  images = []
1237
 
@@ -1256,8 +1349,9 @@ class StableDiffusionXLControlNetPipeline(
1256
  assert False
1257
 
1258
  # 5. Prepare timesteps
1259
- self.scheduler.set_timesteps(num_inference_steps, device=device)
1260
- timesteps = self.scheduler.timesteps
 
1261
  self._num_timesteps = len(timesteps)
1262
 
1263
  # 6. Prepare latent variables
@@ -1336,11 +1430,31 @@ class StableDiffusionXLControlNetPipeline(
1336
 
1337
  # 8. Denoising loop
1338
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1339
  is_unet_compiled = is_compiled_module(self.unet)
1340
  is_controlnet_compiled = is_compiled_module(self.controlnet)
1341
  is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1342
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1343
  for i, t in enumerate(timesteps):
 
 
 
1344
  # Relevant thread:
1345
  # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1346
  if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
@@ -1386,13 +1500,13 @@ class StableDiffusionXLControlNetPipeline(
1386
  )
1387
 
1388
  if guess_mode and self.do_classifier_free_guidance:
1389
- # Infered ControlNet only for the conditional batch.
1390
  # To apply the output of ControlNet to both the unconditional and conditional batches,
1391
  # add 0 to the unconditional batch to keep it unchanged.
1392
  down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1393
  mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1394
 
1395
- if ip_adapter_image is not None:
1396
  added_cond_kwargs["image_embeds"] = image_embeds
1397
 
1398
  # predict the noise residual
@@ -1425,6 +1539,13 @@ class StableDiffusionXLControlNetPipeline(
1425
  latents = callback_outputs.pop("latents", latents)
1426
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1427
  negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
 
 
 
 
 
 
 
1428
 
1429
  # call the callback, if provided
1430
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -1441,7 +1562,22 @@ class StableDiffusionXLControlNetPipeline(
1441
  self.upcast_vae()
1442
  latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1443
 
1444
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1445
 
1446
  # cast back to fp16 if needed
1447
  if needs_upcasting:
@@ -1462,4 +1598,4 @@ class StableDiffusionXLControlNetPipeline(
1462
  if not return_dict:
1463
  return (image,)
1464
 
1465
- return StableDiffusionXLPipelineOutput(images=image)
 
37
  StableDiffusionXLLoraLoaderMixin,
38
  TextualInversionLoaderMixin,
39
  )
 
40
  from .controlnet import ControlNetModel
41
+ # from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
42
  from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
43
  from diffusers.models.attention_processor import (
44
  AttnProcessor2_0,
 
57
  unscale_lora_layers,
58
  )
59
  from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
60
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
61
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
62
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
63
 
64
  if is_invisible_watermark_available():
65
  from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
 
116
  """
117
 
118
 
119
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
120
+ def retrieve_timesteps(
121
+ scheduler,
122
+ num_inference_steps: Optional[int] = None,
123
+ device: Optional[Union[str, torch.device]] = None,
124
+ timesteps: Optional[List[int]] = None,
125
+ sigmas: Optional[List[float]] = None,
126
+ **kwargs,
127
+ ):
128
+ r"""
129
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
130
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
131
+
132
+ Args:
133
+ scheduler (`SchedulerMixin`):
134
+ The scheduler to get timesteps from.
135
+ num_inference_steps (`int`):
136
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
137
+ must be `None`.
138
+ device (`str` or `torch.device`, *optional*):
139
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
140
+ timesteps (`List[int]`, *optional*):
141
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
142
+ `num_inference_steps` and `sigmas` must be `None`.
143
+ sigmas (`List[float]`, *optional*):
144
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
145
+ `num_inference_steps` and `timesteps` must be `None`.
146
+
147
+ Returns:
148
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
149
+ second element is the number of inference steps.
150
+ """
151
+ if timesteps is not None and sigmas is not None:
152
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
153
+ if timesteps is not None:
154
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
155
+ if not accepts_timesteps:
156
+ raise ValueError(
157
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
158
+ f" timestep schedules. Please check whether you are using the correct scheduler."
159
+ )
160
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
161
+ timesteps = scheduler.timesteps
162
+ num_inference_steps = len(timesteps)
163
+ elif sigmas is not None:
164
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
165
+ if not accept_sigmas:
166
+ raise ValueError(
167
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
168
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
169
+ )
170
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
171
+ timesteps = scheduler.timesteps
172
+ num_inference_steps = len(timesteps)
173
+ else:
174
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
175
+ timesteps = scheduler.timesteps
176
+ return timesteps, num_inference_steps
177
+
178
+
179
  class StableDiffusionXLControlNetPipeline(
180
  DiffusionPipeline,
181
+ StableDiffusionMixin,
182
  TextualInversionLoaderMixin,
183
  StableDiffusionXLLoraLoaderMixin,
184
  IPAdapterMixin,
 
237
  "feature_extractor",
238
  "image_encoder",
239
  ]
240
+ _callback_tensor_inputs = [
241
+ "latents",
242
+ "prompt_embeds",
243
+ "negative_prompt_embeds",
244
+ "add_text_embeds",
245
+ "add_time_ids",
246
+ "negative_pooled_prompt_embeds",
247
+ "negative_add_time_ids",
248
+ "image",
249
+ ]
250
 
251
  def __init__(
252
  self,
 
294
 
295
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
298
  def encode_prompt(
299
  self,
 
304
  do_classifier_free_guidance: bool = True,
305
  negative_prompt: Optional[str] = None,
306
  negative_prompt_2: Optional[str] = None,
307
+ prompt_embeds: Optional[torch.Tensor] = None,
308
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
309
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
310
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
311
  lora_scale: Optional[float] = None,
312
  clip_skip: Optional[int] = None,
313
  ):
 
333
  negative_prompt_2 (`str` or `List[str]`, *optional*):
334
  The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
335
  `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
336
+ prompt_embeds (`torch.Tensor`, *optional*):
337
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
338
  provided, text embeddings will be generated from `prompt` input argument.
339
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
340
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
341
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
342
  argument.
343
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
344
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
345
  If not provided, pooled text embeddings will be generated from `prompt` input argument.
346
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
347
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
348
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
349
  input argument.
 
390
  prompt_2 = prompt_2 or prompt
391
  prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
392
 
393
+ # textual inversion: process multi-vector tokens if necessary
394
  prompt_embeds_list = []
395
  prompts = [prompt, prompt_2]
396
  for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
 
555
  return image_embeds, uncond_image_embeds
556
 
557
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
558
+ def prepare_ip_adapter_image_embeds(
559
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
560
+ ):
561
+ image_embeds = []
562
+ if do_classifier_free_guidance:
563
+ negative_image_embeds = []
564
+ if ip_adapter_image_embeds is None:
565
+ if not isinstance(ip_adapter_image, list):
566
+ ip_adapter_image = [ip_adapter_image]
567
 
568
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
569
+ raise ValueError(
570
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
571
+ )
572
 
573
+ for single_ip_adapter_image, image_proj_layer in zip(
574
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
575
+ ):
576
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
577
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
578
+ single_ip_adapter_image, device, 1, output_hidden_state
579
+ )
 
 
 
580
 
581
+ image_embeds.append(single_image_embeds[None, :])
582
+ if do_classifier_free_guidance:
583
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
584
+ else:
585
+ for single_image_embeds in ip_adapter_image_embeds:
586
+ if do_classifier_free_guidance:
587
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
588
+ negative_image_embeds.append(single_negative_image_embeds)
589
+ image_embeds.append(single_image_embeds)
590
+
591
+ ip_adapter_image_embeds = []
592
+ for i, single_image_embeds in enumerate(image_embeds):
593
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
594
+ if do_classifier_free_guidance:
595
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
596
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
597
 
598
+ single_image_embeds = single_image_embeds.to(device=device)
599
+ ip_adapter_image_embeds.append(single_image_embeds)
600
 
601
+ return ip_adapter_image_embeds
602
 
603
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
604
  def prepare_extra_step_kwargs(self, generator, eta):
 
629
  prompt_embeds=None,
630
  negative_prompt_embeds=None,
631
  pooled_prompt_embeds=None,
632
+ ip_adapter_image=None,
633
+ ip_adapter_image_embeds=None,
634
  negative_pooled_prompt_embeds=None,
635
  controlnet_conditioning_scale=1.0,
636
  control_guidance_start=0.0,
 
792
  if end > 1.0:
793
  raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
794
 
795
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
796
+ raise ValueError(
797
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
798
+ )
799
+
800
+ if ip_adapter_image_embeds is not None:
801
+ if not isinstance(ip_adapter_image_embeds, list):
802
+ raise ValueError(
803
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
804
+ )
805
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
806
+ raise ValueError(
807
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
808
+ )
809
+
810
  # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
811
  def check_image(self, image, prompt, prompt_embeds):
812
  image_is_pil = isinstance(image, PIL.Image.Image)
 
878
 
879
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
880
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
881
+ shape = (
882
+ batch_size,
883
+ num_channels_latents,
884
+ int(height) // self.vae_scale_factor,
885
+ int(width) // self.vae_scale_factor,
886
+ )
887
  if isinstance(generator, list) and len(generator) != batch_size:
888
  raise ValueError(
889
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
 
927
  (
928
  AttnProcessor2_0,
929
  XFormersAttnProcessor,
 
 
930
  ),
931
  )
932
  # if xformers or torch_2_0 is used attention block does not need
 
936
  self.vae.decoder.conv_in.to(dtype)
937
  self.vae.decoder.mid_block.to(dtype)
938
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
939
  # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
940
+ def get_guidance_scale_embedding(
941
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
942
+ ) -> torch.Tensor:
943
  """
944
  See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
945
 
946
  Args:
947
+ w (`torch.Tensor`):
948
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
949
  embedding_dim (`int`, *optional*, defaults to 512):
950
+ Dimension of the embeddings to generate.
951
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
952
+ Data type of the generated embeddings.
953
 
954
  Returns:
955
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
956
  """
957
  assert len(w.shape) == 1
958
  w = w * 1000.0
 
986
  def cross_attention_kwargs(self):
987
  return self._cross_attention_kwargs
988
 
989
+ @property
990
+ def denoising_end(self):
991
+ return self._denoising_end
992
+
993
  @property
994
  def num_timesteps(self):
995
  return self._num_timesteps
996
 
997
+ @property
998
+ def interrupt(self):
999
+ return self._interrupt
1000
+
1001
  @torch.no_grad()
1002
  @replace_example_docstring(EXAMPLE_DOC_STRING)
1003
  def __call__(
 
1008
  height: Optional[int] = None,
1009
  width: Optional[int] = None,
1010
  num_inference_steps: int = 50,
1011
+ timesteps: List[int] = None,
1012
+ sigmas: List[float] = None,
1013
+ denoising_end: Optional[float] = None,
1014
  guidance_scale: float = 5.0,
1015
  negative_prompt: Optional[Union[str, List[str]]] = None,
1016
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
1017
  num_images_per_prompt: Optional[int] = 1,
1018
  eta: float = 0.0,
1019
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1020
+ latents: Optional[torch.Tensor] = None,
1021
+ prompt_embeds: Optional[torch.Tensor] = None,
1022
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
1023
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
1024
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
1025
  ip_adapter_image: Optional[PipelineImageInput] = None,
1026
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1027
  output_type: Optional[str] = "pil",
1028
  return_dict: bool = True,
1029
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
 
1038
  negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
1039
  negative_target_size: Optional[Tuple[int, int]] = None,
1040
  clip_skip: Optional[int] = None,
1041
+ callback_on_step_end: Optional[
1042
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
1043
+ ] = None,
1044
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1045
  **kwargs,
1046
  ):
 
1053
  prompt_2 (`str` or `List[str]`, *optional*):
1054
  The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1055
  used in both text-encoders.
1056
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
1057
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
1058
  The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
1059
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
1060
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
1061
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
1062
+ images must be passed as a list such that each element of the list can be correctly batched for input
1063
+ to a single ControlNet.
1064
  height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1065
  The height in pixels of the generated image. Anything below 512 pixels won't work well for
1066
  [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
 
1072
  num_inference_steps (`int`, *optional*, defaults to 50):
1073
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1074
  expense of slower inference.
1075
+ timesteps (`List[int]`, *optional*):
1076
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1077
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1078
+ passed will be used. Must be in descending order.
1079
+ sigmas (`List[float]`, *optional*):
1080
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1081
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1082
+ will be used.
1083
+ denoising_end (`float`, *optional*):
1084
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1085
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
1086
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
1087
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
1088
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1089
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
1090
  guidance_scale (`float`, *optional*, defaults to 5.0):
1091
  A higher guidance scale value encourages the model to generate images closely linked to the text
1092
  `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
 
1104
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1105
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1106
  generation deterministic.
1107
+ latents (`torch.Tensor`, *optional*):
1108
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1109
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1110
  tensor is generated by sampling using the supplied random `generator`.
1111
+ prompt_embeds (`torch.Tensor`, *optional*):
1112
  Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1113
  provided, text embeddings are generated from the `prompt` input argument.
1114
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1115
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1116
  not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1117
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
1118
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1119
  not provided, pooled text embeddings are generated from `prompt` input argument.
1120
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
1121
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
1122
  weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
1123
  argument.
1124
  ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1125
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1126
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1127
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1128
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1129
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1130
  output_type (`str`, *optional*, defaults to `"pil"`):
1131
  The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1132
  return_dict (`bool`, *optional*, defaults to `True`):
 
1178
  clip_skip (`int`, *optional*):
1179
  Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1180
  the output of the pre-final layer will be used for computing the prompt embeddings.
1181
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1182
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1183
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1184
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1185
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1186
  callback_on_step_end_tensor_inputs (`List`, *optional*):
1187
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1188
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1189
+ `._callback_tensor_inputs` attribute of your pipeline class.
1190
 
1191
  Examples:
1192
 
 
1212
  "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1213
  )
1214
 
1215
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1216
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1217
+
1218
  controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1219
 
1220
  # align format for control guidance
 
1240
  prompt_embeds,
1241
  negative_prompt_embeds,
1242
  pooled_prompt_embeds,
1243
+ ip_adapter_image,
1244
+ ip_adapter_image_embeds,
1245
  negative_pooled_prompt_embeds,
1246
  controlnet_conditioning_scale,
1247
  control_guidance_start,
 
1252
  self._guidance_scale = guidance_scale
1253
  self._clip_skip = clip_skip
1254
  self._cross_attention_kwargs = cross_attention_kwargs
1255
+ self._denoising_end = denoising_end
1256
+ self._interrupt = False
1257
 
1258
  # 2. Define call parameters
1259
  if prompt is not None and isinstance(prompt, str):
 
1301
  )
1302
 
1303
  # 3.2 Encode ip_adapter_image
1304
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1305
  image_embeds = self.prepare_ip_adapter_image_embeds(
1306
+ ip_adapter_image,
1307
+ ip_adapter_image_embeds,
1308
+ device,
1309
+ batch_size * num_images_per_prompt,
1310
+ self.do_classifier_free_guidance,
1311
  )
1312
 
1313
  # 4. Prepare image
 
1324
  guess_mode=guess_mode,
1325
  )
1326
  height, width = image.shape[-2:]
1327
+ height, width = height*self.vae_scale_factor, width*self.vae_scale_factor # Bria: update for vae controlnet
1328
  elif isinstance(controlnet, MultiControlNetModel):
1329
  images = []
1330
 
 
1349
  assert False
1350
 
1351
  # 5. Prepare timesteps
1352
+ timesteps, num_inference_steps = retrieve_timesteps(
1353
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1354
+ )
1355
  self._num_timesteps = len(timesteps)
1356
 
1357
  # 6. Prepare latent variables
 
1430
 
1431
  # 8. Denoising loop
1432
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1433
+
1434
+ # 8.1 Apply denoising_end
1435
+ if (
1436
+ self.denoising_end is not None
1437
+ and isinstance(self.denoising_end, float)
1438
+ and self.denoising_end > 0
1439
+ and self.denoising_end < 1
1440
+ ):
1441
+ discrete_timestep_cutoff = int(
1442
+ round(
1443
+ self.scheduler.config.num_train_timesteps
1444
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1445
+ )
1446
+ )
1447
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1448
+ timesteps = timesteps[:num_inference_steps]
1449
+
1450
  is_unet_compiled = is_compiled_module(self.unet)
1451
  is_controlnet_compiled = is_compiled_module(self.controlnet)
1452
  is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1453
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1454
  for i, t in enumerate(timesteps):
1455
+ if self.interrupt:
1456
+ continue
1457
+
1458
  # Relevant thread:
1459
  # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1460
  if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
 
1500
  )
1501
 
1502
  if guess_mode and self.do_classifier_free_guidance:
1503
+ # Inferred ControlNet only for the conditional batch.
1504
  # To apply the output of ControlNet to both the unconditional and conditional batches,
1505
  # add 0 to the unconditional batch to keep it unchanged.
1506
  down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1507
  mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1508
 
1509
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1510
  added_cond_kwargs["image_embeds"] = image_embeds
1511
 
1512
  # predict the noise residual
 
1539
  latents = callback_outputs.pop("latents", latents)
1540
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1541
  negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1542
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1543
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1544
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1545
+ )
1546
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1547
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1548
+ image = callback_outputs.pop("image", image)
1549
 
1550
  # call the callback, if provided
1551
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
1562
  self.upcast_vae()
1563
  latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1564
 
1565
+ # unscale/denormalize the latents
1566
+ # denormalize with the mean and std if available and not None
1567
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1568
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1569
+ if has_latents_mean and has_latents_std:
1570
+ latents_mean = (
1571
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1572
+ )
1573
+ latents_std = (
1574
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1575
+ )
1576
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1577
+ else:
1578
+ latents = latents / self.vae.config.scaling_factor
1579
+
1580
+ image = self.vae.decode(latents, return_dict=False)[0]
1581
 
1582
  # cast back to fp16 if needed
1583
  if needs_upcasting:
 
1598
  if not return_dict:
1599
  return (image,)
1600
 
1601
+ return StableDiffusionXLPipelineOutput(images=image)