FLUX.1-merged

This repository provides the merged params for black-forest-labs/FLUX.1-dev and black-forest-labs/FLUX.1-schnell. Please be aware of the licenses of both the models before using the params commercially.

Dev (50 steps) Dev (4 steps) Dev + Schnell (4 steps)
Dev 50 Steps Dev 4 Steps Dev + Schnell 4 Steps

Sub-memory-efficient merging code

from diffusers import FluxTransformer2DModel
from huggingface_hub import snapshot_download
from accelerate import init_empty_weights
from diffusers.models.model_loading_utils import load_model_dict_into_meta
import safetensors.torch
import glob
import torch


with init_empty_weights():
    config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer")
    model = FluxTransformer2DModel.from_config(config)

dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", allow_patterns="transformer/*")
schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*")

dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors"))
schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors"))

merged_state_dict = {}
guidance_state_dict = {}

for i in range(len((dev_shards))):
    state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i])
    state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i])

    keys = list(state_dict_dev_temp.keys())
    for k in keys:
        if "guidance" not in k:
            merged_state_dict[k] = (state_dict_dev_temp.pop(k) + state_dict_schnell_temp.pop(k)) / 2
        else:
            guidance_state_dict[k] = state_dict_dev_temp.pop(k)

    if len(state_dict_dev_temp) > 0:
        raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.")
    if len(state_dict_schnell_temp) > 0:
        raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.")

merged_state_dict.update(guidance_state_dict)
load_model_dict_into_meta(model, merged_state_dict)

model.to(torch.bfloat16).save_pretrained("merged-flux")

Inference code

from diffusers import FluxPipeline
import torch

pipeline = FluxPipeline.from_pretrained(
    "sayakpaul/FLUX.1-merged", torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline(
    prompt="a tiny astronaut hatching from an egg on the moon",
    guidance_scale=3.5,
    num_inference_steps=4,
    height=880,
    width=1184,
    max_sequence_length=512,
    generator=torch.manual_seed(0),
).images[0]
image.save("merged_flux.png")

Documentation

Downloads last month
943
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for sayakpaul/FLUX.1-merged

Spaces using sayakpaul/FLUX.1-merged 23