Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,288 Bytes
5de8b22 a3f48ee 9281027 a3f48ee c2df784 a3f48ee 80aa27f c2df784 5068bdc a3f48ee c0eeab4 a3f48ee a973e8e a3f48ee a973e8e a3f48ee 4c550a6 a3f48ee 4c550a6 ce8130a a3f48ee 4fb8bf2 2324727 24843a9 2324727 24843a9 f7d4f2d 2324727 24843a9 f7d4f2d 2324727 c51c906 a3f48ee c51c906 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"briaai/RMBG-2.0", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
output_folder = 'output_images'
if not os.path.exists(output_folder):
os.makedirs(output_folder)
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
image = process(im)
image_path = os.path.join(output_folder, "no_bg_image.png")
image.save(image_path)
return (image, origin), image_path
@spaces.GPU
def process(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
def process_file(f):
name_path = f.rsplit(".",1)[0]+".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
transparent.save(name_path)
return name_path
slider1 = ImageSlider(label="RMBG-2.0", type="pil")
slider2 = ImageSlider(label="RMBG-2.0", type="pil")
image = gr.Image(label="Upload an image")
image2 = gr.Image(label="Upload an image",type="filepath")
text = gr.Textbox(label="Paste an image URL")
png_file = gr.File(label="output png file")
chameleon = load_img("giraffe.jpg", output_type="pil")
url = "http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"
tab1 = gr.Interface(
fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[chameleon], api_name="image"
)
tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=[url], api_name="text")
tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")
demo = gr.TabbedInterface(
[tab1, tab2], ["input image", "input url"], title = (
"RMBG-2.0 for background removal <br>"
"<span style='font-size:16px; font-weight:300;'>"
"Background removal model developed by "
"<a href='https://bria.ai' target='_blank'>BRIA.AI</a>, trained on a carefully selected dataset,<br> "
"and is available as an open-source model for non-commercial use.</span><br>"
"<span style='font-size:16px; font-weight:500;'> For testing upload your image and wait.<br>"
"<a href='https://go.bria.ai/3ZCBTLH' target='_blank'>Commercial use license</a> | "
"<a href='https://huggingface.co/briaai/RMBG-2.0' target='_blank'>Model card</a> | "
"<a href='https://blog.bria.ai/brias-new-state-of-the-art-remove-background-2.0-outperforms-the-competition' target='_blank'>Blog</a>"
"</span>"
)
)
if __name__ == "__main__":
demo.launch(show_error=True)
|