rrnoa's picture
marigold-v1-0
b1e03fc
import diffusers
import torch
from fastapi import FastAPI, UploadFile, HTTPException, File
from fastapi.responses import StreamingResponse
from PIL import Image
import io
app = FastAPI()
# Inicializa el pipeline al arrancar el servidor
@app.on_event("startup")
async def startup_event():
global pipe
print("[DEBUG] Cargando modelo Marigold-v1-0...")
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
"prs-eth/marigold-v1-0", variant="fp16", torch_dtype=torch.float16
).to("cuda")
print("[DEBUG] Modelo Marigold-v1-0 cargado exitosamente.")
@app.post("/predict-depth/")
async def predict_depth(file: UploadFile = File(...)):
try:
# Verifica si el archivo es una imagen v谩lida
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="El archivo subido no es una imagen.")
# Carga la imagen desde el archivo subido
image = Image.open(file.file).convert("RGB")
# Realiza la predicci贸n de profundidad
print("[DEBUG] Realizando predicci贸n de profundidad con Marigold-v1-0...")
depth = pipe(image)
# Exporta la profundidad como una imagen 16-bit PNG
depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction)
# Guarda la imagen generada en un buffer
img_buffer = io.BytesIO()
depth_16bit[0].save(img_buffer, format="PNG")
img_buffer.seek(0)
# Devuelve la imagen como respuesta
return StreamingResponse(img_buffer, media_type="image/png")
except Exception as e:
print(f"[ERROR] {str(e)}")
raise HTTPException(status_code=500, detail="Error procesando la imagen.")
@app.get("/")
async def root():
return {"message": "API de generaci贸n de mapas de profundidad con Marigold-v1-0"}