File size: 5,629 Bytes
3175ce6
 
 
9235b7f
47500db
9235b7f
371bdca
1893705
 
 
47500db
1893705
 
47500db
50bfc5a
 
 
 
a7bee92
 
 
 
 
 
50bfc5a
c97fcf1
 
 
a7bee92
371bdca
50bfc5a
 
9235b7f
a7bee92
9235b7f
 
 
d033e91
ab0b470
6ef3309
50bfc5a
 
 
 
6ef3309
b3a0761
9235b7f
 
e6730cb
 
9235b7f
fc8037f
3175ce6
fc8037f
9235b7f
47500db
9235b7f
3175ce6
 
 
 
7ceb780
3175ce6
3d23955
7ceb780
3d23955
 
 
3175ce6
7ceb780
3d23955
 
3175ce6
3d23955
 
3175ce6
7ceb780
3d23955
3175ce6
7ceb780
3d23955
7ceb780
3d23955
 
 
 
 
7ceb780
 
 
 
 
 
3175ce6
3d23955
 
 
 
3175ce6
47500db
9235b7f
114a69f
 
9235b7f
114a69f
 
 
9235b7f
9cda2f8
 
9235b7f
db52967
 
9cda2f8
47500db
9235b7f
 
 
 
 
1893705
 
 
 
3175ce6
9235b7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import cv2
import numpy as np
from PIL import Image
import os
import spaces
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
import shlex
import subprocess

subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
subprocess.run(shlex.split('pip install flash-attn --no-build-isolation'),
               env=os.environ | {'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"})
subprocess.run(shlex.split('pip install scepter --no-deps'))

def resolve_hf_path(path):
    if isinstance(path, str) and path.startswith("hf://"):
        parts = path[len("hf://"):].split("@")
        if len(parts) == 1:
            repo_id = parts[0]
            filename = None
        elif len(parts) == 2:
            repo_id, filename = parts
        else:
            raise ValueError(f"Invalid HF URI format: {path}")
        token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
        if token is None:
            raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable not set!")
        # If filename is provided, download that file; otherwise, download the whole repo snapshot.
        local_path = hf_hub_download(repo_id=repo_id, filename=filename, token=token) if filename else snapshot_download(repo_id=repo_id, token=token)
        return local_path
    return path

os.environ["FLUX_FILL_PATH"] = "hf://black-forest-labs/FLUX.1-Fill-dev"
os.environ["PORTRAIT_MODEL_PATH"] = "ms://iic/ACE_Plus@portrait/comfyui_portrait_lora64.safetensors"
os.environ["SUBJECT_MODEL_PATH"] = "ms://iic/ACE_Plus@subject/comfyui_subject_lora16.safetensors"
os.environ["LOCAL_MODEL_PATH"] = "ms://iic/ACE_Plus@local_editing/comfyui_local_lora16.safetensors"
os.environ["ACE_PLUS_FFT_MODEL"] = "hf://ali-vilab/ACE_Plus@ace_plus_fft.safetensors"

flux_full = resolve_hf_path(os.environ["FLUX_FILL_PATH"])
ace_plus_fft_model_path = resolve_hf_path(os.environ["ACE_PLUS_FFT_MODEL"])

# Update the environment variables with the resolved local file paths.
os.environ["ACE_PLUS_FFT_MODEL"] = ace_plus_fft_model_path
os.environ["FLUX_FILL_PATH"] = flux_full

from inference.ace_plus_inference import ACEInference
from scepter.modules.utils.config import Config
from modules.flux import FluxMRModiACEPlus
from inference.registry import INFERENCES


config_path = os.path.join("config", "ace_plus_fft.yaml")
cfg = Config(load=True, cfg_file=config_path)
# Instantiate the ACEInference object.
ace_infer = INFERENCES.build(cfg) # ACEInference(cfg)

def create_face_mask(pil_image):
    """
    Create a binary mask (PIL Image) from a PIL image by detecting the face region.
    The mask will be white (255) on the detected face area and black (0) elsewhere.
    An ellipse is used to better match the shape of a face.
    """
    try:
        # Convert the PIL image to a numpy array in RGB format
        image_np = np.array(pil_image.convert("RGB"))
        # Convert to grayscale for face detection
        gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)

        # Load the Haar cascade for face detection
        cascade_path = cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
        face_cascade = cv2.CascadeClassifier(cascade_path)

        # Detect faces in the image
        faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5)

        # Create an empty mask with the same dimensions as the grayscale image
        mask = np.zeros_like(gray, dtype=np.uint8)

        # For each detected face, draw an ellipse instead of a rectangle
        for (x, y, w, h) in faces:
            # Optionally expand the bounding box slightly for a better fit
            padding = 0.2
            x1 = max(0, int(x - w * padding))
            y1 = max(0, int(y - h * padding))
            x2 = min(gray.shape[1], int(x + w * (1 + padding)))
            y2 = min(gray.shape[0], int(y + h * (1 + padding)))

            # Calculate the center and axes for the ellipse
            center = (int((x1 + x2) / 2), int((y1 + y2) / 2))
            axes = (int((x2 - x1) / 2), int((y2 - y1) / 2))
            # Draw a filled ellipse (white) on the mask
            cv2.ellipse(mask, center, axes, 0, 0, 360, 255, -1)

        return Image.fromarray(mask)
    except Exception as e:
        print(f"Error: {e}")
        raise ValueError('A very specific bad thing happened.')

@spaces.GPU(duration=80)
def face_swap_app(target_img, face_img):
    if target_img is None or face_img is None:
        raise ValueError("Both a target image and a face image must be provided.")

    # (Optional) Ensure images are in RGB
    target_img = target_img.convert("RGB")
    face_img = face_img.convert("RGB")

    edit_mask = create_face_mask(face_img)

    output_img, edit_image, change_image, mask, seed = ace_infer(
        reference_image=target_img,
        edit_image=face_img,
        edit_mask=edit_mask,
        prompt="maintain the facial features as much as possible",
        output_height=1024,
        output_width=1024,
        sampler='flow_euler',
        sample_steps=28,
        guide_scale=50,
        repainting_scale=1.0,
        use_change=True,
        keep_pixels=True,
        keep_pixels_rate=0.8,
        seed=-1
    )
    return output_img

# Create the Gradio interface.
iface = gr.Interface(
    fn=face_swap_app,
    inputs=[
        gr.Image(type="pil", label="Target Image"),
        gr.Image(type="pil", label="Face Image")
    ],
    outputs=gr.Image(type="pil", label="Swapped Face Output"),
    title="ACE++ Face Swap Demo",
    description="Upload a target image and a face image to swap the face using the ACE++ model."
)

if __name__ == "__main__":
    iface.launch()