ameerazam08 commited on
Commit
0a82683
·
verified ·
1 Parent(s): e8bec9f

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/R-F.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/teaser.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from src.pipeline_pe_clone import FluxPipeline
5
+ import spaces
6
+
7
+ @spaces.GPU()
8
+ def generate_image(model_path, image, height, width, prompt, guidance_scale, num_steps, lora_name):
9
+ # Load the model
10
+ pipeline = FluxPipeline.from_pretrained(
11
+ model_path,
12
+ torch_dtype=torch.bfloat16,
13
+ ).to('cuda')
14
+
15
+ # Load and fuse base LoRA weights
16
+ pipeline.load_lora_weights("nicolaus-huang/PhotoDoodle", weight_name="pretrain.safetensors")
17
+ pipeline.fuse_lora()
18
+ pipeline.unload_lora_weights()
19
+
20
+ # Load selected LoRA effect if not using the pretrained base model
21
+ if lora_name != 'pretrained':
22
+ pipeline.load_lora_weights("nicolaus-huang/PhotoDoodle", weight_name=f"{lora_name}.safetensors")
23
+
24
+ # Prepare the input image
25
+ condition_image = image.resize((height, width)).convert("RGB")
26
+
27
+ # Generate the output image
28
+ result = pipeline(
29
+ prompt=prompt,
30
+ condition_image=condition_image,
31
+ height=height,
32
+ width=width,
33
+ guidance_scale=guidance_scale,
34
+ num_inference_steps=num_steps,
35
+ max_sequence_length=512
36
+ ).images[0]
37
+
38
+ return result
39
+
40
+ # Create Gradio interface
41
+ iface = gr.Interface(
42
+ fn=generate_image,
43
+ inputs=[
44
+ gr.Textbox(label="Model Path", value="black-forest-labs/FLUX.1-dev"),
45
+ gr.Image(label="Input Image", type="pil"),
46
+ gr.Number(label="Height", value=768),
47
+ gr.Number(label="Width", value=512),
48
+ gr.Textbox(label="Prompt", value="add a halo and wings for the cat by sksmagiceffects"),
49
+ gr.Number(label="Guidance Scale", value=3.5),
50
+ gr.Number(label="Number of Steps", value=20),
51
+ gr.Dropdown(
52
+ label="LoRA Name",
53
+ choices=["pretrained", "sksmagiceffects", "sksmonstercalledlulu",
54
+ "skspaintingeffects", "sksedgeeffect", "skscatooneffect"],
55
+ value="sksmagiceffects"
56
+ )
57
+ ],
58
+ outputs=gr.Image(label="Output Image", type="pil"),
59
+ title="FLUX Image Generation with LoRA"
60
+ )
61
+
62
+ if __name__ == "__main__":
63
+ iface.launch()
assets/1.png ADDED
assets/R-F.jpg ADDED

Git LFS Details

  • SHA256: c37533cd09e5d5da972a4c5c7bb6093bb376b35b27cb10899d5ca1f608ebbad1
  • Pointer size: 132 Bytes
  • Size of remote file: 7.85 MB
assets/teaser.png ADDED

Git LFS Details

  • SHA256: 939bee83f31172b79c826e6dcb25978f11ee12e18324f65adb4742c59bb95d00
  • Pointer size: 132 Bytes
  • Size of remote file: 3.18 MB
