geetu040's picture
fix index
b95c986
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from segmentation import predict as segmentation_predict
from depth_estimation import predict as depth_estimation_predict
def predict(image, color_map):
# inference
mask_image = segmentation_predict(image)
segmented_image = Image.composite(
image,
Image.new("RGB", image.size, (0, 0, 0)),
mask_image.convert("L")
)
depth_image = depth_estimation_predict(segmented_image)
# apply matplotlib colormap (e.g., viridis)
depth_array = np.array(depth_image) # Convert PIL image to NumPy array
colormap = plt.get_cmap(color_map) # Choose a colormap
depth_colored = colormap(depth_array / 255.0) # Normalize and apply colormap
depth_colored = (depth_colored * 255).astype(np.uint8) # Convert to RGB (discard alpha)
depth_colored = Image.fromarray(depth_colored)
return depth_colored
color_maps = [
'viridis', 'plasma', 'inferno', 'magma', 'cividis',
'Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds',
'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu',
'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn',
'binary', 'gist_yarg', 'gist_gray', 'gray', 'bone',
'pink', 'spring', 'summer', 'autumn', 'winter', 'cool',
'Wistia', 'hot', 'afmhot', 'gist_heat', 'copper',
'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu',
'RdYlGn', 'Spectral', 'coolwarm', 'bwr', 'seismic',
'twilight', 'twilight_shifted', 'hsv',
'Pastel1', 'Pastel2', 'Paired', 'Accent', 'Dark2',
'Set1', 'Set2', 'Set3', 'tab10', 'tab20', 'tab20b', 'tab20c',
'flag', 'prism', 'ocean', 'gist_earth', 'terrain',
'gist_stern', 'gnuplot', 'gnuplot2', 'CMRmap',
'cubehelix', 'brg', 'gist_rainbow', 'rainbow', 'jet',
'turbo', 'nipy_spectral', 'gist_ncar',
]
examples = [
["assets/examples/myself.jpeg", "afmhot"],
["assets/examples/myself.jpeg", "inferno"],
]
interface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil"),
gr.Dropdown(choices=color_maps),
],
outputs=gr.Image(type="pil"),
title="DepthPro: Colorify",
description="Applies segmentation on the input image, then creates the depth map and finally colorizes it.",
examples=examples,
)
if __name__ == "__main__":
interface.launch()