File size: 3,288 Bytes
5de8b22
a3f48ee
 
 
 
 
 
 
 
 
 
 
9281027
a3f48ee
 
 
 
 
 
 
 
 
 
c2df784
 
 
 
a3f48ee
 
 
 
80aa27f
c2df784
 
5068bdc
a3f48ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0eeab4
 
a3f48ee
 
 
 
 
 
a973e8e
a3f48ee
 
a973e8e
a3f48ee
4c550a6
a3f48ee
 
4c550a6
ce8130a
a3f48ee
 
4fb8bf2
2324727
 
24843a9
2324727
24843a9
f7d4f2d
 
2324727
24843a9
 
f7d4f2d
2324727
c51c906
 
 
a3f48ee
 
 
c51c906
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "briaai/RMBG-2.0", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

output_folder = 'output_images'
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    origin = im.copy()
    image = process(im)    
    image_path = os.path.join(output_folder, "no_bg_image.png")
    image.save(image_path)
    return (image, origin), image_path

@spaces.GPU
def process(image):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image
  
def process_file(f):
    name_path = f.rsplit(".",1)[0]+".png"
    im = load_img(f, output_type="pil")
    im = im.convert("RGB")
    transparent = process(im)
    transparent.save(name_path)
    return name_path

slider1 = ImageSlider(label="RMBG-2.0", type="pil")
slider2 = ImageSlider(label="RMBG-2.0", type="pil")
image = gr.Image(label="Upload an image")
image2 = gr.Image(label="Upload an image",type="filepath")
text = gr.Textbox(label="Paste an image URL")
png_file = gr.File(label="output png file")


chameleon = load_img("giraffe.jpg", output_type="pil")

url = "http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"

tab1 = gr.Interface(
    fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[chameleon], api_name="image"
)

tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=[url], api_name="text")
tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")


demo = gr.TabbedInterface(
    [tab1, tab2], ["input image", "input url"], title = (
    "RMBG-2.0 for background removal <br>"
    "<span style='font-size:16px; font-weight:300;'>"
    "Background removal model developed by "
    "<a href='https://bria.ai' target='_blank'>BRIA.AI</a>, trained on a carefully selected dataset,<br> "
    "and is available as an open-source model for non-commercial use.</span><br>"
    "<span style='font-size:16px; font-weight:500;'> For testing upload your image and wait.<br>"
    "<a href='https://go.bria.ai/3ZCBTLH' target='_blank'>Commercial use license</a> | "
    "<a href='https://huggingface.co/briaai/RMBG-2.0' target='_blank'>Model card</a> | "
    "<a href='https://blog.bria.ai/brias-new-state-of-the-art-remove-background-2.0-outperforms-the-competition' target='_blank'>Blog</a>"
    "</span>"
)



)

if __name__ == "__main__":
    demo.launch(show_error=True)