inference.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from src.pipeline_pe_clone import FluxPipeline
3
+ import torch
4
+ from PIL import Image
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser(description='FLUX image generation with LoRA')
8
+ parser.add_argument('--model_path', type=str,
9
+ default="black-forest-labs/FLUX.1-dev",
10
+ help='Path to pretrained model')
11
+ parser.add_argument('--image_path', type=str,
12
+ default="assets/1.png",
13
+ help='Input image path')
14
+ parser.add_argument('--output_path', type=str,
15
+ default="output.png",
16
+ help='Output image path')
17
+ parser.add_argument('--height', type=int, default=768)
18
+ parser.add_argument('--width', type=int, default=512)
19
+ parser.add_argument('--prompt', type=str,
20
+ default="add a halo and wings for the cat by sksmagiceffects",
21
+ help="""Different LoRA effects and their example prompts:
22
+ - sksmagiceffects: "add a halo and wings for the cat by sksmagiceffects"
23
+ - sksmonstercalledlulu: "add a red sksmonstercalledlulu hugging the cat"
24
+ - skspaintingeffects: "add a yellow flower on the cat's head and psychedelic colors and dynamic flows by skspaintingeffects"
25
+ - sksedgeeffect: "add yellow flames to the cat by sksedgeeffect"
26
+ - skscatooneffect: "add two hands holding the cat in skscatooneffect"
27
+ """)
28
+ parser.add_argument('--guidance_scale', type=float, default=3.5)
29
+ parser.add_argument('--num_steps', type=int, default=20,
30
+ help='Number of inference steps')
31
+ parser.add_argument('--lora_name', type=str,
32
+ choices=['pretrained', 'sksmagiceffects', 'sksmonstercalledlulu',
33
+ 'skspaintingeffects', 'sksedgeeffect', 'skscatooneffect'],
34
+ default="sksmagiceffects",
35
+ help='Name of LoRA weights to use. Use "pretrained" for base model only')
36
+ return parser.parse_args()
37
+
38
+ def main():
39
+ args = parse_args()
40
+
41
+ pipeline = FluxPipeline.from_pretrained(
42
+ args.model_path,
43
+ torch_dtype=torch.bfloat16,
44
+ ).to('cuda')
45
+
46
+ # Load and fuse base LoRA weights
47
+ pipeline.load_lora_weights("nicolaus-huang/PhotoDoodle", weight_name="pretrain.safetensors")
48
+ pipeline.fuse_lora()
49
+ pipeline.unload_lora_weights()
50
+
51
+ # Load selected LoRA effect only if not using pretrained
52
+ if args.lora_name != 'pretrained':
53
+ pipeline.load_lora_weights("nicolaus-huang/PhotoDoodle", weight_name=f"{args.lora_name}.safetensors")
54
+
55
+ condition_image = Image.open(args.image_path).resize((args.height, args.width)).convert("RGB")
56
+
57
+ result = pipeline(
58
+ prompt=args.prompt,
59
+ condition_image=condition_image,
60
+ height=args.height,
61
+ width=args.width,
62
+ guidance_scale=args.guidance_scale,
63
+ num_inference_steps=args.num_steps,
64
+ max_sequence_length=512
65
+ ).images[0]
66
+
67
+ result.save(args.output_path)
68
+
69
+ if __name__ == "__main__":
70
+ main()
merge.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from src.pipeline_pe_clone import FluxPipeline
3
+ import torch
4
+ from PIL import Image
5
+ pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
6
+ pipeline = FluxPipeline.from_pretrained(
7
+ pretrained_model_name_or_path,
8
+ torch_dtype=torch.bfloat16,
9
+ )
10
+ pipeline.load_lora_weights("outputs/doodle_pretrain_4508000/pytorch_lora_weights.safetensors")
11
+ pipeline.fuse_lora()
12
+ pipeline.unload_lora_weights()
13
+ pipeline.save_pretrained("edit_pretrain")
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.33.0
2
+ transformers==4.44.0
3
+ diffusers[torch]==0.25.0
4
+ ftfy==6.1.1
5
+ # albumentations==1.3.0
6
+ opencv-python==4.8.1.78
7
+ einops==0.7.0
8
+ pytorch-lightning==1.9.0
9
+ bitsandbytes==0.44.0
10
+ prodigyopt==1.0
11
+ lion-pytorch==0.0.6
12
+ came_pytorch==0.1.3
13
+ schedulefree==1.4
14
+ tensorboard
15
+ safetensors==0.4.4
16
+ # for gradio
17
+ gradio==3.6
18
+ altair==4.2.2
19
+ easygui==0.98.3
20
+ toml==0.10.2
21
+ voluptuous==0.13.1
22
+ huggingface-hub==0.24.5
23
+ # for Image utils
24
+ imagesize==1.4.1
25
+ numpy<=2.0
26
+ rich==13.7.0
27
+ # for T5XXL tokenizer (SD3/FLUX)
28
+ sentencepiece==0.2.0
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (137 Bytes). View file
 
src/__pycache__/pipeline_pe_clone.cpython-310.pyc ADDED
Binary file (21.8 kB). View file
 
