cavargas10's picture
Update app.py
410cd67 verified
import gradio as gr
import spaces
from gradio_litmodel3d import LitModel3D
import os
import shutil
os.environ['SPCONV_ALGO'] = 'native'
from typing import *
import torch
import torchvision.transforms.functional as TF
import numpy as np
import random
import imageio
import cv2
from easydict import EasyDict as edict
from PIL import Image, ImageOps
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.representations import Gaussian, MeshExtractResult
from trellis.utils import render_utils, postprocessing_utils
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
from controlnet_aux import PidiNetDetector, HEDdetector
from diffusers.utils import load_image
from huggingface_hub import HfApi
from pathlib import Path
from gradio_imageslider import ImageSlider
style_list = [
{
"name": "(No style)",
"prompt": "{prompt}",
"negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
},
{
"name": "Cinematic",
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
},
{
"name": "3D Model",
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
},
]
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "(No style)"
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
def reset_canvas():
return gr.update(value={"background":Image.new("RGB", (512, 512), (255, 255, 255)), "layers":[Image.new("RGB", (512, 512), (255, 255, 255))], "composite":Image.new("RGB", (512, 512), (255, 255, 255))})
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
return p.replace("{prompt}", positive), n + negative
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(user_dir)
@spaces.GPU
def preprocess_image(image: Image.Image,
prompt: str = "",
negative_prompt: str = "",
style_name: str = "",
num_steps: int = 25,
guidance_scale: float = 5,
controlnet_conditioning_scale: float = 1.0,
) -> Image.Image:
width, height = image['composite'].size
ratio = np.sqrt(1024. * 1024. / (width * height))
new_width, new_height = int(width * ratio), int(height * ratio)
image = image['composite'].resize((new_width, new_height))
image = ImageOps.invert(image)
prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
output = pipe_control(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
num_inference_steps=num_steps,
controlnet_conditioning_scale=controlnet_conditioning_scale,
guidance_scale=guidance_scale,
width=new_width,
height=new_height).images[0]
processed_image = pipeline.preprocess_image(output)
return (image, processed_image)
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
return {
'gaussian': {
**gs.init_params,
'_xyz': gs._xyz.cpu().numpy(),
'_features_dc': gs._features_dc.cpu().numpy(),
'_scaling': gs._scaling.cpu().numpy(),
'_rotation': gs._rotation.cpu().numpy(),
'_opacity': gs._opacity.cpu().numpy(),
},
'mesh': {
'vertices': mesh.vertices.cpu().numpy(),
'faces': mesh.faces.cpu().numpy(),
},
}
def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
gs = Gaussian(
aabb=state['gaussian']['aabb'],
sh_degree=state['gaussian']['sh_degree'],
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
scaling_bias=state['gaussian']['scaling_bias'],
opacity_bias=state['gaussian']['opacity_bias'],
scaling_activation=state['gaussian']['scaling_activation'],
)
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
mesh = edict(
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
)
return gs, mesh
def get_seed(randomize_seed: bool, seed: int) -> int:
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
@spaces.GPU
def image_to_3d(
image: Image.Image,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
req: gr.Request,
) -> Tuple[dict, str]:
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
outputs = pipeline.run(
image[1],
seed=seed,
formats=["mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
video = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
video_path = os.path.join(user_dir, 'sample.mp4')
imageio.mimsave(video_path, video, fps=15)
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
torch.cuda.empty_cache()
return state, video_path
@spaces.GPU(duration=90)
def extract_glb(
state: dict,
mesh_simplify: float,
texture_size: int,
req: gr.Request,
) -> Tuple[str, str]:
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
gs, mesh = unpack_state(state)
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
glb_path = os.path.join(user_dir, 'sample.glb')
glb.export(glb_path)
torch.cuda.empty_cache()
return glb_path, glb_path
def reset_do_preprocess():
return True
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## Sketch to 3D with TRELLIS
1. Fast sketch to image with SDXL Flash, using [@xinsir](https://huggingface.co/xinsir) [scribble sdxl controlnet](https://huggingface.co/xinsir/controlnet-scribble-sdxl-1.0) and [sdxl flash](https://huggingface.co/sd-community/sdxl-flash)
2. Scalable and versatile image to 3D generation using [TRELLIS](https://trellis3d.github.io/)
### ð ¨ð ï¸ draw or upload a sketch and click "Generate" to create a 3D asset â ¨
""")
with gr.Row():
with gr.Column():
with gr.Column():
image_prompt = gr.ImageMask(label="Input sketch", type="pil", image_mode="RGB", height=512, value={"background":Image.new("RGB", (512, 512), (255, 255, 255)), "layers":[Image.new("RGB", (512, 512), (255, 255, 255))], "composite":Image.new("RGB", (512, 512), (255, 255, 255))})
with gr.Row():
sketch_btn = gr.Button("process sketch")
generate_btn = gr.Button("Generate 3D")
with gr.Row():
prompt = gr.Textbox(label="Prompt")
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
with gr.Accordion(label="Generation Settings", open=False):
with gr.Tab(label="sketch-to-image generation"):
negative_prompt = gr.Textbox(label="Negative prompt")
num_steps = gr.Slider(
label="Number of steps",
minimum=1,
maximum=20,
step=1,
value=8,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.1,
maximum=10.0,
step=0.1,
value=5,
)
controlnet_conditioning_scale = gr.Slider(
label="controlnet conditioning scale",
minimum=0.5,
maximum=5.0,
step=0.01,
value=0.85,
)
with gr.Tab(label="3D generation"):
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
gr.Markdown("Stage 2: Structured Latent Generation")
with gr.Row():
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
with gr.Accordion(label="GLB Extraction Settings", open=False):
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
with gr.Row():
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
gr.Markdown("""
*NOTE: GLB file can be downloaded after extraction.*
""")
with gr.Column():
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
image_prompt_processed = ImageSlider(label="processed sketch", interactive=False, type="pil", height=512)
model_output = LitModel3D(label="Extracted GLB", exposure=10.0, height=300)
with gr.Row():
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
output_buf = gr.State()
demo.load(start_session)
demo.unload(end_session)
image_prompt.clear(
fn=reset_canvas,
outputs = [image_prompt]
)
sketch_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
preprocess_image,
inputs=[image_prompt, prompt, negative_prompt, style, num_steps, guidance_scale, controlnet_conditioning_scale],
outputs=[image_prompt_processed],
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
image_to_3d,
inputs=[image_prompt_processed, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
outputs=[output_buf, video_output],
).then(
lambda: gr.Button(interactive=True),
outputs=[extract_glb_btn],
)
video_output.clear(
lambda: gr.Button(interactive=False),
outputs=[extract_glb_btn],
)
extract_glb_btn.click(
extract_glb,
inputs=[output_buf, mesh_simplify, texture_size],
outputs=[model_output, download_glb],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_glb],
)
model_output.clear(
lambda: gr.Button(interactive=False),
outputs=[download_glb],
)
if __name__ == "__main__":
pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
pipeline.cuda()
device = "cuda" if torch.cuda.is_available() else "cpu"
#scribble controlnet
controlnet = ControlNetModel.from_pretrained(
"xinsir/controlnet-scribble-sdxl-1.0",
torch_dtype=torch.float16
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe_control = StableDiffusionXLControlNetPipeline.from_pretrained(
"sd-community/sdxl-flash",
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16,
)
pipe_control.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe_control.scheduler.config)
pipe_control.to(device)
try:
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
except:
pass
demo.launch()