File size: 3,364 Bytes
323deb3
0a82683
 
 
 
762146c
0b2e184
76b4c2e
1ec1764
 
e02c582
 
0a82683
 
 
c07cc81
 
 
e02c582
27d1cb4
6c59011
e02c582
 
0a82683
c07cc81
0a82683
5334972
47f6862
0a82683
6c59011
 
 
 
0a82683
 
 
 
 
 
 
 
 
 
 
 
 
6c59011
5bf7bea
0a82683
6c59011
0a82683
13dc11f
 
 
 
0b82d35
 
 
 
504a8f8
 
2650b50
 
0b82d35
 
13dc11f
 
1ec1764
0a82683
 
 
1ec1764
6c59011
 
0a82683
1ec1764
 
0a82683
 
 
1e7b1a9
0a82683
 
 
 
917fecc
13dc11f
0a82683
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import spaces
import gradio as gr
import torch
from PIL import Image
from src.pipeline_pe_clone import FluxPipeline
import os
import huggingface_hub
huggingface_hub.login(os.getenv('HF_TOKEN_FLUX2'))
# Load default image from assets as an example
default_image = Image.open("assets/1.png")
pipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16,
    ).to('cuda')

pipeline.load_lora_weights("nicolaus-huang/PhotoDoodle", weight_name="pretrain.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()

@spaces.GPU
def generate_image(image, prompt, guidance_scale, num_steps, lora_name):
    # Load the model
    
    # Load and fuse base LoRA weights
    
    # Load selected LoRA effect if not using the pretrained base model
    if lora_name != 'pretrained':
        pipeline.load_lora_weights("nicolaus-huang/PhotoDoodle", weight_name=f"{lora_name}.safetensors")
    
    height=768
    width=512

    
    # Prepare the input image
    condition_image = image.resize((height, width)).convert("RGB")
    
    # Generate the output image
    result = pipeline(
        prompt=prompt,
        condition_image=condition_image,
        height=height,
        width=width,
        guidance_scale=guidance_scale,
        num_inference_steps=num_steps,
        max_sequence_length=512
    ).images[0]

    final_image  =  result.resize(image.size)
    
    return final_image

# Define examples to be shown within the Gradio interface
examples = [
    # Each example is a list corresponding to the inputs:
    # [Input Image, Prompt, Guidance Scale, Number of Steps, LoRA Name]
    ["assets/1.png", "add a halo and wings for the cat by sksmagiceffects", 3.5, 20, "sksmagiceffects"],
    ["assets/1.png", "add a huge by sksmonstercalledlulu", 3.5, 20, "sksmonstercalledlulu"],
    ["assets/1.png", "Add colorful magical effects and flowing color blocks to the cat  by skspaintingeffects", 3.5, 20, "skspaintingeffects"],
    ["assets/1.png", "Add hand-drawn lines and star decorations.", 3.5, 20, "sksedgeeffect"],
    ["assets/hmgoepprod (1).jpeg", "Add pink star effect.", 3.5, 20, "sksmagiceffects"],
    ["assets/hmgoepprod (3).jpeg", "Add pink star effect.", 3.5, 20, "sksmagiceffects"],
    
    
    
    
]

# Create Gradio interface with sliders for numeric inputs
iface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Image(label="Input Image", type="pil", value=default_image),
        # gr.Slider(label="Height", value=768, minimum=256, maximum=1024, step=64),
        # gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64),
        gr.Textbox(label="Prompt", value="add a halo and wings for the cat by sksmagiceffects"),
        gr.Slider(label="Guidance Scale", value=3.5, minimum=1.0, maximum=10.0, step=0.1),
        gr.Slider(label="Number of Steps", value=20, minimum=1, maximum=100, step=1),
        gr.Dropdown(
            label="LoRA Name", 
            choices=["pretrained", "sksmagiceffects", "sksmonstercalledlulu", 
                     "skspaintingeffects", "sksedgeeffect", "skscatooneffect",'skscloudsketch'],
            value="sksmagiceffects"
        )
    ],
    outputs=gr.Image(label="Output Image", type="pil"),
    title="PhotoDoodle-Image-Edit with LoRA",
    examples=examples
)

if __name__ == "__main__":
    iface.launch()