src/jsonl_datasets.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from datasets import load_dataset
4
+ from torchvision import transforms
5
+ import random
6
+ import os
7
+ import numpy as np
8
+
9
+ Image.MAX_IMAGE_PIXELS = None
10
+
11
+ def make_train_dataset(args, tokenizer, accelerator=None):
12
+ if args.train_data_dir is not None:
13
+ print("load_data")
14
+ dataset = load_dataset('json', data_files=args.train_data_dir)
15
+
16
+ column_names = dataset["train"].column_names
17
+
18
+ # 6. Get the column names for input/target.
19
+ if args.caption_column is None:
20
+ caption_column = column_names[0]
21
+ print(f"caption column defaulting to {caption_column}")
22
+ else:
23
+ caption_column = args.caption_column
24
+ if caption_column not in column_names:
25
+ raise ValueError(
26
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
27
+ )
28
+ if args.source_column is None:
29
+ source_column = column_names[1]
30
+ print(f"source column defaulting to {source_column}")
31
+ else:
32
+ source_column = args.source_column
33
+ if source_column not in column_names:
34
+ raise ValueError(
35
+ f"`--source_column` value '{args.source_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
36
+ )
37
+ if args.target_column is None:
38
+ target_column = column_names[1]
39
+ print(f"target column defaulting to {target_column}")
40
+ else:
41
+ target_column = args.target_column
42
+ if target_column not in column_names:
43
+ raise ValueError(
44
+ f"`--target_column` value '{args.target_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
45
+ )
46
+
47
+ h = args.height
48
+ w = args.width
49
+ train_transforms = transforms.Compose(
50
+ [
51
+ transforms.Resize((h, w), interpolation=transforms.InterpolationMode.BILINEAR),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize([0.5], [0.5]),
54
+ ]
55
+ )
56
+
57
+ tokenizer_clip = tokenizer[0]
58
+ tokenizer_t5 = tokenizer[1]
59
+
60
+ def tokenize_prompt_clip_t5(examples):
61
+ captions = []
62
+ for caption in examples[caption_column]:
63
+ if isinstance(caption, str):
64
+ captions.append(caption)
65
+ elif isinstance(caption, list):
66
+ captions.append(random.choice(caption))
67
+ else:
68
+ raise ValueError(
69
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
70
+ )
71
+ text_inputs = tokenizer_clip(
72
+ captions,
73
+ padding="max_length",
74
+ max_length=77,
75
+ truncation=True,
76
+ return_length=False,
77
+ return_overflowing_tokens=False,
78
+ return_tensors="pt",
79
+ )
80
+ text_input_ids_1 = text_inputs.input_ids
81
+
82
+ text_inputs = tokenizer_t5(
83
+ captions,
84
+ padding="max_length",
85
+ max_length=512,
86
+ truncation=True,
87
+ return_length=False,
88
+ return_overflowing_tokens=False,
89
+ return_tensors="pt",
90
+ )
91
+ text_input_ids_2 = text_inputs.input_ids
92
+ return text_input_ids_1, text_input_ids_2
93
+
94
+ def preprocess_train(examples):
95
+ _examples = {}
96
+
97
+ source_images = [Image.open(image).convert("RGB") for image in examples[source_column]]
98
+ target_images = [Image.open(image).convert("RGB") for image in examples[target_column]]
99
+
100
+ _examples["cond_pixel_values"] = [train_transforms(source) for source in source_images]
101
+ _examples["pixel_values"] = [train_transforms(image) for image in target_images]
102
+ _examples["token_ids_clip"], _examples["token_ids_t5"] = tokenize_prompt_clip_t5(examples)
103
+
104
+ return _examples
105
+
106
+ if accelerator is not None:
107
+ with accelerator.main_process_first():
108
+ train_dataset = dataset["train"].with_transform(preprocess_train)
109
+ else:
110
+ train_dataset = dataset["train"].with_transform(preprocess_train)
111
+
112
+ return train_dataset
113
+
114
+
115
+ def collate_fn(examples):
116
+ cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
117
+ cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
118
+ target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
119
+ target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
120
+ token_ids_clip = torch.stack([torch.tensor(example["token_ids_clip"]) for example in examples])
121
+ token_ids_t5 = torch.stack([torch.tensor(example["token_ids_t5"]) for example in examples])
122
+
123
+ return {
124
+ "cond_pixel_values": cond_pixel_values,
125
+ "pixel_values": target_pixel_values,
126
+ "text_ids_1": token_ids_clip,
127
+ "text_ids_2": token_ids_t5,
128
+ }
src/pipeline_pe_clone.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
7
+
8
+ from diffusers.image_processor import (VaeImageProcessor)
9
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
10
+ from diffusers.models.autoencoders import AutoencoderKL
11
+ from diffusers.models.transformers import FluxTransformer2DModel
12
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
13
+ from diffusers.utils import (
14
+ USE_PEFT_BACKEND,
15
+ is_torch_xla_available,
16
+ logging,
17
+ scale_lora_layers,
18
+ unscale_lora_layers,
19
+ )
20
+ from diffusers.utils.torch_utils import randn_tensor
21
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
22
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
23
+
24
+ if is_torch_xla_available():
25
+ import torch_xla.core.xla_model as xm
26
+
27
+ XLA_AVAILABLE = True
28
+ else:
29
+ XLA_AVAILABLE = False
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+ def calculate_shift(
34
+ image_seq_len,
35
+ base_seq_len: int = 256,
36
+ max_seq_len: int = 4096,
37
+ base_shift: float = 0.5,
38
+ max_shift: float = 1.16,
39
+ ):
40
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
41
+ b = base_shift - m * base_seq_len
42
+ mu = image_seq_len * m + b
43
+ return mu
44
+
45
+ def prepare_latent_image_ids_2(height, width, device, dtype):
46
+ latent_image_ids = torch.zeros(height//2, width//2, 3, device=device, dtype=dtype)
47
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height//2, device=device)[:, None] # y坐标
48
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width//2, device=device)[None, :] # x坐标
49
+ return latent_image_ids
50
+
51
+ def position_encoding_clone(batch_size, original_height, original_width, device, dtype):
52
+ latent_image_ids = prepare_latent_image_ids_2(original_height, original_width, device, dtype)
53
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
54
+ latent_image_ids = latent_image_ids.reshape(
55
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
56
+ )
57
+ cond_latent_image_ids = latent_image_ids
58
+ latent_image_ids = torch.concat([latent_image_ids, cond_latent_image_ids], dim=-2)
59
+ return latent_image_ids
60
+
61
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
62
+ def retrieve_latents(
63
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
64
+ ):
65
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
66
+ return encoder_output.latent_dist.sample(generator)
67
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
68
+ return encoder_output.latent_dist.mode()
69
+ elif hasattr(encoder_output, "latents"):
70
+ return encoder_output.latents
71
+ else:
72
+ raise AttributeError("Could not access latents of provided encoder_output")
73
+
74
+
75
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
76
+ def retrieve_timesteps(
77
+ scheduler,
78
+ num_inference_steps: Optional[int] = None,
79
+ device: Optional[Union[str, torch.device]] = None,
80
+ timesteps: Optional[List[int]] = None,
81
+ sigmas: Optional[List[float]] = None,
82
+ **kwargs,
83
+ ):
84
+ """
85
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
86
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
87
+
88
+ Args:
89
+ scheduler (`SchedulerMixin`):
90
+ The scheduler to get timesteps from.
91
+ num_inference_steps (`int`):
92
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
93
+ must be `None`.
94
+ device (`str` or `torch.device`, *optional*):
95
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
96
+ timesteps (`List[int]`, *optional*):
97
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
98
+ `num_inference_steps` and `sigmas` must be `None`.
99
+ sigmas (`List[float]`, *optional*):
100
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
101
+ `num_inference_steps` and `timesteps` must be `None`.
102
+
103
+ Returns:
104
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
105
+ second element is the number of inference steps.
106
+ """
107
+ if timesteps is not None and sigmas is not None:
108
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
109
+ if timesteps is not None:
110
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
111
+ if not accepts_timesteps:
112
+ raise ValueError(
113
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
114
+ f" timestep schedules. Please check whether you are using the correct scheduler."
115
+ )
116
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
117
+ timesteps = scheduler.timesteps
118
+ num_inference_steps = len(timesteps)
119
+ elif sigmas is not None:
120
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
121
+ if not accept_sigmas:
122
+ raise ValueError(
123
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
124
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
125
+ )
126
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
127
+ timesteps = scheduler.timesteps
128
+ num_inference_steps = len(timesteps)
129
+ else:
130
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
131
+ timesteps = scheduler.timesteps
132
+ return timesteps, num_inference_steps
133
+
134
+
135
+ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
136
+ r"""
137
+ The Flux pipeline for text-to-image generation.
138
+
139
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
140
+
141
+ Args:
142
+ transformer ([`FluxTransformer2DModel`]):
143
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
144
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
145
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
146
+ vae ([`AutoencoderKL`]):
147
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
148
+ text_encoder ([`CLIPTextModel`]):
149
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
150
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
151
+ text_encoder_2 ([`T5EncoderModel`]):
152
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
153
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
154
+ tokenizer (`CLIPTokenizer`):
155
+ Tokenizer of class
156
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
157
+ tokenizer_2 (`T5TokenizerFast`):
158
+ Second Tokenizer of class
159
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
160
+ """
161
+
162
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
163
+ _optional_components = []
164
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
165
+
166
+ def __init__(
167
+ self,
168
+ scheduler: FlowMatchEulerDiscreteScheduler,
169
+ vae: AutoencoderKL,
170
+ text_encoder: CLIPTextModel,
171
+ tokenizer: CLIPTokenizer,
172
+ text_encoder_2: T5EncoderModel,
173
+ tokenizer_2: T5TokenizerFast,
174
+ transformer: FluxTransformer2DModel,
175
+ ):
176
+ super().__init__()
177
+
178
+ self.register_modules(
179
+ vae=vae,
180
+ text_encoder=text_encoder,
181
+ text_encoder_2=text_encoder_2,
182
+ tokenizer=tokenizer,
183
+ tokenizer_2=tokenizer_2,
184
+ transformer=transformer,
185
+ scheduler=scheduler,
186
+ )
187
+ self.vae_scale_factor = (
188
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
189
+ )
190
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
191
+ self.tokenizer_max_length = (
192
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
193
+ )
194
+ self.default_sample_size = 64
195
+
196
+ def _get_t5_prompt_embeds(
197
+ self,
198
+ prompt: Union[str, List[str]] = None,
199
+ num_images_per_prompt: int = 1,
200
+ max_sequence_length: int = 512,
201
+ device: Optional[torch.device] = None,
202
+ dtype: Optional[torch.dtype] = None,
203
+ ):
204
+ device = device or self._execution_device
205
+ dtype = dtype or self.text_encoder.dtype
206
+
207
+ prompt = [prompt] if isinstance(prompt, str) else prompt
208
+ batch_size = len(prompt)
209
+
210
+ text_inputs = self.tokenizer_2(
211
+ prompt,
212
+ padding="max_length",
213
+ max_length=max_sequence_length,
214
+ truncation=True,
215
+ return_length=False,
216
+ return_overflowing_tokens=False,
217
+ return_tensors="pt",
218
+ )
219
+ text_input_ids = text_inputs.input_ids
220
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
221
+
222
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
223
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1])
224
+ logger.warning(
225
+ "The following part of your input was truncated because `max_sequence_length` is set to "
226
+ f" {max_sequence_length} tokens: {removed_text}"
227
+ )
228
+
229
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
230
+
231
+ dtype = self.text_encoder_2.dtype
232
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
233
+
234
+ _, seq_len, _ = prompt_embeds.shape
235
+
236
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
237
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
238
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
239
+
240
+ return prompt_embeds
241
+
242
+ def _get_clip_prompt_embeds(
243
+ self,
244
+ prompt: Union[str, List[str]],
245
+ num_images_per_prompt: int = 1,
246
+ device: Optional[torch.device] = None,
247
+ ):
248
+ device = device or self._execution_device
249
+
250
+ prompt = [prompt] if isinstance(prompt, str) else prompt
251
+ batch_size = len(prompt)
252
+
253
+ text_inputs = self.tokenizer(
254
+ prompt,
255
+ padding="max_length",
256
+ max_length=self.tokenizer_max_length,
257
+ truncation=True,
258
+ return_overflowing_tokens=False,
259
+ return_length=False,
260
+ return_tensors="pt",
261
+ )
262
+
263
+ text_input_ids = text_inputs.input_ids
264
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
265
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
266
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1])
267
+ logger.warning(
268
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
269
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
270
+ )
271
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
272
+
273
+ # Use pooled output of CLIPTextModel
274
+ prompt_embeds = prompt_embeds.pooler_output
275
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
276
+
277
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
278
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
279
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
280
+
281
+ return prompt_embeds
282
+
283
+ def encode_prompt(
284
+ self,
285
+ prompt: Union[str, List[str]],
286
+ prompt_2: Union[str, List[str]],
287
+ device: Optional[torch.device] = None,
288
+ num_images_per_prompt: int = 1,
289
+ prompt_embeds: Optional[torch.FloatTensor] = None,
290
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
291
+ max_sequence_length: int = 512,
292
+ lora_scale: Optional[float] = None,
293
+ ):
294
+ r"""
295
+
296
+ Args:
297
+ prompt (`str` or `List[str]`, *optional*):
298
+ prompt to be encoded
299
+ prompt_2 (`str` or `List[str]`, *optional*):
300
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
301
+ used in all text-encoders
302
+ device: (`torch.device`):
303
+ torch device
304
+ num_images_per_prompt (`int`):
305
+ number of images that should be generated per prompt
306
+ prompt_embeds (`torch.FloatTensor`, *optional*):
307
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
308
+ provided, text embeddings will be generated from `prompt` input argument.
309
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
310
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
311
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
312
+ lora_scale (`float`, *optional*):
313
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
314
+ """
315
+ device = device or self._execution_device
316
+
317
+ # set lora scale so that monkey patched LoRA
318
+ # function of text encoder can correctly access it
319
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
320
+ self._lora_scale = lora_scale
321
+
322
+ # dynamically adjust the LoRA scale
323
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
324
+ scale_lora_layers(self.text_encoder, lora_scale)
325
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
326
+ scale_lora_layers(self.text_encoder_2, lora_scale)
327
+
328
+ prompt = [prompt] if isinstance(prompt, str) else prompt
329
+
330
+ if prompt_embeds is None:
331
+ prompt_2 = prompt_2 or prompt
332
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
333
+
334
+ # We only use the pooled prompt output from the CLIPTextModel
335
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
336
+ prompt=prompt,
337
+ device=device,
338
+ num_images_per_prompt=num_images_per_prompt,
339
+ )
340
+ prompt_embeds = self._get_t5_prompt_embeds(
341
+ prompt=prompt_2,
342
+ num_images_per_prompt=num_images_per_prompt,
343
+ max_sequence_length=max_sequence_length,
344
+ device=device,
345
+ )
346
+
347
+ if self.text_encoder is not None:
348
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
349
+ # Retrieve the original scale by scaling back the LoRA layers
350
+ unscale_lora_layers(self.text_encoder, lora_scale)
351
+
352
+ if self.text_encoder_2 is not None:
353
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
354
+ # Retrieve the original scale by scaling back the LoRA layers
355
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
356
+
357
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
358
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
359
+
360
+ return prompt_embeds, pooled_prompt_embeds, text_ids
361
+
362
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
363
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
364
+ if isinstance(generator, list):
365
+ image_latents = [
366
+ retrieve_latents(self.vae.encode(image[i: i + 1]), generator=generator[i])
367
+ for i in range(image.shape[0])
368
+ ]
369
+ image_latents = torch.cat(image_latents, dim=0)
370
+ else:
371
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
372
+
373
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
374
+
375
+ return image_latents
376
+
377
+ def check_inputs(
378
+ self,
379
+ prompt,
380
+ prompt_2,
381
+ height,
382
+ width,
383
+ prompt_embeds=None,
384
+ pooled_prompt_embeds=None,
385
+ callback_on_step_end_tensor_inputs=None,
386
+ max_sequence_length=None,
387
+ ):
388
+ if height % 8 != 0 or width % 8 != 0:
389
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
390
+
391
+ if callback_on_step_end_tensor_inputs is not None and not all(
392
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
393
+ ):
394
+ raise ValueError(
395
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
396
+ )
397
+
398
+ if prompt is not None and prompt_embeds is not None:
399
+ raise ValueError(
400
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
401
+ " only forward one of the two."
402
+ )
403
+ elif prompt_2 is not None and prompt_embeds is not None:
404
+ raise ValueError(
405
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
406
+ " only forward one of the two."
407
+ )
408
+ elif prompt is None and prompt_embeds is None:
409
+ raise ValueError(
410
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
411
+ )
412
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
413
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
414
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
415
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
416
+
417
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
418
+ raise ValueError(
419
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
420
+ )
421
+
422
+ if max_sequence_length is not None and max_sequence_length > 512:
423
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
424
+
425
+ @staticmethod
426
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
427
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
428
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
429
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
430
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
431
+ latent_image_ids = latent_image_ids.reshape(
432
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
433
+ )
434
+ return latent_image_ids.to(device=device, dtype=dtype)
435
+
436
+ @staticmethod
437
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
438
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
439
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
440
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
441
+ return latents
442
+
443
+ @staticmethod
444
+ def _unpack_latents(latents, height, width, vae_scale_factor):
445
+ batch_size, num_patches, channels = latents.shape
446
+
447
+ height = height // vae_scale_factor
448
+ width = width // vae_scale_factor
449
+
450
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
451
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
452
+
453
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
454
+
455
+ return latents
456
+
457
+ def enable_vae_slicing(self):
458
+ r"""
459
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
460
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
461
+ """
462
+ self.vae.enable_slicing()
463
+
464
+ def disable_vae_slicing(self):
465
+ r"""
466
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
467
+ computing decoding in one step.
468
+ """
469
+ self.vae.disable_slicing()
470
+
471
+ def enable_vae_tiling(self):
472
+ r"""
473
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
474
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
475
+ processing larger images.
476
+ """
477
+ self.vae.enable_tiling()
478
+
479
+ def disable_vae_tiling(self):
480
+ r"""
481
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
482
+ computing decoding in one step.
483
+ """
484
+ self.vae.disable_tiling()
485
+
486
+ def prepare_latents(
487
+ self,
488
+ batch_size,
489
+ num_channels_latents,
490
+ height,
491
+ width,
492
+ dtype,
493
+ device,
494
+ generator,
495
+ latents=None,
496
+ condition_image=None,
497
+ ):
498
+ height = 2 * (int(height) // self.vae_scale_factor)
499
+ width = 2 * (int(width) // self.vae_scale_factor)
500
+
501
+ shape = (batch_size, num_channels_latents, height, width) # 1 16 106 80
502
+
503
+ if latents is not None:
504
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
505
+ return latents.to(device=device, dtype=dtype), latent_image_ids
506
+
507
+ if isinstance(generator, list) and len(generator) != batch_size:
508
+ raise ValueError(
509
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
510
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
511
+ )
512
+ if condition_image is not None:
513
+ condition_image = condition_image.to(device=device, dtype=dtype)
514
+ image_latents = self._encode_vae_image(image=condition_image, generator=generator)
515
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
516
+ # expand init_latents for batch_size
517
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
518
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
519
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
520
+ raise ValueError(
521
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
522
+ )
523
+ else:
524
+ image_latents = torch.cat([image_latents], dim=0)
525
+
526
+ # import pdb; pdb.set_trace()
527
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
528
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
529
+ cond_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
530
+ latents = torch.concat([latents, cond_latents], dim=-2)
531
+
532
+ latent_image_ids = position_encoding_clone(batch_size, height, width, device, dtype) # add position
533
+
534
+ mask1 = torch.ones(shape, device=device, dtype=dtype)
535
+ mask2 = torch.zeros(shape, device=device, dtype=dtype)
536
+ mask1 = self._pack_latents(mask1, batch_size, num_channels_latents, height, width) # 1 4096 64
537
+ mask2 = self._pack_latents(mask2, batch_size, num_channels_latents, height, width) # 1 4096 64
538
+ mask = torch.concat([mask1, mask2], dim=-2)
539
+ return latents, latent_image_ids, mask, cond_latents
540
+
541
+ @property
542
+ def guidance_scale(self):
543
+ return self._guidance_scale
544
+
545
+ @property
546
+ def joint_attention_kwargs(self):
547
+ return self._joint_attention_kwargs
548
+
549
+ @property
550
+ def num_timesteps(self):
551
+ return self._num_timesteps
552
+
553
+ @property
554
+ def interrupt(self):
555
+ return self._interrupt
556
+
557
+ @torch.no_grad()
558
+ def __call__(
559
+ self,
560
+ prompt: Union[str, List[str]] = None,
561
+ prompt_2: Optional[Union[str, List[str]]] = None,
562
+ height: Optional[int] = None,
563
+ width: Optional[int] = None,
564
+ num_inference_steps: int = 28,
565
+ timesteps: List[int] = None,
566
+ guidance_scale: float = 3.5,
567
+ num_images_per_prompt: Optional[int] = 1,
568
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
569
+ latents: Optional[torch.FloatTensor] = None,
570
+ prompt_embeds: Optional[torch.FloatTensor] = None,
571
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
572
+ output_type: Optional[str] = "pil",
573
+ return_dict: bool = True,
574
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
575
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
576
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
577
+ max_sequence_length: int = 512,
578
+ condition_image=None,
579
+ ):
580
+ height = height or self.default_sample_size * self.vae_scale_factor
581
+ width = width or self.default_sample_size * self.vae_scale_factor
582
+
583
+ # 1. Check inputs. Raise error if not correct
584
+ self.check_inputs(
585
+ prompt,
586
+ prompt_2,
587
+ height,
588
+ width,
589
+ prompt_embeds=prompt_embeds,
590
+ pooled_prompt_embeds=pooled_prompt_embeds,
591
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
592
+ max_sequence_length=max_sequence_length,
593
+ )
594
+
595
+ self._guidance_scale = guidance_scale
596
+ self._joint_attention_kwargs = joint_attention_kwargs
597
+ self._interrupt = False
598
+
599
+ condition_image = self.image_processor.preprocess(condition_image, height=height, width=width)
600
+ condition_image = condition_image.to(dtype=torch.float32)
601
+
602
+ # 2. Define call parameters
603
+ if prompt is not None and isinstance(prompt, str):
604
+ batch_size = 1
605
+ elif prompt is not None and isinstance(prompt, list):
606
+ batch_size = len(prompt)
607
+ else:
608
+ batch_size = prompt_embeds.shape[0]
609
+
610
+ device = self._execution_device
611
+
612
+ lora_scale = (
613
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
614
+ )
615
+ (
616
+ prompt_embeds,
617
+ pooled_prompt_embeds,
618
+ text_ids,
619
+ ) = self.encode_prompt(
620
+ prompt=prompt,
621
+ prompt_2=prompt_2,
622
+ prompt_embeds=prompt_embeds,
623
+ pooled_prompt_embeds=pooled_prompt_embeds,
624
+ device=device,
625
+ num_images_per_prompt=num_images_per_prompt,
626
+ max_sequence_length=max_sequence_length,
627
+ lora_scale=lora_scale,
628
+ )
629
+
630
+ # 4. Prepare latent variables
631
+ num_channels_latents = self.transformer.config.in_channels // 4 # 16
632
+ latents, latent_image_ids, mask, cond_latents = self.prepare_latents(
633
+ batch_size * num_images_per_prompt,
634
+ num_channels_latents,
635
+ height,
636
+ width,
637
+ prompt_embeds.dtype,
638
+ device,
639
+ generator,
640
+ latents,
641
+ condition_image
642
+ )
643
+ clean_latents = latents.clone()
644
+
645
+ # 5. Prepare timesteps
646
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
647
+ image_seq_len = latents.shape[1]
648
+ mu = calculate_shift(
649
+ image_seq_len,
650
+ self.scheduler.config.base_image_seq_len,
651
+ self.scheduler.config.max_image_seq_len,
652
+ self.scheduler.config.base_shift,
653
+ self.scheduler.config.max_shift,
654
+ )
655
+ timesteps, num_inference_steps = retrieve_timesteps(
656
+ self.scheduler,
657
+ num_inference_steps,
658
+ device,
659
+ timesteps,
660
+ sigmas,
661
+ mu=mu,
662
+ )
663
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
664
+ self._num_timesteps = len(timesteps)
665
+
666
+ # handle guidance
667
+ if self.transformer.config.guidance_embeds:
668
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
669
+ guidance = guidance.expand(latents.shape[0])
670
+ else:
671
+ guidance = None
672
+
673
+ # 6. Denoising loop
674
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
675
+ for i, t in enumerate(timesteps):
676
+ if self.interrupt:
677
+ continue
678
+
679
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
680
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
681
+ noise_pred = self.transformer(
682
+ hidden_states=latents, # 1 4096 64
683
+ timestep=timestep / 1000,
684
+ guidance=guidance,
685
+ pooled_projections=pooled_prompt_embeds,
686
+ encoder_hidden_states=prompt_embeds,
687
+ txt_ids=text_ids,
688
+ img_ids=latent_image_ids,
689
+ joint_attention_kwargs=self.joint_attention_kwargs,
690
+ return_dict=False,
691
+ )[0]
692
+
693
+ # compute the previous noisy sample x_t -> x_t-1
694
+ latents_dtype = latents.dtype
695
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
696
+ latents = latents * mask + clean_latents * (1 - mask)
697
+
698
+ if latents.dtype != latents_dtype:
699
+ if torch.backends.mps.is_available():
700
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
701
+ latents = latents.to(latents_dtype)
702
+
703
+ if callback_on_step_end is not None:
704
+ callback_kwargs = {}
705
+ for k in callback_on_step_end_tensor_inputs:
706
+ callback_kwargs[k] = locals()[k]
707
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
708
+
709
+ latents = callback_outputs.pop("latents", latents)
710
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
711
+
712
+ # call the callback, if provided
713
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
714
+ progress_bar.update()
715
+
716
+ if XLA_AVAILABLE:
717
+ xm.mark_step()
718
+
719
+ if output_type == "latent":
720
+ image = latents
721
+
722
+ else:
723
+ latents = self._unpack_latents(latents[:,:latents.shape[-2]-cond_latents.shape[-2],:], height, width, self.vae_scale_factor)
724
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
725
+ image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
726
+ image = self.image_processor.postprocess(image, output_type=output_type)
727
+
728
+ # Offload all models
729
+ self.maybe_free_model_hooks()
730
+
731
+ if not return_dict:
732
+ return (image,)
733
+
734
+ return FluxPipelineOutput(images=image)
src/prompt_helper.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def load_text_encoders(args, class_one, class_two):
5
+ text_encoder_one = class_one.from_pretrained(
6
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
7
+ )
8
+ text_encoder_two = class_two.from_pretrained(
9
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
10
+ )
11
+ return text_encoder_one, text_encoder_two
12
+
13
+
14
+ def tokenize_prompt(tokenizer, prompt, max_sequence_length):
15
+ text_inputs = tokenizer(
16
+ prompt,
17
+ padding="max_length",
18
+ max_length=max_sequence_length,
19
+ truncation=True,
20
+ return_length=False,
21
+ return_overflowing_tokens=False,
22
+ return_tensors="pt",
23
+ )
24
+ text_input_ids = text_inputs.input_ids
25
+ return text_input_ids
26
+
27
+
28
+ def tokenize_prompt_clip(tokenizer, prompt):
29
+ text_inputs = tokenizer(
30
+ prompt,
31
+ padding="max_length",
32
+ max_length=77,
33
+ truncation=True,
34
+ return_length=False,
35
+ return_overflowing_tokens=False,
36
+ return_tensors="pt",
37
+ )
38
+ text_input_ids = text_inputs.input_ids
39
+ return text_input_ids
40
+
41
+
42
+ def tokenize_prompt_t5(tokenizer, prompt):
43
+ text_inputs = tokenizer(
44
+ prompt,
45
+ padding="max_length",
46
+ max_length=512,
47
+ truncation=True,
48
+ return_length=False,
49
+ return_overflowing_tokens=False,
50
+ return_tensors="pt",
51
+ )
52
+ text_input_ids = text_inputs.input_ids
53
+ return text_input_ids
54
+
55
+
56
+ def _encode_prompt_with_t5(
57
+ text_encoder,
58
+ tokenizer,
59
+ max_sequence_length=512,
60
+ prompt=None,
61
+ num_images_per_prompt=1,
62
+ device=None,
63
+ text_input_ids=None,
64
+ ):
65
+ prompt = [prompt] if isinstance(prompt, str) else prompt
66
+ batch_size = len(prompt)
67
+
68
+ if tokenizer is not None:
69
+ text_inputs = tokenizer(
70
+ prompt,
71
+ padding="max_length",
72
+ max_length=max_sequence_length,
73
+ truncation=True,
74
+ return_length=False,
75
+ return_overflowing_tokens=False,
76
+ return_tensors="pt",
77
+ )
78
+ text_input_ids = text_inputs.input_ids
79
+ else:
80
+ if text_input_ids is None:
81
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
82
+
83
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
84
+
85
+ dtype = text_encoder.dtype
86
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
87
+
88
+ _, seq_len, _ = prompt_embeds.shape
89
+
90
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
91
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
92
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
93
+
94
+ return prompt_embeds
95
+
96
+
97
+ def _encode_prompt_with_clip(
98
+ text_encoder,
99
+ tokenizer,
100
+ prompt: str,
101
+ device=None,
102
+ text_input_ids=None,
103
+ num_images_per_prompt: int = 1,
104
+ ):
105
+ prompt = [prompt] if isinstance(prompt, str) else prompt
106
+ batch_size = len(prompt)
107
+
108
+ if tokenizer is not None:
109
+ text_inputs = tokenizer(
110
+ prompt,
111
+ padding="max_length",
112
+ max_length=77,
113
+ truncation=True,
114
+ return_overflowing_tokens=False,
115
+ return_length=False,
116
+ return_tensors="pt",
117
+ )
118
+
119
+ text_input_ids = text_inputs.input_ids
120
+ else:
121
+ if text_input_ids is None:
122
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
123
+
124
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
125
+
126
+ # Use pooled output of CLIPTextModel
127
+ prompt_embeds = prompt_embeds.pooler_output
128
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
129
+
130
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
131
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
132
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
133
+
134
+ return prompt_embeds
135
+
136
+
137
+ def encode_prompt(
138
+ text_encoders,
139
+ tokenizers,
140
+ prompt: str,
141
+ max_sequence_length,
142
+ device=None,
143
+ num_images_per_prompt: int = 1,
144
+ text_input_ids_list=None,
145
+ ):
146
+ prompt = [prompt] if isinstance(prompt, str) else prompt
147
+ dtype = text_encoders[0].dtype
148
+
149
+ pooled_prompt_embeds = _encode_prompt_with_clip(
150
+ text_encoder=text_encoders[0],
151
+ tokenizer=tokenizers[0],
152
+ prompt=prompt,
153
+ device=device if device is not None else text_encoders[0].device,
154
+ num_images_per_prompt=num_images_per_prompt,
155
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
156
+ )
157
+
158
+ prompt_embeds = _encode_prompt_with_t5(
159
+ text_encoder=text_encoders[1],
160
+ tokenizer=tokenizers[1],
161
+ max_sequence_length=max_sequence_length,
162
+ prompt=prompt,
163
+ num_images_per_prompt=num_images_per_prompt,
164
+ device=device if device is not None else text_encoders[1].device,
165
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
166
+ )
167
+
168
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
169
+
170
+ return prompt_embeds, pooled_prompt_embeds, text_ids
171
+
172
+
173
+ def encode_token_ids(text_encoders, tokens, accelerator, num_images_per_prompt=1, device=None):
174
+ text_encoder_clip = text_encoders[0]
175
+ text_encoder_t5 = text_encoders[1]
176
+ tokens_clip, tokens_t5 = tokens[0], tokens[1]
177
+ batch_size = tokens_clip.shape[0]
178
+
179
+ if device == "cpu":
180
+ device = "cpu"
181
+ else:
182
+ device = accelerator.device
183
+
184
+ # clip
185
+ prompt_embeds = text_encoder_clip(tokens_clip.to(device), output_hidden_states=False)
186
+ # Use pooled output of CLIPTextModelpreprocess_train
187
+ prompt_embeds = prompt_embeds.pooler_output
188
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
189
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
190
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
191
+ pooled_prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
192
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
193
+
194
+ # t5
195
+ prompt_embeds = text_encoder_t5(tokens_t5.to(device))[0]
196
+ dtype = text_encoder_t5.dtype
197
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=accelerator.device)
198
+ _, seq_len, _ = prompt_embeds.shape
199
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
200
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
201
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
202
+
203
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=dtype)
204
+
205
+ return prompt_embeds, pooled_prompt_embeds, text_ids