Learner commited on
Commit
79ded7b
·
1 Parent(s): 0174344

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import requests
4
+ from PIL import Image
5
+ from pathlib import Path
6
+ from io import BytesIO
7
+
8
+ # Diffusers
9
+ from diffusers import StableDiffusionImg2ImgPipeline
10
+ from diffusers import (
11
+ FlaxStableDiffusionControlNetPipeline,
12
+ FlaxControlNetModel,
13
+ FlaxStableDiffusionPipeline,
14
+ )
15
+ from diffusers import ControlNetModel
16
+ from diffusers.utils import load_image
17
+
18
+ # Pytorch
19
+ import torch
20
+
21
+ # Numpy
22
+ import numpy as np
23
+
24
+ # Jax
25
+ import jax
26
+ import jax.numpy as jnp
27
+ from jax import pmap
28
+
29
+ # Flax
30
+ import flax
31
+ from flax.jax_utils import replicate
32
+ from flax.training.common_utils import shard
33
+
34
+
35
+ def create_key(seed=0):
36
+ return jax.random.PRNGKey(seed)
37
+
38
+
39
+ def image_grid(imgs, rows, cols):
40
+ w, h = imgs[0].size
41
+ grid = Image.new("RGB", size=(cols * w, rows * h))
42
+ for i, img in enumerate(imgs):
43
+ grid.paste(img, box=(i % cols * w, i // cols * h))
44
+ return grid
45
+
46
+
47
+ # load control net and stable diffusion v1-5
48
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
49
+ "jax-diffuser-event/learner/trained_model_v0.1", from_flax=True, dtype=jnp.float32
50
+ )
51
+
52
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
53
+ "runwayml/stable-diffusion-v1-5",
54
+ controlnet=controlnet,
55
+ from_pt=True,
56
+ dtype=jnp.float32,
57
+ safety_checker=None,
58
+ )
59
+
60
+
61
+ # inference function takes prompt, negative prompt and image
62
+ def infer(prompts, negative_prompts, image):
63
+ params["controlnet"] = controlnet_params
64
+
65
+ num_samples = 1 # jax.device_count()
66
+ rng = create_key(0)
67
+ rng = jax.random.split(rng, jax.device_count())
68
+ battlemap_image = load_image(image)
69
+
70
+ prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
71
+ negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
72
+ processed_image = pipe.prepare_image_inputs([battlemap_image] * num_samples)
73
+
74
+ p_params = replicate(params)
75
+ prompt_ids = shard(prompt_ids)
76
+ negative_prompt_ids = shard(negative_prompt_ids)
77
+ processed_image = shard(processed_image)
78
+
79
+ output = pipe(
80
+ prompt_ids=prompt_ids,
81
+ image=processed_image,
82
+ params=p_params,
83
+ # params = params,
84
+ prng_seed=rng,
85
+ num_inference_steps=50,
86
+ neg_prompt_ids=negative_prompt_ids,
87
+ jit=True,
88
+ ).images
89
+
90
+ output_image = pipe.numpy_to_pil(
91
+ np.asarray(output.reshape((num_samples,) + output.shape[-3:]))
92
+ )
93
+
94
+ return output_image
95
+
96
+
97
+ title = "ControlNet on Battlemaps"
98
+ description = "This is a demo on ControlNet based on Bettlemaps."
99
+ # you need to pass inputs and outputs according to inference function
100
+ gr.Interface(
101
+ fn=infer,
102
+ inputs=["text", "text", "image"],
103
+ outputs="gallery",
104
+ title=title,
105
+ description=description,
106
+ ).launch()