|
--- |
|
base_model: |
|
- black-forest-labs/FLUX.1-dev |
|
- black-forest-labs/FLUX.1-schnell |
|
language: |
|
- en |
|
license: other |
|
license_name: flux-1-dev-non-commercial-license |
|
license_link: LICENSE.md |
|
tags: |
|
- merge |
|
- flux |
|
--- |
|
|
|
# Aryanne/flux_swap |
|
This model is a merge of [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) and [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell). |
|
|
|
But different than others methods here the values in the tensors are not changed but substitute in a checkboard pattern with the values of FLUX.1-schnell, so ~50% of each is present here.(if my code is right) |
|
|
|
```python |
|
from diffusers import FluxTransformer2DModel |
|
from huggingface_hub import snapshot_download |
|
from accelerate import init_empty_weights |
|
from diffusers.models.model_loading_utils import load_model_dict_into_meta |
|
import safetensors.torch |
|
import glob |
|
import torch |
|
import gc |
|
|
|
|
|
|
|
|
|
with init_empty_weights(): |
|
config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer") |
|
model = FluxTransformer2DModel.from_config(config) |
|
|
|
dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", allow_patterns="transformer/*") |
|
schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*") |
|
|
|
dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors")) |
|
schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors")) |
|
|
|
def swapping_method(base, x, parameters): |
|
def swap_values(shape, n, base, x): |
|
if x.dim() == 2: |
|
rows, cols = shape |
|
rows_range = torch.arange(rows).view(-1, 1) |
|
cols_range = torch.arange(cols).view(1, -1) |
|
mask = ((rows_range + cols_range) % n == 0).to(base.device.type).bool() |
|
x = torch.where(mask, x, base) |
|
else: |
|
rows_range = torch.arange(shape[0]) |
|
mask = ((rows_range) % n == 0).to(base.device.type).bool() |
|
x = torch.where(mask, x, base) |
|
return x |
|
|
|
def rand_mask(base, x, percent, seed=None): |
|
oldseed = torch.seed() |
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
random = torch.rand(base.shape) |
|
mask = (random <= percent).to(base.device.type).bool() |
|
del random |
|
torch.manual_seed(oldseed) |
|
x = torch.where(mask, x, base) |
|
return x |
|
|
|
|
|
if x.device.type == "cpu": |
|
x = x.to(torch.bfloat16) |
|
base = base.to(torch.bfloat16) |
|
|
|
diagonal_offset = None |
|
diagonal_offset = parameters.get('diagonal_offset') |
|
random_mask = parameters.get('random_mask') |
|
random_mask_seed = parameters.get('random_mask_seed') |
|
random_mask_seed = int(random_mask_seed) if random_mask_seed is not None else random_mask_seed |
|
|
|
assert (diagonal_offset is not None) and (diagonal_offset % 1 == 0) and (diagonal_offset >= 2), "The diagonal_offset must be an integer greater than or equal to 2." |
|
|
|
if random_mask != 0.0: |
|
assert (random_mask is not None) and (random_mask < 1.0) and (random_mask > 0.0) , "The random_mask parameter can't be empty, 0, 1, or None, it must be a number between 0 and 1." |
|
assert random_mask_seed is None or (isinstance(random_mask_seed, int) and random_mask_seed % 1 == 0), "The random_mask_seed parameter must be None or an integer, None is a random seed." |
|
x = rand_mask(base, x, random_mask, random_mask_seed) |
|
|
|
else: |
|
if parameters.get('invert_offset') == False: |
|
x = swap_values(x.shape, diagonal_offset, base, x) |
|
else: |
|
x = swap_values(x.shape, diagonal_offset, x, base) |
|
|
|
del base |
|
return x |
|
|
|
parameters = { |
|
'diagonal_offset': 2, |
|
'random_mask': False, |
|
'invert_offset': False, |
|
# 'random_mask_seed': "899557" |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
merged_state_dict = {} |
|
guidance_state_dict = {} |
|
|
|
for i in range(len((dev_shards))): |
|
state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i]) |
|
state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i]) |
|
|
|
keys = list(state_dict_dev_temp.keys()) |
|
for k in keys: |
|
if "guidance" not in k: |
|
merged_state_dict[k] = swapping_method(state_dict_dev_temp.pop(k),state_dict_schnell_temp.pop(k), parameters) |
|
else: |
|
guidance_state_dict[k] = state_dict_dev_temp.pop(k) |
|
|
|
if len(state_dict_dev_temp) > 0: |
|
raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.") |
|
if len(state_dict_schnell_temp) > 0: |
|
raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.") |
|
|
|
|
|
|
|
|
|
merged_state_dict.update(guidance_state_dict) |
|
load_model_dict_into_meta(model, merged_state_dict) |
|
|
|
model.to(torch.bfloat16).save_pretrained("merged-flux") |
|
``` |
|
|
|
Used a piece of this code from [mergekit](https://github.com/Ar57m/mergekit/tree/swapping) |
|
|
|
Thanks SayakPaul for your code which helped me do this merge. |