Spaces:
Sleeping
Sleeping
kernel-luso-comfort
commited on
Commit
·
6ba63c9
1
Parent(s):
cbd253a
Add initial module structure and entry points for modeling and utilities
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- Dockerfile +77 -0
- README.md +5 -5
- colabs/ENVIRONMENT.md +6 -0
- colabs/biomedparse_inference_demo.py +156 -0
- colabs/environment.yml +149 -0
- colabs/requirements-colab-pip-freeze.txt +567 -0
- colabs/requirements-colab.txt +39 -0
- configs/biomedparse_inference.yaml +204 -0
- entrypoint.sh +5 -0
- examples/Part_1_516_pathology_breast.png +3 -0
- inference_utils/inference.py +149 -0
- inference_utils/output_processing.py +91 -0
- inference_utils/processing_utils.py +182 -0
- inference_utils/target_dist.json +1 -0
- main.py +106 -0
- modeling/BaseModel.py +45 -0
- modeling/__init__.py +1 -0
- modeling/architectures/__init__.py +5 -0
- modeling/architectures/build.py +22 -0
- modeling/architectures/seem_model_demo.py +923 -0
- modeling/architectures/seem_model_v0.py +1160 -0
- modeling/architectures/seem_model_v1.py +1179 -0
- modeling/architectures/xdecoder_model.py +937 -0
- modeling/body/__init__.py +10 -0
- modeling/body/build.py +13 -0
- modeling/body/xdecoder_head.py +126 -0
- modeling/interface/__init__.py +13 -0
- modeling/interface/build.py +14 -0
- modeling/interface/modules.py +200 -0
- modeling/interface/prototype/__init__.py +0 -0
- modeling/interface/prototype/attention_data_struct_seemdemo.py +265 -0
- modeling/interface/prototype/attention_data_struct_seemv0.py +264 -0
- modeling/interface/prototype/attention_data_struct_seemv1.py +302 -0
- modeling/interface/seem_demo.py +397 -0
- modeling/interface/seem_v0.py +392 -0
- modeling/interface/seem_v1.py +389 -0
- modeling/interface/xdecoder.py +497 -0
- modeling/language/LangEncoder/__init__.py +35 -0
- modeling/language/LangEncoder/build.py +16 -0
- modeling/language/LangEncoder/transformer.py +222 -0
- modeling/language/__init__.py +10 -0
- modeling/language/build.py +14 -0
- modeling/language/loss.py +232 -0
- modeling/language/misc.py +66 -0
- modeling/language/vlpencoder.py +206 -0
- modeling/modules/__init__.py +6 -0
- modeling/modules/attention.py +487 -0
- modeling/modules/criterion.py +874 -0
- modeling/modules/matcher.py +632 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
FROM continuumio/miniconda3:latest
|
5 |
+
|
6 |
+
# Add build argument to force rebuild
|
7 |
+
ARG CACHEBUST=1
|
8 |
+
|
9 |
+
# Avoid tzdata interactive configuration
|
10 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
11 |
+
ENV TZ=UTC
|
12 |
+
|
13 |
+
# Install system dependencies
|
14 |
+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
15 |
+
git \
|
16 |
+
build-essential \
|
17 |
+
python3-dev \
|
18 |
+
wget \
|
19 |
+
openmpi-bin \
|
20 |
+
libopenmpi-dev \
|
21 |
+
libopenmpi3 \
|
22 |
+
libhwloc15 \
|
23 |
+
libevent-dev \
|
24 |
+
libpmix2 \
|
25 |
+
libgl1 \
|
26 |
+
libglib2.0-0 \
|
27 |
+
&& rm -rf /var/lib/apt/lists/*
|
28 |
+
|
29 |
+
# Set up OpenMPI environment
|
30 |
+
ENV OMPI_MCA_btl_vader_single_copy_mechanism=none \
|
31 |
+
OMPI_ALLOW_RUN_AS_ROOT=1 \
|
32 |
+
OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \
|
33 |
+
PATH=/usr/lib/x86_64-linux-gnu/openmpi/bin:$PATH \
|
34 |
+
LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/openmpi/lib:/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
|
35 |
+
|
36 |
+
# Copy environment file
|
37 |
+
COPY colabs/environment.yml /tmp/environment.yml
|
38 |
+
|
39 |
+
# Create conda environment
|
40 |
+
RUN conda env create -f /tmp/environment.yml && \
|
41 |
+
conda run -n biomedparse pip install gradio==3.50.2
|
42 |
+
|
43 |
+
# Initialize conda in bash
|
44 |
+
RUN conda init bash
|
45 |
+
|
46 |
+
# Make RUN commands use the new environment
|
47 |
+
SHELL ["conda", "run", "-n", "biomedparse", "/bin/bash", "-c"]
|
48 |
+
|
49 |
+
# Set up a new user named "user" with user ID 1000
|
50 |
+
RUN useradd -m -u 1000 user
|
51 |
+
|
52 |
+
# Switch to the "user" user
|
53 |
+
USER user
|
54 |
+
|
55 |
+
# Set up HF token for the user
|
56 |
+
RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true \
|
57 |
+
echo "export HF_TOKEN=$(cat /run/secrets/HF_TOKEN)" >> $HOME/.bashrc
|
58 |
+
|
59 |
+
# Set home to the user's home directory
|
60 |
+
ENV HOME=/home/user \
|
61 |
+
PATH=/home/user/.local/bin:$PATH
|
62 |
+
|
63 |
+
# Set the working directory to the user's home directory
|
64 |
+
WORKDIR $HOME/app
|
65 |
+
|
66 |
+
# Copy all files to the app directory
|
67 |
+
COPY --chown=user . $HOME/app
|
68 |
+
|
69 |
+
# Set permissions for entrypoint script
|
70 |
+
RUN chmod 755 $HOME/app/entrypoint.sh
|
71 |
+
|
72 |
+
# Add conda environment to user's path
|
73 |
+
RUN echo "conda activate biomedparse" >> $HOME/.bashrc
|
74 |
+
|
75 |
+
# Use entrypoint script to set up environment and run application
|
76 |
+
ENTRYPOINT ["/bin/bash", "-c"]
|
77 |
+
CMD ["exec /home/user/app/entrypoint.sh"]
|
README.md
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
-
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Biomedparse Docker
|
3 |
+
emoji: 📉
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: blue
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
+
license: cc-by-nc-sa-4.0
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
colabs/ENVIRONMENT.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Description of Google Colab Environment
|
2 |
+
|
3 |
+
- Hardware: Python 3 Google Compute Engine Backend on T4 GPU
|
4 |
+
- CUDA version: 12.2
|
5 |
+
- Driver Version: 535.104.05
|
6 |
+
- Python version: 3.10.12
|
colabs/biomedparse_inference_demo.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""biomedparse_inference_demo.ipynb
|
3 |
+
|
4 |
+
Automatically generated by Colab.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/drive/1jL4wvdtBWz6G_yBkFn8tyDD0hV1RtKVZ
|
8 |
+
|
9 |
+
# BiomedParse Inference Demo Notebook
|
10 |
+
|
11 |
+
Welcome to the demo notebook for BiomedParse, a comprehensive tool for biomedical image analysis. BiomedParse is designed to simultaneously handle segmentation, detection, and recognition tasks across major biomedical image modalities, providing a unified solution for complex image analysis in biomedical research.
|
12 |
+
|
13 |
+
[[`Paper`](https://aka.ms/biomedparse-paper)] [[`Demo`](https://microsoft.github.io/BiomedParse/)] [[`Model`](https://huggingface.co/microsoft/BiomedParse)] [[`Data`](https://huggingface.co/datasets/microsoft/BiomedParseData)]
|
14 |
+
|
15 |
+
## Model Checkpoint Access
|
16 |
+
|
17 |
+
The BiomedParse model checkpoint is hosted on [HuggingFace](https://huggingface.co/microsoft/BiomedParse). To access the model:
|
18 |
+
|
19 |
+
1. Visit the [model page](https://huggingface.co/microsoft/BiomedParse).
|
20 |
+
2. Make sure to review and accept the terms of use to gain access to the checkpoint.
|
21 |
+
3. Retrieve your HuggingFace access token from your user profile.
|
22 |
+
|
23 |
+
## Setting Up Access
|
24 |
+
|
25 |
+
To use the model, set your Hugging Face access token in the HF_TOKEN environment variable or as a Colab secret. This step ensures secure and authorized access to the model resources.
|
26 |
+
"""
|
27 |
+
|
28 |
+
# Set your Hugging Face access token in your environment
|
29 |
+
# import os
|
30 |
+
# os.environ['HF_TOKEN'] = 'your_huggingface_access_token_here'
|
31 |
+
|
32 |
+
# Or, if you are using Google Colab, set HF_TOKEN on Colab secrets.
|
33 |
+
|
34 |
+
from google.colab import userdata
|
35 |
+
import huggingface_hub
|
36 |
+
|
37 |
+
huggingface_hub.login(userdata.get('HF_TOKEN'))
|
38 |
+
|
39 |
+
from huggingface_hub import hf_hub_download
|
40 |
+
|
41 |
+
model_file = hf_hub_download(repo_id="microsoft/BiomedParse", filename="biomedparse_v1.pt", local_dir="pretrained")
|
42 |
+
|
43 |
+
print(f"Downloaded model file to: {model_file}")
|
44 |
+
|
45 |
+
"""## Environment Setup"""
|
46 |
+
|
47 |
+
!git clone https://github.com/microsoft/BiomedParse
|
48 |
+
|
49 |
+
!pip install -r BiomedParse/assets/requirements/requirements.txt
|
50 |
+
|
51 |
+
"""# Restart Colab Runtime"""
|
52 |
+
|
53 |
+
# Make sure to restart Colab runtime after installing dependencies
|
54 |
+
import os
|
55 |
+
try:
|
56 |
+
import google.colab
|
57 |
+
os._exit(0)
|
58 |
+
except ImportError:
|
59 |
+
pass
|
60 |
+
|
61 |
+
import os
|
62 |
+
os.chdir('/content/BiomedParse')
|
63 |
+
print(os.getcwd())
|
64 |
+
|
65 |
+
"""## Load the model weights"""
|
66 |
+
|
67 |
+
from PIL import Image
|
68 |
+
import torch
|
69 |
+
import argparse
|
70 |
+
import numpy as np
|
71 |
+
from modeling.BaseModel import BaseModel
|
72 |
+
from modeling import build_model
|
73 |
+
from utilities.distributed import init_distributed # changed from utils
|
74 |
+
from utilities.arguments import load_opt_from_config_files
|
75 |
+
from utilities.constants import BIOMED_CLASSES
|
76 |
+
from inference_utils.inference import interactive_infer_image
|
77 |
+
|
78 |
+
conf_files = "configs/biomedparse_inference.yaml"
|
79 |
+
opt = load_opt_from_config_files([conf_files])
|
80 |
+
opt = init_distributed(opt)
|
81 |
+
|
82 |
+
model_file = "../pretrained/biomedparse_v1.pt"
|
83 |
+
|
84 |
+
model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda()
|
85 |
+
with torch.no_grad():
|
86 |
+
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(BIOMED_CLASSES + ["background"], is_eval=True)
|
87 |
+
|
88 |
+
"""# Run Inference"""
|
89 |
+
|
90 |
+
# RGB image input of shape (H, W, 3). Currently only batch size 1 is supported.
|
91 |
+
image = Image.open('examples/Part_1_516_pathology_breast.png', formats=['png'])
|
92 |
+
image = image.convert('RGB')
|
93 |
+
|
94 |
+
# text prompts querying objects in the image. Multiple ones can be provided.
|
95 |
+
prompts = ['neoplastic cells', 'inflammatory cells']
|
96 |
+
|
97 |
+
pred_mask = interactive_infer_image(model, image, prompts)
|
98 |
+
pred_mask.shape
|
99 |
+
|
100 |
+
# load ground truth mask
|
101 |
+
gt_masks = []
|
102 |
+
for prompt in prompts:
|
103 |
+
gt_mask = Image.open(f"examples/Part_1_516_pathology_breast_{prompt.replace(' ', '+')}.png", formats=['png'])
|
104 |
+
gt_mask = 1*(np.array(gt_mask.convert('RGB'))[:,:,0] > 0)
|
105 |
+
gt_masks.append(gt_mask)
|
106 |
+
|
107 |
+
# prediction with ground truth mask
|
108 |
+
for i, pred in enumerate(pred_mask):
|
109 |
+
gt = gt_masks[i]
|
110 |
+
dice = (1*(pred>0.5) & gt).sum() * 2.0 / (1*(pred>0.5).sum() + gt.sum())
|
111 |
+
print(f'Dice score for {prompts[i]}: {dice:.4f}')
|
112 |
+
|
113 |
+
import numpy as np
|
114 |
+
import matplotlib.pyplot as plt
|
115 |
+
from PIL import Image
|
116 |
+
import matplotlib.patches as mpatches
|
117 |
+
|
118 |
+
def overlay_masks(image, masks, colors):
|
119 |
+
overlay = image.copy()
|
120 |
+
overlay = np.array(overlay, dtype=np.uint8)
|
121 |
+
for mask, color in zip(masks, colors):
|
122 |
+
overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype(np.uint8)
|
123 |
+
return Image.fromarray(overlay)
|
124 |
+
|
125 |
+
def generate_colors(n):
|
126 |
+
cmap = plt.get_cmap('tab10')
|
127 |
+
colors = [tuple(int(255 * val) for val in cmap(i)[:3]) for i in range(n)]
|
128 |
+
return colors
|
129 |
+
|
130 |
+
original_image = Image.open('examples/Part_1_516_pathology_breast.png').convert('RGB')
|
131 |
+
|
132 |
+
colors = generate_colors(len(prompts))
|
133 |
+
|
134 |
+
pred_overlay = overlay_masks(original_image, [1*(pred_mask[i] > 0.5) for i in range(len(prompts))], colors)
|
135 |
+
|
136 |
+
gt_overlay = overlay_masks(original_image, gt_masks, colors)
|
137 |
+
|
138 |
+
legend_patches = [mpatches.Patch(color=np.array(color) / 255, label=prompt) for color, prompt in zip(colors, prompts)]
|
139 |
+
|
140 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
141 |
+
axes[0].imshow(original_image)
|
142 |
+
axes[0].set_title("Original Image")
|
143 |
+
axes[0].axis('off')
|
144 |
+
|
145 |
+
axes[1].imshow(pred_overlay)
|
146 |
+
axes[1].set_title("Predictions")
|
147 |
+
axes[1].axis('off')
|
148 |
+
axes[1].legend(handles=legend_patches, loc='upper right', fontsize='small')
|
149 |
+
|
150 |
+
axes[2].imshow(gt_overlay)
|
151 |
+
axes[2].set_title("Ground Truth")
|
152 |
+
axes[2].axis('off')
|
153 |
+
axes[2].legend(handles=legend_patches, loc='upper right', fontsize='small')
|
154 |
+
|
155 |
+
plt.tight_layout()
|
156 |
+
plt.show()
|
colabs/environment.yml
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: biomedparse
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- _libgcc_mutex=0.1=main
|
8 |
+
- _openmp_mutex=5.1=1_gnu
|
9 |
+
- blas=1.0=mkl
|
10 |
+
- brotli-python=1.0.9=py39h6a678d5_8
|
11 |
+
- bzip2=1.0.8=h5eee18b_6
|
12 |
+
- ca-certificates=2024.7.2=h06a4308_0
|
13 |
+
- certifi=2024.7.4=py39h06a4308_0
|
14 |
+
- charset-normalizer=3.3.2=pyhd3eb1b0_0
|
15 |
+
- cuda-cudart=12.4.127=0
|
16 |
+
- cuda-cupti=12.4.127=0
|
17 |
+
- cuda-libraries=12.4.0=0
|
18 |
+
- cuda-nvrtc=12.4.127=0
|
19 |
+
- cuda-nvtx=12.4.127=0
|
20 |
+
- cuda-opencl=12.6.37=0
|
21 |
+
- cuda-runtime=12.4.0=0
|
22 |
+
- cuda-version=12.6=3
|
23 |
+
- ffmpeg=4.3=hf484d3e_0
|
24 |
+
- filelock=3.13.1=py39h06a4308_0
|
25 |
+
- freetype=2.12.1=h4a9f257_0
|
26 |
+
- gmp=6.2.1=h295c915_3
|
27 |
+
- gmpy2=2.1.2=py39heeb90bb_0
|
28 |
+
- gnutls=3.6.15=he1e5248_0
|
29 |
+
- idna=3.7=py39h06a4308_0
|
30 |
+
- intel-openmp=2023.1.0=hdb19cb5_46306
|
31 |
+
- jinja2=3.1.4=py39h06a4308_0
|
32 |
+
- jpeg=9e=h5eee18b_3
|
33 |
+
- lame=3.100=h7b6447c_0
|
34 |
+
- lcms2=2.12=h3be6417_0
|
35 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
36 |
+
- lerc=3.0=h295c915_0
|
37 |
+
- libcublas=12.4.2.65=0
|
38 |
+
- libcufft=11.2.0.44=0
|
39 |
+
- libcufile=1.11.0.15=0
|
40 |
+
- libcurand=10.3.7.37=0
|
41 |
+
- libcusolver=11.6.0.99=0
|
42 |
+
- libcusparse=12.3.0.142=0
|
43 |
+
- libdeflate=1.17=h5eee18b_1
|
44 |
+
- libffi=3.4.4=h6a678d5_1
|
45 |
+
- libgcc-ng=11.2.0=h1234567_1
|
46 |
+
- libgomp=11.2.0=h1234567_1
|
47 |
+
- libiconv=1.16=h5eee18b_3
|
48 |
+
- libidn2=2.3.4=h5eee18b_0
|
49 |
+
- libjpeg-turbo=2.0.0=h9bf148f_0
|
50 |
+
- libnpp=12.2.5.2=0
|
51 |
+
- libnvfatbin=12.6.20=0
|
52 |
+
- libnvjitlink=12.4.99=0
|
53 |
+
- libnvjpeg=12.3.1.89=0
|
54 |
+
- libpng=1.6.39=h5eee18b_0
|
55 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
56 |
+
- libtasn1=4.19.0=h5eee18b_0
|
57 |
+
- libtiff=4.5.1=h6a678d5_0
|
58 |
+
- libunistring=0.9.10=h27cfd23_0
|
59 |
+
- libwebp-base=1.3.2=h5eee18b_0
|
60 |
+
- llvm-openmp=14.0.6=h9e868ea_0
|
61 |
+
- lz4-c=1.9.4=h6a678d5_1
|
62 |
+
- markupsafe=2.1.3=py39h5eee18b_0
|
63 |
+
- mkl=2023.1.0=h213fc3f_46344
|
64 |
+
- mkl-service=2.4.0=py39h5eee18b_1
|
65 |
+
- mkl_fft=1.3.8=py39h5eee18b_0
|
66 |
+
- mkl_random=1.2.4=py39hdb19cb5_0
|
67 |
+
- mpc=1.1.0=h10f8cd9_1
|
68 |
+
- mpfr=4.0.2=hb69a4c5_1
|
69 |
+
- mpmath=1.3.0=py39h06a4308_0
|
70 |
+
- ncurses=6.4=h6a678d5_0
|
71 |
+
- nettle=3.7.3=hbbd107a_1
|
72 |
+
- networkx=3.2.1=py39h06a4308_0
|
73 |
+
- openh264=2.1.1=h4ff587b_0
|
74 |
+
- openjpeg=2.5.2=he7f1fd0_0
|
75 |
+
- openssl=3.0.14=h5eee18b_0
|
76 |
+
- pip=24.2=py39h06a4308_0
|
77 |
+
- pysocks=1.7.1=py39h06a4308_0
|
78 |
+
- python=3.9.19=h955ad1f_1
|
79 |
+
- pytorch=2.4.0=py3.9_cuda12.4_cudnn9.1.0_0
|
80 |
+
- pytorch-cuda=12.4=hc786d27_6
|
81 |
+
- pytorch-mutex=1.0=cuda
|
82 |
+
- pyyaml=6.0.1=py39h5eee18b_0
|
83 |
+
- readline=8.2=h5eee18b_0
|
84 |
+
- requests=2.32.3=py39h06a4308_0
|
85 |
+
- setuptools=72.1.0=py39h06a4308_0
|
86 |
+
- sqlite=3.45.3=h5eee18b_0
|
87 |
+
- sympy=1.12=py39h06a4308_0
|
88 |
+
- tbb=2021.8.0=hdb19cb5_0
|
89 |
+
- tk=8.6.14=h39e8969_0
|
90 |
+
- torchaudio=2.4.0=py39_cu124
|
91 |
+
- torchtriton=3.0.0=py39
|
92 |
+
- torchvision=0.19.0=py39_cu124
|
93 |
+
- typing_extensions=4.11.0=py39h06a4308_0
|
94 |
+
- tzdata=2024a=h04d1e81_0
|
95 |
+
- urllib3=2.2.2=py39h06a4308_0
|
96 |
+
- wheel=0.43.0=py39h06a4308_0
|
97 |
+
- xz=5.4.6=h5eee18b_1
|
98 |
+
- yaml=0.2.5=h7b6447c_0
|
99 |
+
- zlib=1.2.13=h5eee18b_1
|
100 |
+
- zstd=1.5.5=hc292b87_2
|
101 |
+
- pip:
|
102 |
+
- accelerate==0.23.0
|
103 |
+
- antlr4-python3-runtime==4.9.3
|
104 |
+
- appdirs==1.4.4
|
105 |
+
- black==21.4b2
|
106 |
+
- open-clip-torch==2.26.1
|
107 |
+
- cloudpickle==3.0.0
|
108 |
+
- cython==3.0.2
|
109 |
+
- deepspeed==0.10.3
|
110 |
+
- git+https://github.com/MaureenZOU/detectron2-xyz.git
|
111 |
+
- diffdist==0.1
|
112 |
+
- einops==0.8.0
|
113 |
+
- ftfy==6.1.1
|
114 |
+
- fvcore==0.1.5.post20221221
|
115 |
+
- hjson==3.1.0
|
116 |
+
- huggingface-hub==0.17.3
|
117 |
+
- hydra-core==1.3.2
|
118 |
+
- imageio==2.35.1
|
119 |
+
- infinibatch==0.1.1
|
120 |
+
- iopath==0.1.9
|
121 |
+
- json-tricks==3.17.3
|
122 |
+
- kornia==0.7.0
|
123 |
+
- mpi4py==3.1.5
|
124 |
+
- mup==1.0.0
|
125 |
+
- mypy-extensions==1.0.0
|
126 |
+
- ninja==1.11.1.1
|
127 |
+
- nltk==3.8.1
|
128 |
+
- numpy==1.23.1
|
129 |
+
- omegaconf==2.3.0
|
130 |
+
- opencv-python==4.8.1.78
|
131 |
+
- pandas==2.0.3
|
132 |
+
- pathspec==0.12.1
|
133 |
+
- pillow==9.4.0
|
134 |
+
- portalocker==2.10.1
|
135 |
+
- py-cpuinfo==9.0.0
|
136 |
+
- pycocotools==2.0.7
|
137 |
+
- pydantic==1.10.18
|
138 |
+
- pydot==3.0.1
|
139 |
+
- regex==2023.10.3
|
140 |
+
- scikit-image==0.21.0
|
141 |
+
- scikit-learn==1.3.1
|
142 |
+
- sentencepiece==0.1.99
|
143 |
+
- tabulate==0.9.0
|
144 |
+
- termcolor==2.4.0
|
145 |
+
- timm==0.4.12
|
146 |
+
- tokenizers==0.14.1
|
147 |
+
- transformers==4.34.0
|
148 |
+
- vision-datasets==0.2.2
|
149 |
+
- yacs==0.1.8
|
colabs/requirements-colab-pip-freeze.txt
ADDED
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.4.0
|
2 |
+
accelerate==0.23.0
|
3 |
+
aiohappyeyeballs==2.4.4
|
4 |
+
aiohttp==3.11.10
|
5 |
+
aiosignal==1.3.2
|
6 |
+
alabaster==1.0.0
|
7 |
+
albucore==0.0.19
|
8 |
+
albumentations==1.4.20
|
9 |
+
altair==5.5.0
|
10 |
+
annotated-types==0.7.0
|
11 |
+
antlr4-python3-runtime==4.9.3
|
12 |
+
anyio==3.7.1
|
13 |
+
appdirs==1.4.4
|
14 |
+
argon2-cffi==23.1.0
|
15 |
+
argon2-cffi-bindings==21.2.0
|
16 |
+
array_record==0.5.1
|
17 |
+
arviz==0.20.0
|
18 |
+
astropy==6.1.7
|
19 |
+
astropy-iers-data==0.2024.12.16.0.35.48
|
20 |
+
astunparse==1.6.3
|
21 |
+
async-timeout==4.0.3
|
22 |
+
atpublic==4.1.0
|
23 |
+
attrs==24.3.0
|
24 |
+
audioread==3.0.1
|
25 |
+
autograd==1.7.0
|
26 |
+
babel==2.16.0
|
27 |
+
backcall==0.2.0
|
28 |
+
beautifulsoup4==4.12.3
|
29 |
+
bigframes==1.29.0
|
30 |
+
bigquery-magics==0.4.0
|
31 |
+
black==21.4b2
|
32 |
+
bleach==6.2.0
|
33 |
+
blinker==1.9.0
|
34 |
+
blis==0.7.11
|
35 |
+
blosc2==2.7.1
|
36 |
+
bokeh==3.6.2
|
37 |
+
Bottleneck==1.4.2
|
38 |
+
bqplot==0.12.43
|
39 |
+
branca==0.8.1
|
40 |
+
CacheControl==0.14.1
|
41 |
+
cachetools==5.5.0
|
42 |
+
catalogue==2.0.10
|
43 |
+
certifi==2024.12.14
|
44 |
+
cffi==1.17.1
|
45 |
+
chardet==5.2.0
|
46 |
+
charset-normalizer==3.4.0
|
47 |
+
chex==0.1.88
|
48 |
+
clarabel==0.9.0
|
49 |
+
click==8.1.7
|
50 |
+
cloudpathlib==0.20.0
|
51 |
+
cloudpickle==3.1.0
|
52 |
+
cmake==3.31.2
|
53 |
+
cmdstanpy==1.2.5
|
54 |
+
colorcet==3.1.0
|
55 |
+
colorlover==0.3.0
|
56 |
+
colour==0.1.5
|
57 |
+
community==1.0.0b1
|
58 |
+
confection==0.1.5
|
59 |
+
cons==0.4.6
|
60 |
+
contourpy==1.3.1
|
61 |
+
cryptography==43.0.3
|
62 |
+
cuda-python==12.2.1
|
63 |
+
cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-24.10.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
64 |
+
cufflinks==0.17.3
|
65 |
+
cupy-cuda12x==12.2.0
|
66 |
+
cvxopt==1.3.2
|
67 |
+
cvxpy==1.6.0
|
68 |
+
cycler==0.12.1
|
69 |
+
cymem==2.0.10
|
70 |
+
Cython==3.0.2
|
71 |
+
dask==2024.10.0
|
72 |
+
datascience==0.17.6
|
73 |
+
db-dtypes==1.3.1
|
74 |
+
dbus-python==1.2.18
|
75 |
+
debugpy==1.8.0
|
76 |
+
decorator==4.4.2
|
77 |
+
deepspeed==0.10.3
|
78 |
+
defusedxml==0.7.1
|
79 |
+
Deprecated==1.2.15
|
80 |
+
detectron2 @ git+https://github.com/MaureenZOU/detectron2-xyz.git@42121d75e10d9f858f3a91b6a39f5722c02868f0
|
81 |
+
diffdist==0.1
|
82 |
+
diffusers==0.31.0
|
83 |
+
distro==1.9.0
|
84 |
+
dlib==19.24.2
|
85 |
+
dm-tree==0.1.8
|
86 |
+
docker-pycreds==0.4.0
|
87 |
+
docstring_parser==0.16
|
88 |
+
docutils==0.21.2
|
89 |
+
dopamine_rl==4.1.0
|
90 |
+
duckdb==1.1.3
|
91 |
+
earthengine-api==1.4.3
|
92 |
+
easydict==1.13
|
93 |
+
editdistance==0.8.1
|
94 |
+
eerepr==0.0.4
|
95 |
+
einops==0.8.0
|
96 |
+
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
|
97 |
+
entrypoints==0.4
|
98 |
+
et_xmlfile==2.0.0
|
99 |
+
etils==1.11.0
|
100 |
+
etuples==0.3.9
|
101 |
+
eval_type_backport==0.2.0
|
102 |
+
exceptiongroup==1.2.2
|
103 |
+
fastai==2.7.18
|
104 |
+
fastcore==1.7.27
|
105 |
+
fastdownload==0.0.7
|
106 |
+
fastjsonschema==2.21.1
|
107 |
+
fastprogress==1.0.3
|
108 |
+
fastrlock==0.8.3
|
109 |
+
filelock==3.16.1
|
110 |
+
firebase-admin==6.6.0
|
111 |
+
Flask==3.1.0
|
112 |
+
flatbuffers==24.3.25
|
113 |
+
flax==0.8.5
|
114 |
+
folium==0.19.2
|
115 |
+
fonttools==4.55.3
|
116 |
+
frozendict==2.4.6
|
117 |
+
frozenlist==1.5.0
|
118 |
+
fsspec==2024.10.0
|
119 |
+
ftfy==6.1.1
|
120 |
+
future==1.0.0
|
121 |
+
fvcore==0.1.5.post20221221
|
122 |
+
gast==0.6.0
|
123 |
+
gcsfs==2024.10.0
|
124 |
+
GDAL==3.6.4
|
125 |
+
gdown==5.2.0
|
126 |
+
geemap==0.35.1
|
127 |
+
gensim==4.3.3
|
128 |
+
geocoder==1.38.1
|
129 |
+
geographiclib==2.0
|
130 |
+
geopandas==1.0.1
|
131 |
+
geopy==2.4.1
|
132 |
+
gin-config==0.5.0
|
133 |
+
gitdb==4.0.11
|
134 |
+
GitPython==3.1.43
|
135 |
+
glob2==0.7
|
136 |
+
google==2.0.3
|
137 |
+
google-ai-generativelanguage==0.6.10
|
138 |
+
google-api-core==2.19.2
|
139 |
+
google-api-python-client==2.155.0
|
140 |
+
google-auth==2.27.0
|
141 |
+
google-auth-httplib2==0.2.0
|
142 |
+
google-auth-oauthlib==1.2.1
|
143 |
+
google-cloud-aiplatform==1.74.0
|
144 |
+
google-cloud-bigquery==3.25.0
|
145 |
+
google-cloud-bigquery-connection==1.17.0
|
146 |
+
google-cloud-bigquery-storage==2.27.0
|
147 |
+
google-cloud-bigtable==2.27.0
|
148 |
+
google-cloud-core==2.4.1
|
149 |
+
google-cloud-datastore==2.20.2
|
150 |
+
google-cloud-firestore==2.19.0
|
151 |
+
google-cloud-functions==1.19.0
|
152 |
+
google-cloud-iam==2.17.0
|
153 |
+
google-cloud-language==2.16.0
|
154 |
+
google-cloud-pubsub==2.27.1
|
155 |
+
google-cloud-resource-manager==1.14.0
|
156 |
+
google-cloud-storage==2.19.0
|
157 |
+
google-cloud-translate==3.19.0
|
158 |
+
google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz
|
159 |
+
google-crc32c==1.6.0
|
160 |
+
google-genai==0.3.0
|
161 |
+
google-generativeai==0.8.3
|
162 |
+
google-pasta==0.2.0
|
163 |
+
google-resumable-media==2.7.2
|
164 |
+
googleapis-common-protos==1.66.0
|
165 |
+
googledrivedownloader==0.4
|
166 |
+
graphviz==0.20.3
|
167 |
+
greenlet==3.1.1
|
168 |
+
grpc-google-iam-v1==0.13.1
|
169 |
+
grpcio==1.68.1
|
170 |
+
grpcio-status==1.62.3
|
171 |
+
gspread==6.0.2
|
172 |
+
gspread-dataframe==3.3.1
|
173 |
+
gym==0.25.2
|
174 |
+
gym-notices==0.0.8
|
175 |
+
h11==0.14.0
|
176 |
+
h5netcdf==1.4.1
|
177 |
+
h5py==3.12.1
|
178 |
+
hjson==3.1.0
|
179 |
+
holidays==0.63
|
180 |
+
holoviews==1.20.0
|
181 |
+
html5lib==1.1
|
182 |
+
httpcore==1.0.7
|
183 |
+
httpimport==1.4.0
|
184 |
+
httplib2==0.22.0
|
185 |
+
httpx==0.28.1
|
186 |
+
huggingface-hub==0.17.3
|
187 |
+
humanize==4.11.0
|
188 |
+
hydra-core==1.3.2
|
189 |
+
hyperopt==0.2.7
|
190 |
+
ibis-framework==9.2.0
|
191 |
+
idna==3.10
|
192 |
+
imageio==2.36.1
|
193 |
+
imageio-ffmpeg==0.5.1
|
194 |
+
imagesize==1.4.1
|
195 |
+
imbalanced-learn==0.12.4
|
196 |
+
imgaug==0.4.0
|
197 |
+
immutabledict==4.2.1
|
198 |
+
importlib_metadata==8.5.0
|
199 |
+
importlib_resources==6.4.5
|
200 |
+
imutils==0.5.4
|
201 |
+
infinibatch==0.1.1
|
202 |
+
inflect==7.4.0
|
203 |
+
iniconfig==2.0.0
|
204 |
+
intel-cmplr-lib-ur==2025.0.4
|
205 |
+
intel-openmp==2025.0.4
|
206 |
+
iopath==0.1.9
|
207 |
+
ipyevents==2.0.2
|
208 |
+
ipyfilechooser==0.6.0
|
209 |
+
ipykernel==5.5.6
|
210 |
+
ipyleaflet==0.19.2
|
211 |
+
ipyparallel==8.8.0
|
212 |
+
ipython==7.34.0
|
213 |
+
ipython-genutils==0.2.0
|
214 |
+
ipython-sql==0.5.0
|
215 |
+
ipytree==0.2.2
|
216 |
+
ipywidgets==7.7.1
|
217 |
+
itsdangerous==2.2.0
|
218 |
+
jax==0.4.33
|
219 |
+
jax-cuda12-pjrt==0.4.33
|
220 |
+
jax-cuda12-plugin==0.4.33
|
221 |
+
jaxlib==0.4.33
|
222 |
+
jeepney==0.7.1
|
223 |
+
jellyfish==1.1.0
|
224 |
+
jieba==0.42.1
|
225 |
+
Jinja2==3.1.4
|
226 |
+
jiter==0.8.2
|
227 |
+
joblib==1.4.2
|
228 |
+
json-tricks==3.17.3
|
229 |
+
jsonpatch==1.33
|
230 |
+
jsonpickle==4.0.1
|
231 |
+
jsonpointer==3.0.0
|
232 |
+
jsonschema==4.23.0
|
233 |
+
jsonschema-specifications==2024.10.1
|
234 |
+
jupyter-client==6.1.12
|
235 |
+
jupyter-console==6.1.0
|
236 |
+
jupyter-leaflet==0.19.2
|
237 |
+
jupyter-server==1.24.0
|
238 |
+
jupyter_core==5.7.2
|
239 |
+
jupyterlab_pygments==0.3.0
|
240 |
+
jupyterlab_widgets==3.0.13
|
241 |
+
kaggle==1.6.17
|
242 |
+
kagglehub==0.3.5
|
243 |
+
keras==3.5.0
|
244 |
+
keyring==23.5.0
|
245 |
+
kiwisolver==1.4.7
|
246 |
+
kornia==0.7.0
|
247 |
+
langchain==0.3.12
|
248 |
+
langchain-core==0.3.25
|
249 |
+
langchain-text-splitters==0.3.3
|
250 |
+
langcodes==3.5.0
|
251 |
+
langsmith==0.2.3
|
252 |
+
language_data==1.3.0
|
253 |
+
launchpadlib==1.10.16
|
254 |
+
lazr.restfulclient==0.14.4
|
255 |
+
lazr.uri==1.0.6
|
256 |
+
lazy_loader==0.4
|
257 |
+
libclang==18.1.1
|
258 |
+
libcudf-cu12 @ https://pypi.nvidia.com/libcudf-cu12/libcudf_cu12-24.10.1-py3-none-manylinux_2_28_x86_64.whl
|
259 |
+
librosa==0.10.2.post1
|
260 |
+
lightgbm==4.5.0
|
261 |
+
linkify-it-py==2.0.3
|
262 |
+
llvmlite==0.43.0
|
263 |
+
locket==1.0.0
|
264 |
+
logical-unification==0.4.6
|
265 |
+
lxml==5.3.0
|
266 |
+
marisa-trie==1.2.1
|
267 |
+
Markdown==3.7
|
268 |
+
markdown-it-py==3.0.0
|
269 |
+
MarkupSafe==3.0.2
|
270 |
+
matplotlib==3.8.0
|
271 |
+
matplotlib-inline==0.1.7
|
272 |
+
matplotlib-venn==1.1.1
|
273 |
+
mdit-py-plugins==0.4.2
|
274 |
+
mdurl==0.1.2
|
275 |
+
miniKanren==1.0.3
|
276 |
+
missingno==0.5.2
|
277 |
+
mistune==3.0.2
|
278 |
+
mizani==0.13.1
|
279 |
+
mkl==2025.0.1
|
280 |
+
ml-dtypes==0.4.1
|
281 |
+
mlxtend==0.23.3
|
282 |
+
more-itertools==10.5.0
|
283 |
+
moviepy==1.0.3
|
284 |
+
mpi4py==3.1.5
|
285 |
+
mpmath==1.3.0
|
286 |
+
msgpack==1.1.0
|
287 |
+
multidict==6.1.0
|
288 |
+
multipledispatch==1.0.0
|
289 |
+
multitasking==0.0.11
|
290 |
+
mup==1.0.0
|
291 |
+
murmurhash==1.0.11
|
292 |
+
music21==9.3.0
|
293 |
+
mypy-extensions==1.0.0
|
294 |
+
namex==0.0.8
|
295 |
+
narwhals==1.18.4
|
296 |
+
natsort==8.4.0
|
297 |
+
nbclassic==1.1.0
|
298 |
+
nbclient==0.10.1
|
299 |
+
nbconvert==7.16.4
|
300 |
+
nbformat==5.10.4
|
301 |
+
ndindex==1.9.2
|
302 |
+
nest-asyncio==1.6.0
|
303 |
+
networkx==3.4.2
|
304 |
+
nibabel==5.3.2
|
305 |
+
ninja==1.11.1.3
|
306 |
+
nltk==3.8.1
|
307 |
+
notebook==6.5.5
|
308 |
+
notebook_shim==0.2.4
|
309 |
+
numba==0.60.0
|
310 |
+
numexpr==2.10.2
|
311 |
+
numpy==1.26.4
|
312 |
+
nvidia-cublas-cu12==12.6.4.1
|
313 |
+
nvidia-cuda-cupti-cu12==12.6.80
|
314 |
+
nvidia-cuda-nvcc-cu12==12.6.85
|
315 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
316 |
+
nvidia-cudnn-cu12==9.6.0.74
|
317 |
+
nvidia-cufft-cu12==11.3.0.4
|
318 |
+
nvidia-curand-cu12==10.3.7.77
|
319 |
+
nvidia-cusolver-cu12==11.7.1.2
|
320 |
+
nvidia-cusparse-cu12==12.5.4.2
|
321 |
+
nvidia-nccl-cu12==2.23.4
|
322 |
+
nvidia-nvjitlink-cu12==12.6.85
|
323 |
+
nvtx==0.2.10
|
324 |
+
nx-cugraph-cu12 @ https://pypi.nvidia.com/nx-cugraph-cu12/nx_cugraph_cu12-24.10.0-py3-none-any.whl
|
325 |
+
oauth2client==4.1.3
|
326 |
+
oauthlib==3.2.2
|
327 |
+
omegaconf==2.3.0
|
328 |
+
open_clip_torch==2.26.1
|
329 |
+
openai==1.57.4
|
330 |
+
opencv-contrib-python==4.10.0.84
|
331 |
+
opencv-python==4.8.1.78
|
332 |
+
opencv-python-headless==4.10.0.84
|
333 |
+
openpyxl==3.1.5
|
334 |
+
opentelemetry-api==1.29.0
|
335 |
+
opentelemetry-sdk==1.29.0
|
336 |
+
opentelemetry-semantic-conventions==0.50b0
|
337 |
+
opt_einsum==3.4.0
|
338 |
+
optax==0.2.4
|
339 |
+
optree==0.13.1
|
340 |
+
orbax-checkpoint==0.6.4
|
341 |
+
orjson==3.10.12
|
342 |
+
osqp==0.6.7.post3
|
343 |
+
packaging==24.2
|
344 |
+
pandas==2.0.3
|
345 |
+
pandas-datareader==0.10.0
|
346 |
+
pandas-gbq==0.25.0
|
347 |
+
pandas-stubs==2.2.2.240909
|
348 |
+
pandocfilters==1.5.1
|
349 |
+
panel==1.5.4
|
350 |
+
param==2.2.0
|
351 |
+
parso==0.8.4
|
352 |
+
parsy==2.1
|
353 |
+
partd==1.4.2
|
354 |
+
pathlib==1.0.1
|
355 |
+
pathspec==0.12.1
|
356 |
+
patsy==1.0.1
|
357 |
+
peewee==3.17.8
|
358 |
+
peft==0.14.0
|
359 |
+
pexpect==4.9.0
|
360 |
+
pickleshare==0.7.5
|
361 |
+
Pillow==9.4.0
|
362 |
+
platformdirs==4.3.6
|
363 |
+
plotly==5.24.1
|
364 |
+
plotnine==0.14.4
|
365 |
+
pluggy==1.5.0
|
366 |
+
ply==3.11
|
367 |
+
polars==1.9.0
|
368 |
+
pooch==1.8.2
|
369 |
+
portalocker==3.0.0
|
370 |
+
portpicker==1.5.2
|
371 |
+
preshed==3.0.9
|
372 |
+
prettytable==3.12.0
|
373 |
+
proglog==0.1.10
|
374 |
+
progressbar2==4.5.0
|
375 |
+
prometheus_client==0.21.1
|
376 |
+
promise==2.3
|
377 |
+
prompt_toolkit==3.0.48
|
378 |
+
propcache==0.2.1
|
379 |
+
prophet==1.1.6
|
380 |
+
proto-plus==1.25.0
|
381 |
+
protobuf==4.25.5
|
382 |
+
psutil==5.9.5
|
383 |
+
psycopg2==2.9.10
|
384 |
+
ptyprocess==0.7.0
|
385 |
+
py-cpuinfo==9.0.0
|
386 |
+
py4j==0.10.9.7
|
387 |
+
pyarrow==17.0.0
|
388 |
+
pyasn1==0.6.1
|
389 |
+
pyasn1_modules==0.4.1
|
390 |
+
pycocotools==2.0.7
|
391 |
+
pycparser==2.22
|
392 |
+
pydantic==1.10.19
|
393 |
+
pydantic_core==2.27.1
|
394 |
+
pydata-google-auth==1.9.0
|
395 |
+
pydot==3.0.3
|
396 |
+
pydotplus==2.0.2
|
397 |
+
PyDrive==1.3.1
|
398 |
+
PyDrive2==1.21.3
|
399 |
+
pyerfa==2.0.1.5
|
400 |
+
pygame==2.6.1
|
401 |
+
pygit2==1.16.0
|
402 |
+
Pygments==2.18.0
|
403 |
+
PyGObject==3.42.1
|
404 |
+
PyJWT==2.10.1
|
405 |
+
pylibcudf-cu12 @ https://pypi.nvidia.com/pylibcudf-cu12/pylibcudf_cu12-24.10.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
406 |
+
pylibcugraph-cu12==24.10.0
|
407 |
+
pylibraft-cu12==24.10.0
|
408 |
+
pymc==5.19.1
|
409 |
+
pymystem3==0.2.0
|
410 |
+
pynvjitlink-cu12==0.4.0
|
411 |
+
pyogrio==0.10.0
|
412 |
+
Pyomo==6.8.2
|
413 |
+
PyOpenGL==3.1.7
|
414 |
+
pyOpenSSL==24.2.1
|
415 |
+
pyparsing==3.2.0
|
416 |
+
pyperclip==1.9.0
|
417 |
+
pyproj==3.7.0
|
418 |
+
pyshp==2.3.1
|
419 |
+
PySocks==1.7.1
|
420 |
+
pyspark==3.5.3
|
421 |
+
pytensor==2.26.4
|
422 |
+
pytest==8.3.4
|
423 |
+
python-apt==0.0.0
|
424 |
+
python-box==7.3.0
|
425 |
+
python-dateutil==2.8.2
|
426 |
+
python-louvain==0.16
|
427 |
+
python-slugify==8.0.4
|
428 |
+
python-utils==3.9.1
|
429 |
+
pytz==2024.2
|
430 |
+
pyviz_comms==3.0.3
|
431 |
+
PyWavelets==1.8.0
|
432 |
+
PyYAML==6.0.1
|
433 |
+
pyzmq==24.0.1
|
434 |
+
qdldl==0.1.7.post4
|
435 |
+
ratelim==0.1.6
|
436 |
+
referencing==0.35.1
|
437 |
+
regex==2023.10.3
|
438 |
+
requests==2.32.3
|
439 |
+
requests-oauthlib==1.3.1
|
440 |
+
requests-toolbelt==1.0.0
|
441 |
+
requirements-parser==0.9.0
|
442 |
+
rich==13.9.4
|
443 |
+
rmm-cu12==24.10.0
|
444 |
+
rpds-py==0.22.3
|
445 |
+
rpy2==3.4.2
|
446 |
+
rsa==4.9
|
447 |
+
safetensors==0.4.5
|
448 |
+
scikit-image==0.21.0
|
449 |
+
scikit-learn==1.3.1
|
450 |
+
scipy==1.13.1
|
451 |
+
scooby==0.10.0
|
452 |
+
scs==3.2.7
|
453 |
+
seaborn==0.13.2
|
454 |
+
SecretStorage==3.3.1
|
455 |
+
Send2Trash==1.8.3
|
456 |
+
sentence-transformers==3.3.1
|
457 |
+
sentencepiece==0.1.99
|
458 |
+
sentry-sdk==2.19.2
|
459 |
+
setproctitle==1.3.4
|
460 |
+
shap==0.46.0
|
461 |
+
shapely==2.0.6
|
462 |
+
shellingham==1.5.4
|
463 |
+
simple-parsing==0.1.6
|
464 |
+
six==1.17.0
|
465 |
+
sklearn-pandas==2.2.0
|
466 |
+
slicer==0.0.8
|
467 |
+
smart-open==7.1.0
|
468 |
+
smmap==5.0.1
|
469 |
+
sniffio==1.3.1
|
470 |
+
snowballstemmer==2.2.0
|
471 |
+
soundfile==0.12.1
|
472 |
+
soupsieve==2.6
|
473 |
+
soxr==0.5.0.post1
|
474 |
+
spacy==3.7.5
|
475 |
+
spacy-legacy==3.0.12
|
476 |
+
spacy-loggers==1.0.5
|
477 |
+
Sphinx==8.1.3
|
478 |
+
sphinxcontrib-applehelp==2.0.0
|
479 |
+
sphinxcontrib-devhelp==2.0.0
|
480 |
+
sphinxcontrib-htmlhelp==2.1.0
|
481 |
+
sphinxcontrib-jsmath==1.0.1
|
482 |
+
sphinxcontrib-qthelp==2.0.0
|
483 |
+
sphinxcontrib-serializinghtml==2.0.0
|
484 |
+
SQLAlchemy==2.0.36
|
485 |
+
sqlglot==25.1.0
|
486 |
+
sqlparse==0.5.3
|
487 |
+
srsly==2.5.0
|
488 |
+
stanio==0.5.1
|
489 |
+
statsmodels==0.14.4
|
490 |
+
StrEnum==0.4.15
|
491 |
+
stringzilla==3.11.1
|
492 |
+
sympy==1.13.1
|
493 |
+
tables==3.10.1
|
494 |
+
tabulate==0.9.0
|
495 |
+
tbb==2022.0.0
|
496 |
+
tcmlib==1.2.0
|
497 |
+
tenacity==9.0.0
|
498 |
+
tensorboard==2.17.1
|
499 |
+
tensorboard-data-server==0.7.2
|
500 |
+
tensorflow==2.17.1
|
501 |
+
tensorflow-datasets==4.9.7
|
502 |
+
tensorflow-hub==0.16.1
|
503 |
+
tensorflow-io-gcs-filesystem==0.37.1
|
504 |
+
tensorflow-metadata==1.13.1
|
505 |
+
tensorflow-probability==0.24.0
|
506 |
+
tensorstore==0.1.71
|
507 |
+
termcolor==2.5.0
|
508 |
+
terminado==0.18.1
|
509 |
+
text-unidecode==1.3
|
510 |
+
textblob==0.17.1
|
511 |
+
tf-slim==1.1.0
|
512 |
+
tf_keras==2.17.0
|
513 |
+
thinc==8.2.5
|
514 |
+
threadpoolctl==3.5.0
|
515 |
+
tifffile==2024.12.12
|
516 |
+
timm==0.4.12
|
517 |
+
tinycss2==1.4.0
|
518 |
+
tokenizers==0.14.1
|
519 |
+
toml==0.10.2
|
520 |
+
tomli==2.2.1
|
521 |
+
toolz==0.12.1
|
522 |
+
torch @ https://download.pytorch.org/whl/cu121_full/torch-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl
|
523 |
+
torchaudio @ https://download.pytorch.org/whl/cu121/torchaudio-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl
|
524 |
+
torchsummary==1.5.1
|
525 |
+
torchvision @ https://download.pytorch.org/whl/cu121/torchvision-0.20.1%2Bcu121-cp310-cp310-linux_x86_64.whl
|
526 |
+
tornado==6.3.3
|
527 |
+
tqdm==4.67.1
|
528 |
+
traitlets==5.7.1
|
529 |
+
traittypes==0.2.1
|
530 |
+
transformers==4.34.0
|
531 |
+
tweepy==4.14.0
|
532 |
+
typeguard==4.4.1
|
533 |
+
typer==0.15.1
|
534 |
+
types-pytz==2024.2.0.20241003
|
535 |
+
types-setuptools==75.6.0.20241126
|
536 |
+
typing_extensions==4.12.2
|
537 |
+
tzdata==2024.2
|
538 |
+
tzlocal==5.2
|
539 |
+
uc-micro-py==1.0.3
|
540 |
+
umf==0.9.1
|
541 |
+
uritemplate==4.1.1
|
542 |
+
urllib3==2.2.3
|
543 |
+
vega-datasets==0.9.0
|
544 |
+
vision-datasets==0.2.2
|
545 |
+
wadllib==1.3.6
|
546 |
+
wandb==0.19.1
|
547 |
+
wasabi==1.1.3
|
548 |
+
wcwidth==0.2.13
|
549 |
+
weasel==0.4.1
|
550 |
+
webcolors==24.11.1
|
551 |
+
webencodings==0.5.1
|
552 |
+
websocket-client==1.8.0
|
553 |
+
websockets==14.1
|
554 |
+
Werkzeug==3.1.3
|
555 |
+
widgetsnbextension==3.6.10
|
556 |
+
wordcloud==1.9.4
|
557 |
+
wrapt==1.17.0
|
558 |
+
xarray==2024.11.0
|
559 |
+
xarray-einstats==0.8.0
|
560 |
+
xgboost==2.1.3
|
561 |
+
xlrd==2.0.1
|
562 |
+
xyzservices==2024.9.0
|
563 |
+
yacs==0.1.8
|
564 |
+
yarl==1.18.3
|
565 |
+
yellowbrick==1.5
|
566 |
+
yfinance==0.2.50
|
567 |
+
zipp==3.21.0
|
colabs/requirements-colab.txt
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pillow==9.4.0
|
2 |
+
opencv-python==4.8.1.78
|
3 |
+
pyyaml==6.0.1
|
4 |
+
json_tricks==3.17.3
|
5 |
+
yacs==0.1.8
|
6 |
+
scikit-learn==1.3.1
|
7 |
+
pandas==2.0.3
|
8 |
+
timm==0.4.12
|
9 |
+
numpy==1.26.4
|
10 |
+
einops==0.8.0
|
11 |
+
fvcore==0.1.5.post20221221
|
12 |
+
transformers==4.34.0
|
13 |
+
sentencepiece==0.1.99
|
14 |
+
ftfy==6.1.1
|
15 |
+
regex==2023.10.3
|
16 |
+
nltk==3.8.1
|
17 |
+
mpi4py==3.1.5
|
18 |
+
vision-datasets==0.2.2
|
19 |
+
cython==3.0.2
|
20 |
+
pycocotools==2.0.7
|
21 |
+
diffdist==0.1
|
22 |
+
#pyarrow==13.0.0
|
23 |
+
#cityscapesscripts==2.2.2
|
24 |
+
#shapely==1.8.0
|
25 |
+
scikit-image==0.21.0
|
26 |
+
mup==1.0.0
|
27 |
+
accelerate==0.23.0
|
28 |
+
kornia==0.7.0
|
29 |
+
deepspeed==0.10.3
|
30 |
+
#wandb==0.15.12
|
31 |
+
infinibatch==0.1.1
|
32 |
+
open-clip-torch==2.26.1
|
33 |
+
git+https://github.com/MaureenZOU/detectron2-xyz.git
|
34 |
+
#gradio==3.42.0
|
35 |
+
#torch==2.3.1 #2.0.1
|
36 |
+
#torchvision==0.15.2
|
37 |
+
#torchaudio==2.0.2
|
38 |
+
#torch==2.1.0
|
39 |
+
#torchvision==0.16.0
|
configs/biomedparse_inference.yaml
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Define Test/Trainer/Saving
|
2 |
+
PIPELINE: XDecoderPipeline
|
3 |
+
TRAINER: xdecoder
|
4 |
+
SAVE_DIR: "../../data/output/test"
|
5 |
+
base_path: "./"
|
6 |
+
|
7 |
+
# Resume Logistic
|
8 |
+
RESUME: false
|
9 |
+
WEIGHT: false
|
10 |
+
RESUME_FROM: ""
|
11 |
+
EVAL_AT_START: false
|
12 |
+
|
13 |
+
# Logging and Debug
|
14 |
+
WANDB: False
|
15 |
+
LOG_EVERY: 100
|
16 |
+
FIND_UNUSED_PARAMETERS: false
|
17 |
+
|
18 |
+
# Speed up training
|
19 |
+
FP16: false
|
20 |
+
PORT: "36873"
|
21 |
+
|
22 |
+
# misc
|
23 |
+
LOADER:
|
24 |
+
JOINT: False
|
25 |
+
KEY_DATASET: "coco"
|
26 |
+
|
27 |
+
STANDARD_TEXT_FOR_EVAL: False
|
28 |
+
|
29 |
+
##################
|
30 |
+
# Task settings
|
31 |
+
##################
|
32 |
+
VERBOSE: true
|
33 |
+
MODEL:
|
34 |
+
NAME: seem_model_demo
|
35 |
+
HEAD: xdecoder_head
|
36 |
+
DIM_PROJ: 512
|
37 |
+
TEXT:
|
38 |
+
ARCH: vlpencoder
|
39 |
+
NAME: transformer
|
40 |
+
TOKENIZER: clip
|
41 |
+
CONTEXT_LENGTH: 77 # 77
|
42 |
+
WIDTH: 512
|
43 |
+
HEADS: 8
|
44 |
+
LAYERS: 12 # 6
|
45 |
+
AUTOGRESSIVE: True
|
46 |
+
BACKBONE:
|
47 |
+
NAME: focal
|
48 |
+
PRETRAINED: ""
|
49 |
+
LOAD_PRETRAINED: false
|
50 |
+
FOCAL:
|
51 |
+
PRETRAIN_IMG_SIZE: 224
|
52 |
+
PATCH_SIZE: 4
|
53 |
+
EMBED_DIM: 192
|
54 |
+
DEPTHS: [2, 2, 18, 2]
|
55 |
+
FOCAL_LEVELS: [4, 4, 4, 4]
|
56 |
+
FOCAL_WINDOWS: [3, 3, 3, 3]
|
57 |
+
DROP_PATH_RATE: 0.3
|
58 |
+
MLP_RATIO: 4.0
|
59 |
+
DROP_RATE: 0.0
|
60 |
+
PATCH_NORM: True
|
61 |
+
USE_CONV_EMBED: True
|
62 |
+
SCALING_MODULATOR: True
|
63 |
+
USE_CHECKPOINT: False
|
64 |
+
USE_POSTLN: true
|
65 |
+
USE_POSTLN_IN_MODULATION: false
|
66 |
+
USE_LAYERSCALE: True
|
67 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
68 |
+
OUT_INDICES: [0, 1, 2, 3]
|
69 |
+
ENCODER:
|
70 |
+
NAME: transformer_encoder_fpn
|
71 |
+
IGNORE_VALUE: 255
|
72 |
+
NUM_CLASSES: 16
|
73 |
+
BINARY_CLASSES: False
|
74 |
+
LOSS_WEIGHT: 1.0
|
75 |
+
CONVS_DIM: 512
|
76 |
+
MASK_DIM: 512
|
77 |
+
NORM: "GN"
|
78 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
79 |
+
DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
|
80 |
+
COMMON_STRIDE: 4
|
81 |
+
TRANSFORMER_ENC_LAYERS: 6
|
82 |
+
DECODER:
|
83 |
+
NAME: seem_demo
|
84 |
+
TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
|
85 |
+
MASK:
|
86 |
+
ENABLED: False
|
87 |
+
DETECTION: False
|
88 |
+
SPATIAL:
|
89 |
+
ENABLED: True
|
90 |
+
MAX_ITER: 1
|
91 |
+
GROUNDING:
|
92 |
+
ENABLED: True
|
93 |
+
MAX_LEN: 5
|
94 |
+
TEXT_WEIGHT: 2.0
|
95 |
+
CLASS_WEIGHT: 0.5
|
96 |
+
VISUAL:
|
97 |
+
ENABLED: False
|
98 |
+
AUDIO:
|
99 |
+
ENABLED: False
|
100 |
+
RETRIEVAL:
|
101 |
+
ENABLED: False
|
102 |
+
LVIS:
|
103 |
+
ENABLED: True
|
104 |
+
THRES: 0.7
|
105 |
+
OPENIMAGE:
|
106 |
+
ENABLED: False
|
107 |
+
NEGATIVE_SAMPLES: 5
|
108 |
+
GROUNDING:
|
109 |
+
ENABLED: False
|
110 |
+
MAX_LEN: 5
|
111 |
+
CAPTION:
|
112 |
+
ENABLED: False
|
113 |
+
PHRASE_PROB: 0.5
|
114 |
+
SIM_THRES: 0.95
|
115 |
+
DEEP_SUPERVISION: True
|
116 |
+
NO_OBJECT_WEIGHT: 0.1
|
117 |
+
GCLASS_WEIGHT: 0.4
|
118 |
+
GMASK_WEIGHT: 1.0
|
119 |
+
GDICE_WEIGHT: 1.0
|
120 |
+
SCLASS_WEIGHT: 0.4
|
121 |
+
SMASK_WEIGHT: 1.0
|
122 |
+
SDICE_WEIGHT: 1.0
|
123 |
+
OCLASS_WEIGHT: 0.4
|
124 |
+
OMASK_WEIGHT: 1.0
|
125 |
+
ODICE_WEIGHT: 1.0
|
126 |
+
CLASS_WEIGHT: 2.0
|
127 |
+
MASK_WEIGHT: 5.0
|
128 |
+
DICE_WEIGHT: 5.0
|
129 |
+
BBOX_WEIGHT: 5.0
|
130 |
+
GIOU_WEIGHT: 2.0
|
131 |
+
CAPTION_WEIGHT: 2.0
|
132 |
+
COST_SPATIAL:
|
133 |
+
CLASS_WEIGHT: 5.0
|
134 |
+
MASK_WEIGHT: 2.0
|
135 |
+
DICE_WEIGHT: 2.0
|
136 |
+
HIDDEN_DIM: 512
|
137 |
+
NUM_OBJECT_QUERIES: 101
|
138 |
+
NHEADS: 8
|
139 |
+
DROPOUT: 0.0
|
140 |
+
DIM_FEEDFORWARD: 2048
|
141 |
+
MAX_SPATIAL_LEN: [512, 512, 512, 512]
|
142 |
+
# ENC_LAYERS: 0
|
143 |
+
PRE_NORM: False
|
144 |
+
ENFORCE_INPUT_PROJ: False
|
145 |
+
SIZE_DIVISIBILITY: 32
|
146 |
+
TRAIN_NUM_POINTS: 12544
|
147 |
+
OVERSAMPLE_RATIO: 3.0
|
148 |
+
IMPORTANCE_SAMPLE_RATIO: 0.75
|
149 |
+
DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
|
150 |
+
TOP_GROUNDING_LAYERS: 10
|
151 |
+
TOP_CAPTION_LAYERS: 10
|
152 |
+
TOP_SPATIAL_LAYERS: 10
|
153 |
+
TOP_OPENIMAGE_LAYERS: 10
|
154 |
+
TEST:
|
155 |
+
SEMANTIC_ON: True
|
156 |
+
INSTANCE_ON: True
|
157 |
+
PANOPTIC_ON: True
|
158 |
+
OVERLAP_THRESHOLD: 0.8
|
159 |
+
OBJECT_MASK_THRESHOLD: 0.4
|
160 |
+
SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false
|
161 |
+
DETECTIONS_PER_IMAGE: 100
|
162 |
+
|
163 |
+
# Multi-modal Architecture, order matters
|
164 |
+
ATTENTION_ARCH:
|
165 |
+
VARIABLE:
|
166 |
+
queries: ["object"]
|
167 |
+
tokens: ["grounding", "spatial", "visual", "audio"]
|
168 |
+
SELF_ATTENTION:
|
169 |
+
queries:
|
170 |
+
object:
|
171 |
+
[
|
172 |
+
"queries_object",
|
173 |
+
"tokens_grounding",
|
174 |
+
"tokens_spatial",
|
175 |
+
"tokens_visual",
|
176 |
+
"tokens_audio",
|
177 |
+
]
|
178 |
+
tokens:
|
179 |
+
grounding: ["queries_object", "tokens_grounding"]
|
180 |
+
spatial: ["tokens_spatial"]
|
181 |
+
visual: ["tokens_visual"]
|
182 |
+
audio: ["queries_object", "tokens_audio"]
|
183 |
+
CROSS_ATTENTION:
|
184 |
+
queries:
|
185 |
+
object: True
|
186 |
+
tokens:
|
187 |
+
grounding: False
|
188 |
+
spatial: False
|
189 |
+
visual: False
|
190 |
+
audio: False
|
191 |
+
MASKING:
|
192 |
+
["tokens_spatial", "tokens_grounding", "tokens_visual", "tokens_audio"]
|
193 |
+
DUPLICATION:
|
194 |
+
queries:
|
195 |
+
grounding: "queries_object"
|
196 |
+
spatial: "queries_object"
|
197 |
+
SPATIAL_MEMORIES: 32
|
198 |
+
|
199 |
+
INPUT:
|
200 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
201 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
202 |
+
# INPUT:
|
203 |
+
# PIXEL_MEAN: [64.284, 59.293, 59.962]
|
204 |
+
# PIXEL_STD: [62.484, 60.865, 59.835]
|
entrypoint.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
if [ -f "/run/secrets/HF_TOKEN" ]; then
|
3 |
+
export HF_TOKEN=$(cat /run/secrets/HF_TOKEN)
|
4 |
+
fi
|
5 |
+
exec conda run --no-capture-output -n biomedparse python main.py
|
examples/Part_1_516_pathology_breast.png
ADDED
Git LFS Details
|
inference_utils/inference.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from PIL import Image
|
5 |
+
from torchvision import transforms
|
6 |
+
#from utils.visualizer import Visualizer
|
7 |
+
# from detectron2.utils.colormap import random_color
|
8 |
+
# from detectron2.data import MetadataCatalog
|
9 |
+
# from detectron2.structures import BitMasks
|
10 |
+
from modeling.language.loss import vl_similarity
|
11 |
+
from utilities.constants import BIOMED_CLASSES
|
12 |
+
#from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
|
13 |
+
|
14 |
+
# import cv2
|
15 |
+
# import os
|
16 |
+
# import glob
|
17 |
+
# import subprocess
|
18 |
+
from PIL import Image
|
19 |
+
import random
|
20 |
+
|
21 |
+
t = []
|
22 |
+
t.append(transforms.Resize((1024, 1024), interpolation=Image.BICUBIC))
|
23 |
+
transform = transforms.Compose(t)
|
24 |
+
#metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
25 |
+
all_classes = ['background'] + [name.replace('-other','').replace('-merged','')
|
26 |
+
for name in BIOMED_CLASSES] + ["others"]
|
27 |
+
# colors_list = [(np.array(color['color'])/255).tolist() for color in COCO_CATEGORIES] + [[1, 1, 1]]
|
28 |
+
|
29 |
+
# use color list from matplotlib
|
30 |
+
import matplotlib.colors as mcolors
|
31 |
+
colors = dict(mcolors.TABLEAU_COLORS, **mcolors.BASE_COLORS)
|
32 |
+
colors_list = [list(colors.values())[i] for i in range(16)]
|
33 |
+
|
34 |
+
from .output_processing import mask_stats, combine_masks
|
35 |
+
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def interactive_infer_image(model, image, prompts):
|
39 |
+
|
40 |
+
image_resize = transform(image)
|
41 |
+
width = image.size[0]
|
42 |
+
height = image.size[1]
|
43 |
+
image_resize = np.asarray(image_resize)
|
44 |
+
image = torch.from_numpy(image_resize.copy()).permute(2,0,1).cuda()
|
45 |
+
|
46 |
+
data = {"image": image, 'text': prompts, "height": height, "width": width}
|
47 |
+
|
48 |
+
# inistalize task
|
49 |
+
model.model.task_switch['spatial'] = False
|
50 |
+
model.model.task_switch['visual'] = False
|
51 |
+
model.model.task_switch['grounding'] = True
|
52 |
+
model.model.task_switch['audio'] = False
|
53 |
+
model.model.task_switch['grounding'] = True
|
54 |
+
|
55 |
+
|
56 |
+
batch_inputs = [data]
|
57 |
+
results,image_size,extra = model.model.evaluate_demo(batch_inputs)
|
58 |
+
|
59 |
+
pred_masks = results['pred_masks'][0]
|
60 |
+
v_emb = results['pred_captions'][0]
|
61 |
+
t_emb = extra['grounding_class']
|
62 |
+
|
63 |
+
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
64 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
65 |
+
|
66 |
+
temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
|
67 |
+
out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
68 |
+
|
69 |
+
matched_id = out_prob.max(0)[1]
|
70 |
+
pred_masks_pos = pred_masks[matched_id,:,:]
|
71 |
+
pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
|
72 |
+
|
73 |
+
# interpolate mask to ori size
|
74 |
+
pred_mask_prob = F.interpolate(pred_masks_pos[None,], (data['height'], data['width']),
|
75 |
+
mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy()
|
76 |
+
pred_masks_pos = (1*(pred_mask_prob > 0.5)).astype(np.uint8)
|
77 |
+
|
78 |
+
return pred_mask_prob
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
# def interactive_infer_panoptic_biomedseg(model, image, tasks, reftxt=None):
|
83 |
+
# image_ori = transform(image)
|
84 |
+
# #mask_ori = image['mask']
|
85 |
+
# width = image_ori.size[0]
|
86 |
+
# height = image_ori.size[1]
|
87 |
+
# image_ori = np.asarray(image_ori)
|
88 |
+
# visual = Visualizer(image_ori, metadata=metadata)
|
89 |
+
# images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
90 |
+
|
91 |
+
# data = {"image": images, "height": height, "width": width}
|
92 |
+
# if len(tasks) == 0:
|
93 |
+
# tasks = ["Panoptic"]
|
94 |
+
|
95 |
+
# # inistalize task
|
96 |
+
# model.model.task_switch['spatial'] = False
|
97 |
+
# model.model.task_switch['visual'] = False
|
98 |
+
# model.model.task_switch['grounding'] = False
|
99 |
+
# model.model.task_switch['audio'] = False
|
100 |
+
|
101 |
+
# # check if reftxt is list of strings
|
102 |
+
# assert isinstance(reftxt, list), f"reftxt should be a list of strings, but got {type(reftxt)}"
|
103 |
+
# model.model.task_switch['grounding'] = True
|
104 |
+
# predicts = {}
|
105 |
+
# for i, txt in enumerate(reftxt):
|
106 |
+
# data['text'] = txt
|
107 |
+
# batch_inputs = [data]
|
108 |
+
|
109 |
+
# results,image_size,extra = model.model.evaluate_demo(batch_inputs)
|
110 |
+
|
111 |
+
# pred_masks = results['pred_masks'][0]
|
112 |
+
# v_emb = results['pred_captions'][0]
|
113 |
+
# t_emb = extra['grounding_class']
|
114 |
+
|
115 |
+
# t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
116 |
+
# v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
117 |
+
|
118 |
+
# temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
|
119 |
+
# out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
120 |
+
|
121 |
+
# matched_id = out_prob.max(0)[1]
|
122 |
+
# pred_masks_pos = pred_masks[matched_id,:,:]
|
123 |
+
# pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
|
124 |
+
|
125 |
+
|
126 |
+
# # interpolate mask to ori size
|
127 |
+
# #pred_masks_pos = (F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']] > 0.0).float().cpu().numpy()
|
128 |
+
# # masks.append(pred_masks_pos[0])
|
129 |
+
# # mask = pred_masks_pos[0]
|
130 |
+
# # masks.append(mask)
|
131 |
+
# # interpolate mask to ori size
|
132 |
+
# pred_mask_prob = F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy()
|
133 |
+
# #pred_masks_pos = 1*(pred_mask_prob > 0.5)
|
134 |
+
# predicts[txt] = pred_mask_prob[0]
|
135 |
+
|
136 |
+
# masks = combine_masks(predicts)
|
137 |
+
|
138 |
+
# predict_mask_stats = {}
|
139 |
+
# print(masks.keys())
|
140 |
+
# for i, txt in enumerate(masks):
|
141 |
+
# mask = masks[txt]
|
142 |
+
# demo = visual.draw_binary_mask(mask, color=colors_list[i], text=txt)
|
143 |
+
# predict_mask_stats[txt] = mask_stats((predicts[txt]*255), image_ori)
|
144 |
+
|
145 |
+
# res = demo.get_image()
|
146 |
+
# torch.cuda.empty_cache()
|
147 |
+
# # return Image.fromarray(res), stroke_inimg, stroke_refimg
|
148 |
+
# return Image.fromarray(res), None, predict_mask_stats
|
149 |
+
|
inference_utils/output_processing.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from scipy import stats
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import huggingface_hub
|
6 |
+
|
7 |
+
|
8 |
+
def check_mask_stats(img, mask, modality_type, target):
|
9 |
+
# img: np.array, shape=(H, W, 3) RGB image with pixel values in [0, 255]
|
10 |
+
# mask: np.array, shape=(H, W, 1) mask probability scaled to [0,255] with pixel values in [0, 255]
|
11 |
+
# modality_type: str, see target_dist.json for the list of modality types
|
12 |
+
# target: str, see target_dist.json for the list of targets
|
13 |
+
|
14 |
+
huggingface_hub.hf_hub_download('microsoft/BiomedParse', filename='target_dist.json', local_dir='./inference_utils')
|
15 |
+
huggingface_hub.hf_hub_download('microsoft/BiomedParse', filename="config.yaml", local_dir="./configs")
|
16 |
+
target_dist = json.load(open("inference_utils/target_dist.json"))
|
17 |
+
|
18 |
+
if modality_type not in target_dist:
|
19 |
+
raise ValueError(f"Currently support modality types: {list(target_dist.keys())}")
|
20 |
+
|
21 |
+
if target not in target_dist[modality_type]:
|
22 |
+
raise ValueError(f"Currently support targets for {modality_type}: {list(target_dist[modality_type].keys())}")
|
23 |
+
|
24 |
+
ms = mask_stats(mask, img)
|
25 |
+
|
26 |
+
ps = [stats.ks_1samp([ms[i]], stats.beta(param[0], param[1]).cdf).pvalue for i, param in enumerate(target_dist[modality_type][target])]
|
27 |
+
p_value = np.prod(ps)
|
28 |
+
|
29 |
+
adj_p_value = p_value**0.24 # adjustment for four test products
|
30 |
+
|
31 |
+
return adj_p_value
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
def mask_stats(mask, img):
|
36 |
+
# mask is a prediction mask with pixel values in [0, 255] for probability in [0, 1]
|
37 |
+
# img is a RGB image with pixel values in [0, 255]
|
38 |
+
if mask.max() <= 127:
|
39 |
+
return [0, 0, 0, 0]
|
40 |
+
return [mask[mask>=128].mean()/256, img[:,:,0][mask>=128].mean()/256,
|
41 |
+
img[:,:,1][mask>=128].mean()/256, img[:,:,2][mask>=128].mean()/256]
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
def combine_masks(predicts):
|
46 |
+
# predicts: a dictionary of pixel probability, {TARGET: pred_prob}
|
47 |
+
pixel_preds = {}
|
48 |
+
target_area = {}
|
49 |
+
target_probs = {}
|
50 |
+
for target in predicts:
|
51 |
+
pred = predicts[target]
|
52 |
+
pred_region = np.where(pred > 0.1)
|
53 |
+
target_area[target] = 0
|
54 |
+
target_probs[target] = 0
|
55 |
+
for (i,j) in zip(*pred_region):
|
56 |
+
if (i,j) not in pixel_preds:
|
57 |
+
pixel_preds[(i,j)] = {}
|
58 |
+
pixel_preds[(i,j)][target] = pred[i,j]
|
59 |
+
target_area[target] += 1
|
60 |
+
target_probs[target] += pred[i,j]
|
61 |
+
for target in predicts:
|
62 |
+
if target_area[target] == 0:
|
63 |
+
continue
|
64 |
+
target_probs[target] /= target_area[target]
|
65 |
+
|
66 |
+
# generate combined masks
|
67 |
+
combined_areas = {t: 0 for t in predicts}
|
68 |
+
for index in pixel_preds:
|
69 |
+
pred_target = sorted(pixel_preds[index].keys(), key=lambda t: pixel_preds[index][t], reverse=True)[0]
|
70 |
+
combined_areas[pred_target] += 1
|
71 |
+
|
72 |
+
# discard targets with small areas
|
73 |
+
discard_targets = []
|
74 |
+
for target in predicts:
|
75 |
+
if combined_areas[target] < 0.6 * target_area[target]:
|
76 |
+
discard_targets.append(target)
|
77 |
+
|
78 |
+
# keep the most confident target
|
79 |
+
most_confident_target = sorted(predicts.keys(), key=lambda t: target_probs[t], reverse=True)[0]
|
80 |
+
|
81 |
+
discard_targets = [t for t in discard_targets if t != most_confident_target]
|
82 |
+
|
83 |
+
masks = {t: np.zeros_like(predicts[t]).astype(np.uint8) for t in predicts if t not in discard_targets}
|
84 |
+
for index in pixel_preds:
|
85 |
+
candidates = [t for t in pixel_preds[index] if t not in discard_targets and pixel_preds[index][t] > 0.5]
|
86 |
+
if len(candidates) == 0:
|
87 |
+
continue
|
88 |
+
pred_target = max(candidates, key=lambda t: pixel_preds[index][t])
|
89 |
+
masks[pred_target][index[0], index[1]] = 1
|
90 |
+
|
91 |
+
return masks
|
inference_utils/processing_utils.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from skimage import transform
|
3 |
+
import pydicom
|
4 |
+
from io import BytesIO
|
5 |
+
from PIL import Image
|
6 |
+
import nibabel as nib
|
7 |
+
import SimpleITK as sitk
|
8 |
+
from skimage import measure
|
9 |
+
|
10 |
+
|
11 |
+
"""
|
12 |
+
This script contains utility functions for reading and processing different imaging modalities.
|
13 |
+
"""
|
14 |
+
|
15 |
+
|
16 |
+
CT_WINDOWS = {'abdomen': [-150, 250],
|
17 |
+
'lung': [-1000, 1000],
|
18 |
+
'pelvis': [-55, 200],
|
19 |
+
'liver': [-25, 230],
|
20 |
+
'colon': [-68, 187],
|
21 |
+
'pancreas': [-100, 200]}
|
22 |
+
|
23 |
+
def process_intensity_image(image_data, is_CT, site=None):
|
24 |
+
# process intensity-based image. If CT, apply site specific windowing
|
25 |
+
|
26 |
+
# image_data: 2D numpy array of shape (H, W)
|
27 |
+
|
28 |
+
# return: 3-channel numpy array of shape (H, W, 3) as model input
|
29 |
+
|
30 |
+
if is_CT:
|
31 |
+
# process image with windowing
|
32 |
+
if site and site in CT_WINDOWS:
|
33 |
+
window = CT_WINDOWS[site]
|
34 |
+
else:
|
35 |
+
raise ValueError(f'Please choose CT site from {CT_WINDOWS.keys()}')
|
36 |
+
lower_bound, upper_bound = window
|
37 |
+
else:
|
38 |
+
# process image with intensity range 0.5-99.5 percentile
|
39 |
+
lower_bound, upper_bound = np.percentile(
|
40 |
+
image_data[image_data > 0], 0.5
|
41 |
+
), np.percentile(image_data[image_data > 0], 99.5)
|
42 |
+
|
43 |
+
image_data_pre = np.clip(image_data, lower_bound, upper_bound)
|
44 |
+
image_data_pre = (
|
45 |
+
(image_data_pre - image_data_pre.min())
|
46 |
+
/ (image_data_pre.max() - image_data_pre.min())
|
47 |
+
* 255.0
|
48 |
+
)
|
49 |
+
|
50 |
+
# pad to square with equal padding on both sides
|
51 |
+
shape = image_data_pre.shape
|
52 |
+
if shape[0] > shape[1]:
|
53 |
+
pad = (shape[0]-shape[1])//2
|
54 |
+
pad_width = ((0,0), (pad, pad))
|
55 |
+
elif shape[0] < shape[1]:
|
56 |
+
pad = (shape[1]-shape[0])//2
|
57 |
+
pad_width = ((pad, pad), (0,0))
|
58 |
+
else:
|
59 |
+
pad_width = None
|
60 |
+
|
61 |
+
if pad_width is not None:
|
62 |
+
image_data_pre = np.pad(image_data_pre, pad_width, 'constant', constant_values=0)
|
63 |
+
|
64 |
+
# resize image to 1024x1024
|
65 |
+
image_size = 1024
|
66 |
+
resize_image = transform.resize(image_data_pre, (image_size, image_size), order=3,
|
67 |
+
mode='constant', preserve_range=True, anti_aliasing=True)
|
68 |
+
|
69 |
+
# convert to 3-channel image
|
70 |
+
resize_image = np.stack([resize_image]*3, axis=-1)
|
71 |
+
|
72 |
+
return resize_image.astype(np.uint8)
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
def read_dicom(image_path, is_CT, site=None):
|
77 |
+
# read dicom file and return pixel data
|
78 |
+
|
79 |
+
# dicom_file: str, path to dicom file
|
80 |
+
# is_CT: bool, whether image is CT or not
|
81 |
+
# site: str, one of CT_WINDOWS.keys()
|
82 |
+
# return: 2D numpy array of shape (H, W)
|
83 |
+
|
84 |
+
ds = pydicom.dcmread(image_path)
|
85 |
+
image_array = ds.pixel_array * ds.RescaleSlope + ds.RescaleIntercept
|
86 |
+
|
87 |
+
image_array = process_intensity_image(image_array, is_CT, site)
|
88 |
+
|
89 |
+
return image_array
|
90 |
+
|
91 |
+
|
92 |
+
def read_nifti(image_path, is_CT, slice_idx, site=None, HW_index=(0, 1), channel_idx=None):
|
93 |
+
# read nifti file and return pixel data
|
94 |
+
|
95 |
+
# image_path: str, path to nifti file
|
96 |
+
# is_CT: bool, whether image is CT or not
|
97 |
+
# slice_idx: int, slice index to read
|
98 |
+
# site: str, one of CT_WINDOWS.keys()
|
99 |
+
# HW_index: tuple, index of height and width in the image shape
|
100 |
+
# return: 2D numpy array of shape (H, W)
|
101 |
+
|
102 |
+
|
103 |
+
nii = nib.load(image_path)
|
104 |
+
image_array = nii.get_fdata()
|
105 |
+
|
106 |
+
if HW_index != (0, 1):
|
107 |
+
image_array = np.moveaxis(image_array, HW_index, (0, 1))
|
108 |
+
|
109 |
+
# get slice
|
110 |
+
if channel_idx is None:
|
111 |
+
image_array = image_array[:, :, slice_idx]
|
112 |
+
else:
|
113 |
+
image_array = image_array[:, :, slice_idx, channel_idx]
|
114 |
+
|
115 |
+
image_array = process_intensity_image(image_array, is_CT, site)
|
116 |
+
return image_array
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
def read_rgb(image_path):
|
121 |
+
# read RGB image and return resized pixel data
|
122 |
+
|
123 |
+
# image_path: str, path to RGB image
|
124 |
+
# return: BytesIO buffer
|
125 |
+
|
126 |
+
# read image into numpy array
|
127 |
+
image = Image.open(image_path)
|
128 |
+
image = np.array(image)
|
129 |
+
if len(image.shape) == 2:
|
130 |
+
image = np.stack([image]*3, axis=-1)
|
131 |
+
elif image.shape[2] == 4:
|
132 |
+
image = image[:,:,:3]
|
133 |
+
|
134 |
+
# pad to square with equal padding on both sides
|
135 |
+
shape = image.shape
|
136 |
+
if shape[0] > shape[1]:
|
137 |
+
pad = (shape[0]-shape[1])//2
|
138 |
+
pad_width = ((0,0), (pad, pad), (0,0))
|
139 |
+
elif shape[0] < shape[1]:
|
140 |
+
pad = (shape[1]-shape[0])//2
|
141 |
+
pad_width = ((pad, pad), (0,0), (0,0))
|
142 |
+
else:
|
143 |
+
pad_width = None
|
144 |
+
|
145 |
+
if pad_width is not None:
|
146 |
+
image = np.pad(image, pad_width, 'constant', constant_values=0)
|
147 |
+
|
148 |
+
# resize image to 1024x1024 for each channel
|
149 |
+
image_size = 1024
|
150 |
+
resize_image = np.zeros((image_size, image_size, 3), dtype=np.uint8)
|
151 |
+
for i in range(3):
|
152 |
+
resize_image[:,:,i] = transform.resize(image[:,:,i], (image_size, image_size), order=3,
|
153 |
+
mode='constant', preserve_range=True, anti_aliasing=True)
|
154 |
+
|
155 |
+
return resize_image
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
def get_instances(mask):
|
160 |
+
# get intances from binary mask
|
161 |
+
seg = sitk.GetImageFromArray(mask)
|
162 |
+
filled = sitk.BinaryFillhole(seg)
|
163 |
+
d = sitk.SignedMaurerDistanceMap(filled, insideIsPositive=False, squaredDistance=False, useImageSpacing=False)
|
164 |
+
|
165 |
+
ws = sitk.MorphologicalWatershed( d, markWatershedLine=False, level=1)
|
166 |
+
ws = sitk.Mask( ws, sitk.Cast(seg, ws.GetPixelID()))
|
167 |
+
ins_mask = sitk.GetArrayFromImage(ws)
|
168 |
+
|
169 |
+
# filter out instances with small area outliers
|
170 |
+
props = measure.regionprops_table(ins_mask, properties=('label', 'area'))
|
171 |
+
mean_area = np.mean(props['area'])
|
172 |
+
std_area = np.std(props['area'])
|
173 |
+
|
174 |
+
threshold = mean_area - 2*std_area - 1
|
175 |
+
ins_mask_filtered = ins_mask.copy()
|
176 |
+
for i, area in zip(props['label'], props['area']):
|
177 |
+
if area < threshold:
|
178 |
+
ins_mask_filtered[ins_mask == i] = 0
|
179 |
+
|
180 |
+
return ins_mask_filtered
|
181 |
+
|
182 |
+
|
inference_utils/target_dist.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"CT-Abdomen": {"postcava": [[244.8001455798728, 5.314270814858824], [7.183679633251858, 5.168810995426391], [7.183679633251858, 5.168810995426391], [7.183679633251858, 5.168810995426391]], "aorta": [[570.5260544851909, 8.97527503179567], [3.3715049586348242, 1.4971164544774238], [3.3715049586348242, 1.4971164544774238], [3.3715049586348242, 1.4971164544774238]], "right kidney": [[831.8568013426873, 14.991866448573818], [4.970270375121704, 3.050385928796316], [4.970270375121704, 3.050385928796316], [4.970270375121704, 3.050385928796316]], "kidney": [[824.7288483151449, 17.740666994112335], [5.134294543833492, 3.188304874790919], [5.134294543833492, 3.188304874790919], [5.134294543833492, 3.188304874790919]], "left kidney": [[765.9269280548916, 14.314482540419498], [5.084499568327313, 3.2061871556243515], [5.084499568327313, 3.2061871556243515], [5.084499568327313, 3.2061871556243515]], "duodenum": [[121.5002253116006, 5.0616837393558045], [13.60882943690214, 15.313999640884173], [13.60882943690214, 15.313999640884173], [13.60882943690214, 15.313999640884173]], "pancreas": [[182.85416969377923, 6.9039775525067135], [17.489564177159146, 14.924761571311656], [17.489564177159146, 14.924761571311656], [17.489564177159146, 14.924761571311656]], "liver (non abdomen window)": [[481.5690096331249, 8.413924027868077], [6.047563882283547, 6.86712354789198], [6.047563882283547, 6.86712354789198], [6.047563882283547, 6.86712354789198]], "liver": [[497.88613290346797, 8.79208581405346], [20.552757782824486, 16.312687320589742], [20.552757782824486, 16.312687320589742], [20.552757782824486, 16.312687320589742]], "spleen": [[496.77984794364835, 8.498216025126785], [14.594250163059534, 10.71357260923987], [14.594250163059534, 10.71357260923987], [14.594250163059534, 10.71357260923987]], "stomach": [[137.7555592980079, 3.928159238756134], [5.978844398494112, 10.238758157160921], [5.978844398494112, 10.238758157160921], [5.978844398494112, 10.238758157160921]], "gallbladder": [[109.56988864543307, 3.4765854683723596], [32.35084093358493, 41.113482214152384], [32.35084093358493, 41.113482214152384], [32.35084093358493, 41.113482214152384]], "left adrenal gland": [[121.60075395406241, 4.266683492995461], [17.017417548383662, 18.48528509828753], [17.017417548383662, 18.48528509828753], [17.017417548383662, 18.48528509828753]], "adrenal gland": [[182.4265613513338, 7.813186080282246], [18.97442893128976, 20.599617257380345], [18.97442893128976, 20.599617257380345], [18.97442893128976, 20.599617257380345]], "right adrenal gland": [[158.21570288963346, 5.736947411814261], [17.17089273745977, 19.09450167978653], [17.17089273745977, 19.09450167978653], [17.17089273745977, 19.09450167978653]], "bladder": [[172.667607742299, 4.6885066612866835], [42.56984081338662, 56.45115036285909], [42.56984081338662, 56.45115036285909], [42.56984081338662, 56.45115036285909]], "esophagus": [[253.86092392814248, 6.886078359154348], [13.252110919965341, 15.437200766467301], [13.252110919965341, 15.437200766467301], [13.252110919965341, 15.437200766467301]]}, "CT-Chest": {"nodule": [[115.14726334918862, 3.0043952160348844], [5.275338876748403, 7.899248653413393], [5.275338876748403, 7.899248653413393], [5.275338876748403, 7.899248653413393]], "COVID-19 infection": [[226.93782607812352, 10.662200522447263], [11.74323002038987, 23.773784082857407], [11.74323002038987, 23.773784082857407], [11.74323002038987, 23.773784082857407]], "tumor": [[81.39154648592063, 3.0363381821985254], [9.799683628807484, 19.248706134279548], [9.799683628807484, 19.248706134279548], [9.799683628807484, 19.248706134279548]]}, "MRI-Abdomen": {"aorta": [[840.9822169946456, 13.699556855062456], [2.9798604461548766, 1.19765659474954], [2.9798604461548766, 1.19765659474954], [2.9798604461548766, 1.19765659474954]], "postcava": [[151.3891903352374, 4.700455115571472], [3.065810750535689, 2.074722812609995], [3.065810750535689, 2.074722812609995], [3.065810750535689, 2.074722812609995]], "right kidney": [[613.4017011464975, 11.282616103318485], [4.63815461741129, 2.2967740371944867], [4.63815461741129, 2.2967740371944867], [4.63815461741129, 2.2967740371944867]], "duodenum": [[88.51851857758399, 5.251374959142798], [9.350910364523573, 8.85976960554745], [9.350910364523573, 8.85976960554745], [9.350910364523573, 8.85976960554745]], "kidney": [[831.5762248415444, 18.739059302777875], [5.715871882386201, 2.6205541393599527], [5.715871882386201, 2.6205541393599527], [5.715871882386201, 2.6205541393599527]], "left kidney": [[255.4744196400276, 5.573793361388763], [6.081920320421431, 2.930383603114708], [6.081920320421431, 2.930383603114708], [6.081920320421431, 2.930383603114708]], "liver": [[491.1931789168259, 9.294627086787225], [10.138029098677139, 6.28829088692463], [10.138029098677139, 6.28829088692463], [10.138029098677139, 6.28829088692463]], "pancreas": [[136.2304629992425, 5.676744286342953], [19.631392824605342, 11.528214201070567], [19.631392824605342, 11.528214201070567], [19.631392824605342, 11.528214201070567]], "gallbladder": [[75.18767252055355, 2.8711737605829892], [14.500831537679415, 20.696868858705496], [14.500831537679415, 20.696868858705496], [14.500831537679415, 20.696868858705496]], "stomach": [[89.16380420023327, 4.461224829090838], [10.266772743753412, 16.943404348738376], [10.266772743753412, 16.943404348738376], [10.266772743753412, 16.943404348738376]], "spleen": [[413.92566589639046, 7.99961594912814], [7.267087388529462, 5.149714876028216], [7.267087388529462, 5.149714876028216], [7.267087388529462, 5.149714876028216]], "left adrenal gland": [[86.44109991236728, 4.826813402237061], [17.153928230900817, 14.858036650050408], [17.153928230900817, 14.858036650050408], [17.153928230900817, 14.858036650050408]], "adrenal gland": [[303.9642820935704, 16.729857009916806], [19.500678047021523, 17.02588768312544], [19.500678047021523, 17.02588768312544], [19.500678047021523, 17.02588768312544]], "right adrenal gland": [[172.36803145644578, 8.050377438528958], [15.257519917725558, 13.431078702905772], [15.257519917725558, 13.431078702905772], [15.257519917725558, 13.431078702905772]], "esophagus": [[193.1348898340059, 7.6397334220243325], [12.240331385391299, 16.812971132953354], [12.240331385391299, 16.812971132953354], [12.240331385391299, 16.812971132953354]]}, "MRI-Cardiac": {"left heart ventricle": [[964.9072936969454, 17.21177762137991], [5.880290818671821, 4.100959742819713], [5.880290818671821, 4.100959742819713], [5.880290818671821, 4.100959742819713]], "myocardium": [[448.3393673888417, 17.591805257426998], [5.208511169313307, 15.910705163394415], [5.208511169313307, 15.910705163394415], [5.208511169313307, 15.910705163394415]], "right heart ventricle": [[359.88937669636215, 9.392153523781843], [5.924076424141962, 5.554667293878979], [5.924076424141962, 5.554667293878979], [5.924076424141962, 5.554667293878979]]}, "MRI-FLAIR-Brain": {"edema": [[69.4159007224176, 5.568921766085619], [13.400334168570177, 4.965265405638592], [13.400334168570177, 4.965265405638592], [13.400334168570177, 4.965265405638592]], "tumor core": [[154.26935124167449, 8.089254912853598], [14.908340542645478, 4.820086393609397], [14.908340542645478, 4.820086393609397], [14.908340542645478, 4.820086393609397]], "whole tumor": [[485.48717118600956, 16.01178236475156], [25.74323915508559, 8.636438181178145], [25.74323915508559, 8.636438181178145], [25.74323915508559, 8.636438181178145]]}, "MRI-T1-Gd-Brain": {"enhancing tumor": [[175.6437881777937, 7.539344668413025], [17.864705093992068, 5.36432831714689], [17.864705093992068, 5.36432831714689], [17.864705093992068, 5.36432831714689]], "non-enhancing tumor": [[37.6625733247702, 3.8454536110058246], [6.568014639412233, 8.446289690167484], [6.568014639412233, 8.446289690167484], [6.568014639412233, 8.446289690167484]], "tumor core": [[180.88223552813486, 6.610443841067055], [9.70294999498087, 5.30262880784197], [9.70294999498087, 5.30262880784197], [9.70294999498087, 5.30262880784197]]}, "Pathology": {"connective tissue cells": [[46.71165884847293, 4.997126203483956], [9.942495884846476, 15.700775443760845], [4.328453739888501, 18.42621798468577], [9.798096322131162, 11.920352021312304]], "inflammatory cells": [[39.600337990197595, 3.1848025413959706], [6.287418328538852, 20.538379638162322], [2.9521703595392146, 25.264465092284006], [6.559595490616054, 12.004686961917436]], "neoplastic cells": [[82.29374052289526, 8.22429924322936], [9.592296798563375, 14.818916788142138], [4.948629785308088, 19.78516221506478], [10.729094314024243, 12.934345198477494]], "epithelial cells": [[91.75183574899573, 9.577544361042948], [13.469843493323452, 27.305962287612964], [4.696928248406198, 25.254143364646463], [11.077634907582583, 13.487595094752443]]}, "X-Ray-Chest": {"left lung": [[529.1669758355144, 7.465035502868491], [8.220284641505614, 11.62958600654364], [8.220284641505614, 11.62958600654364], [8.220284641505614, 11.62958600654364]], "lung": [[465.7809501354513, 7.147122106450173], [8.781306299078446, 12.335455073688102], [8.781306299078446, 12.335455073688102], [8.781306299078446, 12.335455073688102]], "right lung": [[567.6127039725319, 7.532428563004494], [8.067311420424144, 11.229763331648746], [8.067311420424144, 11.229763331648746], [8.067311420424144, 11.229763331648746]]}, "Ultrasound-Cardiac": {"left heart atrium": [[1188.687550702627, 24.234766943758856], [5.18832820435626, 13.705576921752291], [5.18832820435626, 13.705576921752291], [5.18832820435626, 13.705576921752291]], "left heart ventricle": [[2787.334986695437, 58.297232816307506], [15.28158405889985, 56.95469460140377], [15.28158405889985, 56.95469460140377], [15.28158405889985, 56.95469460140377]]}, "Endoscopy": {"neoplastic polyp": [[392.89875472390315, 5.4678888279040745], [7.477729277754545, 1.6522601344780465], [7.2704247484339035, 6.347521355120636], [4.3902399436060335, 6.543658310376327]], "polyp": [[163.7838288028474, 3.4851615302599117], [7.03659746479883, 1.9088902542177986], [6.992807172875011, 6.756628353721484], [5.185761648208865, 8.977427344868255]], "non-neoplastic polyp": [[214.9199548332033, 4.360826895414348], [7.303363948417486, 1.9789835935004905], [10.54652900087687, 9.009706115553772], [6.917879576439251, 10.404634951284532]]}, "Fundus": {"optic cup": [[1482.9561484784422, 35.78105120937013], [52.1031548324398, 1.5080077510381715], [10.023538467761934, 3.1641925551155046], [3.394564722036805, 2.4391933423559626]], "optic disc": [[626.9141229495486, 20.95002931507066], [18.278454005466408, 1.8261365514325893], [16.42282430959315, 11.171338052048034], [4.8937792939550135, 6.987302868644637]]}, "Dermoscopy": {"lesion": [[134.43456931870887, 4.743684855379663], [5.18053578956456, 2.3527492367343634], [3.809383004477107, 6.368793378843402], [2.3888068456218847, 6.655396307215968]], "melanoma": [[454.17848530764076, 9.6466178116726], [4.022144360826467, 7.870140640677671], [4.87109613458874, 18.93721534855073], [3.107895746664011, 13.604075970992069]]}, "OCT": {"edema": [[260.11475018501574, 7.379315940573871], [4.162158474003, 17.437425953761988], [12.65808078622105, 81.37165793634547], [1.763378481483125, 4.427309203795247]]}}
|
main.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from modeling.BaseModel import BaseModel
|
9 |
+
from modeling import build_model
|
10 |
+
from utilities.distributed import init_distributed
|
11 |
+
from utilities.arguments import load_opt_from_config_files
|
12 |
+
from utilities.constants import BIOMED_CLASSES
|
13 |
+
from inference_utils.inference import interactive_infer_image
|
14 |
+
|
15 |
+
|
16 |
+
def overlay_masks(image, masks, colors):
|
17 |
+
overlay = image.copy()
|
18 |
+
overlay = np.array(overlay, dtype=np.uint8)
|
19 |
+
for mask, color in zip(masks, colors):
|
20 |
+
overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype(
|
21 |
+
np.uint8
|
22 |
+
)
|
23 |
+
return Image.fromarray(overlay)
|
24 |
+
|
25 |
+
|
26 |
+
def generate_colors(n):
|
27 |
+
cmap = plt.get_cmap("tab10")
|
28 |
+
colors = [tuple(int(255 * val) for val in cmap(i)[:3]) for i in range(n)]
|
29 |
+
return colors
|
30 |
+
|
31 |
+
|
32 |
+
def init_model():
|
33 |
+
# Download model
|
34 |
+
model_file = hf_hub_download(
|
35 |
+
repo_id="microsoft/BiomedParse",
|
36 |
+
filename="biomedparse_v1.pt",
|
37 |
+
token=os.getenv("HF_TOKEN"),
|
38 |
+
)
|
39 |
+
|
40 |
+
# Initialize model
|
41 |
+
conf_files = "configs/biomedparse_inference.yaml"
|
42 |
+
opt = load_opt_from_config_files([conf_files])
|
43 |
+
opt = init_distributed(opt)
|
44 |
+
|
45 |
+
model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda()
|
46 |
+
with torch.no_grad():
|
47 |
+
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
|
48 |
+
BIOMED_CLASSES + ["background"], is_eval=True
|
49 |
+
)
|
50 |
+
|
51 |
+
return model
|
52 |
+
|
53 |
+
|
54 |
+
def predict(image, prompts):
|
55 |
+
if not prompts:
|
56 |
+
return None
|
57 |
+
|
58 |
+
# Convert string input to list
|
59 |
+
prompts = [p.strip() for p in prompts.split(",")]
|
60 |
+
|
61 |
+
# Convert to RGB if needed
|
62 |
+
if image.mode != "RGB":
|
63 |
+
image = image.convert("RGB")
|
64 |
+
|
65 |
+
# Get predictions
|
66 |
+
pred_mask = interactive_infer_image(model, image, prompts)
|
67 |
+
|
68 |
+
# Generate visualization
|
69 |
+
colors = generate_colors(len(prompts))
|
70 |
+
pred_overlay = overlay_masks(
|
71 |
+
image, [1 * (pred_mask[i] > 0.5) for i in range(len(prompts))], colors
|
72 |
+
)
|
73 |
+
|
74 |
+
return pred_overlay
|
75 |
+
|
76 |
+
|
77 |
+
def run():
|
78 |
+
global model
|
79 |
+
model = init_model()
|
80 |
+
|
81 |
+
demo = gr.Interface(
|
82 |
+
fn=predict,
|
83 |
+
inputs=[
|
84 |
+
gr.Image(type="pil", label="Input Image"),
|
85 |
+
gr.Textbox(
|
86 |
+
label="Prompts",
|
87 |
+
placeholder="Enter prompts separated by commas (e.g., neoplastic cells, inflammatory cells)",
|
88 |
+
),
|
89 |
+
],
|
90 |
+
outputs=gr.Image(type="pil", label="Prediction"),
|
91 |
+
title="BiomedParse Demo",
|
92 |
+
description="Upload a biomedical image and enter prompts (separated by commas) to detect specific features.",
|
93 |
+
examples=[
|
94 |
+
[
|
95 |
+
"examples/Part_1_516_pathology_breast.png",
|
96 |
+
"neoplastic cells, inflammatory cells",
|
97 |
+
]
|
98 |
+
],
|
99 |
+
)
|
100 |
+
|
101 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
print(f"HF_TOKEN={os.getenv('HF_TOKEN')}")
|
106 |
+
run()
|
modeling/BaseModel.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from utilities.model import align_and_update_state_dicts
|
8 |
+
|
9 |
+
from utilities.distributed import init_distributed
|
10 |
+
from utilities.arguments import load_opt_from_config_files
|
11 |
+
|
12 |
+
import huggingface_hub
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class BaseModel(nn.Module):
|
18 |
+
def __init__(self, opt, module: nn.Module):
|
19 |
+
super(BaseModel, self).__init__()
|
20 |
+
self.opt = opt
|
21 |
+
self.model = module
|
22 |
+
|
23 |
+
def forward(self, *inputs, **kwargs):
|
24 |
+
outputs = self.model(*inputs, **kwargs)
|
25 |
+
return outputs
|
26 |
+
|
27 |
+
def save_pretrained(self, save_dir):
|
28 |
+
torch.save(self.model.state_dict(), os.path.join(save_dir, "model_state_dict.pt"))
|
29 |
+
|
30 |
+
def from_pretrained(self, pretrained, filename: str = "biomedparse_v1.pt",
|
31 |
+
local_dir: str = "./pretrained", config_dir: str = "./configs"):
|
32 |
+
if pretrained.startswith("hf_hub:"):
|
33 |
+
hub_name = pretrained.split(":")[1]
|
34 |
+
huggingface_hub.hf_hub_download(hub_name, filename=filename,
|
35 |
+
local_dir=local_dir)
|
36 |
+
huggingface_hub.hf_hub_download(hub_name, filename="config.yaml",
|
37 |
+
local_dir=config_dir)
|
38 |
+
load_dir = os.path.join(local_dir, filename)
|
39 |
+
else:
|
40 |
+
load_dir = pretrained
|
41 |
+
|
42 |
+
state_dict = torch.load(load_dir, map_location=self.opt['device'])
|
43 |
+
state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict)
|
44 |
+
self.model.load_state_dict(state_dict, strict=False)
|
45 |
+
return self
|
modeling/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .architectures import build_model
|
modeling/architectures/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .xdecoder_model import *
|
2 |
+
from .seem_model_v0 import *
|
3 |
+
from .seem_model_v1 import *
|
4 |
+
from .seem_model_demo import *
|
5 |
+
from .build import build_model
|
modeling/architectures/build.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_model_entrypoints = {}
|
2 |
+
|
3 |
+
|
4 |
+
def build_model(config, **kwargs):
|
5 |
+
model_name = config['MODEL']['NAME']
|
6 |
+
|
7 |
+
if not is_model(model_name):
|
8 |
+
raise ValueError(f'Unkown model: {model_name}')
|
9 |
+
|
10 |
+
return model_entrypoints(model_name)(config, **kwargs)
|
11 |
+
|
12 |
+
def register_model(fn):
|
13 |
+
module_name_split = fn.__module__.split('.')
|
14 |
+
model_name = module_name_split[-1]
|
15 |
+
_model_entrypoints[model_name] = fn
|
16 |
+
return fn
|
17 |
+
|
18 |
+
def model_entrypoints(model_name):
|
19 |
+
return _model_entrypoints[model_name]
|
20 |
+
|
21 |
+
def is_model(model_name):
|
22 |
+
return model_name in _model_entrypoints
|
modeling/architectures/seem_model_demo.py
ADDED
@@ -0,0 +1,923 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# SEEM -- Segment Everything Everywhere All at Once
|
3 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
4 |
+
# Written by Xueyan Zou ([email protected])
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
import random
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
from kornia.contrib import distance_transform
|
15 |
+
|
16 |
+
from detectron2.structures import Boxes, ImageList, Instances, BitMasks
|
17 |
+
from detectron2.utils.memory import retry_if_cuda_oom
|
18 |
+
from detectron2.data import MetadataCatalog
|
19 |
+
|
20 |
+
from .build import register_model
|
21 |
+
|
22 |
+
from ..utils import configurable, get_class_names, get_iou
|
23 |
+
from ..vision.backbone import build_backbone, Backbone
|
24 |
+
from ..body import build_xdecoder_head
|
25 |
+
from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
|
26 |
+
from ..language import build_language_encoder
|
27 |
+
from ..language.loss import vl_similarity
|
28 |
+
from utilities.prompt_engineering import prompt_engineering
|
29 |
+
from utilities.constants import COCO_PANOPTIC_CLASSES
|
30 |
+
|
31 |
+
|
32 |
+
class GeneralizedSEEM(nn.Module):
|
33 |
+
|
34 |
+
@configurable
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
*,
|
38 |
+
backbone: Backbone,
|
39 |
+
sem_seg_head: nn.Module,
|
40 |
+
criterion: nn.Module,
|
41 |
+
losses: dict,
|
42 |
+
num_queries: int,
|
43 |
+
object_mask_threshold: float,
|
44 |
+
overlap_threshold: float,
|
45 |
+
metadata,
|
46 |
+
task_switch: dict,
|
47 |
+
phrase_prob: float,
|
48 |
+
size_divisibility: int,
|
49 |
+
sem_seg_postprocess_before_inference: bool,
|
50 |
+
pixel_mean: Tuple[float],
|
51 |
+
pixel_std: Tuple[float],
|
52 |
+
# inference
|
53 |
+
semantic_on: bool,
|
54 |
+
panoptic_on: bool,
|
55 |
+
instance_on: bool,
|
56 |
+
test_topk_per_image: int,
|
57 |
+
train_dataset_name: str,
|
58 |
+
interactive_mode: str,
|
59 |
+
interactive_iter: str,
|
60 |
+
dilation_kernel: torch.Tensor,
|
61 |
+
):
|
62 |
+
super().__init__()
|
63 |
+
self.backbone = backbone
|
64 |
+
self.sem_seg_head = sem_seg_head
|
65 |
+
self.criterion = criterion
|
66 |
+
self.losses = losses
|
67 |
+
self.num_queries = num_queries
|
68 |
+
self.overlap_threshold = overlap_threshold
|
69 |
+
self.object_mask_threshold = object_mask_threshold
|
70 |
+
self.metadata = metadata
|
71 |
+
if size_divisibility < 0:
|
72 |
+
# use backbone size_divisibility if not set
|
73 |
+
size_divisibility = self.backbone.size_divisibility
|
74 |
+
self.size_divisibility = size_divisibility
|
75 |
+
self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
|
76 |
+
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
77 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
78 |
+
|
79 |
+
# additional args
|
80 |
+
self.semantic_on = semantic_on
|
81 |
+
self.instance_on = instance_on
|
82 |
+
self.panoptic_on = panoptic_on
|
83 |
+
|
84 |
+
# caption argument
|
85 |
+
self.task_switch = task_switch
|
86 |
+
self.phrase_prob = phrase_prob
|
87 |
+
|
88 |
+
self.test_topk_per_image = test_topk_per_image
|
89 |
+
self.train_class_names = None
|
90 |
+
self.interactive_mode = interactive_mode
|
91 |
+
self.interactive_iter = interactive_iter
|
92 |
+
|
93 |
+
if not self.semantic_on:
|
94 |
+
assert self.sem_seg_postprocess_before_inference
|
95 |
+
|
96 |
+
self.register_buffer("dilation_kernel", dilation_kernel)
|
97 |
+
|
98 |
+
@classmethod
|
99 |
+
def from_config(cls, cfg):
|
100 |
+
enc_cfg = cfg['MODEL']['ENCODER']
|
101 |
+
dec_cfg = cfg['MODEL']['DECODER']
|
102 |
+
|
103 |
+
openimage_switch = {'grounding': dec_cfg['OPENIMAGE']['GROUNDING'].get('ENABLED', False),
|
104 |
+
'mask': dec_cfg['OPENIMAGE'].get('ENABLED', False)}
|
105 |
+
|
106 |
+
task_switch = {'bbox': dec_cfg.get('DETECTION', False),
|
107 |
+
'mask': dec_cfg.get('MASK', True),
|
108 |
+
'spatial': dec_cfg['SPATIAL'].get('ENABLED', False),
|
109 |
+
'grounding': dec_cfg['GROUNDING'].get('ENABLED', False),
|
110 |
+
'openimage': openimage_switch,
|
111 |
+
'visual': dec_cfg['VISUAL'].get('ENABLED', False),
|
112 |
+
'audio': dec_cfg['AUDIO'].get('ENABLED', False)}
|
113 |
+
|
114 |
+
# build model
|
115 |
+
extra = {'task_switch': task_switch}
|
116 |
+
backbone = build_backbone(cfg)
|
117 |
+
lang_encoder = build_language_encoder(cfg)
|
118 |
+
sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra=extra)
|
119 |
+
|
120 |
+
# Training Settings.
|
121 |
+
loss_weights = {}
|
122 |
+
matcher = None
|
123 |
+
losses = {}
|
124 |
+
weight_dict = {}
|
125 |
+
grd_weight = {}
|
126 |
+
top_x_layers = {}
|
127 |
+
criterion = None
|
128 |
+
train_dataset_name = None
|
129 |
+
phrase_prob = None
|
130 |
+
# Loss parameters:
|
131 |
+
deep_supervision = None
|
132 |
+
no_object_weight = None
|
133 |
+
|
134 |
+
interactive_mode = 'best'
|
135 |
+
interactive_iter = 20
|
136 |
+
dilation = 3
|
137 |
+
dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
|
138 |
+
|
139 |
+
return {
|
140 |
+
"backbone": backbone,
|
141 |
+
"sem_seg_head": sem_seg_head,
|
142 |
+
"criterion": criterion,
|
143 |
+
"losses": losses,
|
144 |
+
"num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
|
145 |
+
"object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
|
146 |
+
"overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
|
147 |
+
"metadata": None,
|
148 |
+
"size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
|
149 |
+
"sem_seg_postprocess_before_inference": (
|
150 |
+
dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
|
151 |
+
or dec_cfg['TEST']['PANOPTIC_ON']
|
152 |
+
or dec_cfg['TEST']['INSTANCE_ON']
|
153 |
+
),
|
154 |
+
"pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
|
155 |
+
"pixel_std": cfg['INPUT']['PIXEL_STD'],
|
156 |
+
"task_switch": task_switch,
|
157 |
+
"phrase_prob": phrase_prob,
|
158 |
+
# inference
|
159 |
+
"semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
|
160 |
+
"instance_on": dec_cfg['TEST']['INSTANCE_ON'],
|
161 |
+
"panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
|
162 |
+
"test_topk_per_image": cfg['MODEL']['DECODER']['TEST']['DETECTIONS_PER_IMAGE'],
|
163 |
+
"train_dataset_name": train_dataset_name,
|
164 |
+
"interactive_mode": interactive_mode,
|
165 |
+
"interactive_iter": interactive_iter,
|
166 |
+
"dilation_kernel": dilation_kernel,
|
167 |
+
}
|
168 |
+
|
169 |
+
@property
|
170 |
+
def device(self):
|
171 |
+
return self.pixel_mean.device
|
172 |
+
|
173 |
+
def forward(self, batched_inputs, mode='default'):
|
174 |
+
if self.training:
|
175 |
+
losses = {}
|
176 |
+
if self.task_switch['mask']:
|
177 |
+
losses_seg = self.forward_seg(batched_inputs)
|
178 |
+
losses.update(losses_seg)
|
179 |
+
if self.task_switch['openimage'] and self.task_switch['openimage']['mask']:
|
180 |
+
losses_openimage = self.forward_openimage(batched_inputs['openimage'])
|
181 |
+
losses_openimage = {key.replace('mask', 'openimage'):value for key, value in losses_openimage.items()}
|
182 |
+
losses_openimage = {key.replace('grounding', 'grounding_openimage'):value for key, value in losses_openimage.items()}
|
183 |
+
losses.update(losses_openimage)
|
184 |
+
for k in list(losses.keys()):
|
185 |
+
if k in self.criterion.weight_dict:
|
186 |
+
losses[k] *= self.criterion.weight_dict[k]
|
187 |
+
else: # remove this loss if not specified in `weight_dict`
|
188 |
+
losses.pop(k)
|
189 |
+
return losses
|
190 |
+
else:
|
191 |
+
if mode == 'interactive':
|
192 |
+
return self.evaluate_interactive(batched_inputs)
|
193 |
+
elif mode == 'grounding_spatial':
|
194 |
+
return self.evaluate_grounding_sptial(batched_inputs, mode)
|
195 |
+
elif mode in ['grounding_phrasecut', 'grounding_refcoco']:
|
196 |
+
return self.evaluate_grounding(batched_inputs, mode)
|
197 |
+
else:
|
198 |
+
return self.evaluate(batched_inputs)
|
199 |
+
|
200 |
+
|
201 |
+
def forward_seg(self, batched_inputs):
|
202 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
203 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
204 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
205 |
+
|
206 |
+
self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
|
207 |
+
|
208 |
+
extra = {}
|
209 |
+
# mask classification target
|
210 |
+
if "instances" in batched_inputs[0]:
|
211 |
+
# input bounding box is checked to be correct.
|
212 |
+
targets = self.prepare_targets(batched_inputs, images)
|
213 |
+
|
214 |
+
if self.task_switch['grounding']:
|
215 |
+
grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
|
216 |
+
grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens, padding_value=-1)
|
217 |
+
non_zero_query_mask = (grounding_tokens.sum(dim=-1) == -grounding_tokens.shape[-1])
|
218 |
+
grounding_tokens[non_zero_query_mask] = 0
|
219 |
+
|
220 |
+
extra['grounding_tokens'] = grounding_tokens
|
221 |
+
extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
222 |
+
|
223 |
+
if self.task_switch['spatial']:
|
224 |
+
pos_masks = [x['spatial_query']['rand_shape'].to(self.device) for x in batched_inputs]
|
225 |
+
neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs]
|
226 |
+
fp_masks = torch.stack([(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs])
|
227 |
+
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks, 'false_positive_mask': fp_masks})
|
228 |
+
|
229 |
+
features = self.backbone(images.tensor)
|
230 |
+
mask_features, _, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
231 |
+
|
232 |
+
# forward spatial only without gradient
|
233 |
+
if self.task_switch['spatial']:
|
234 |
+
with torch.no_grad():
|
235 |
+
# generate random integeter between [0,3]
|
236 |
+
rand_iter_num = random.randint(0, 2)
|
237 |
+
for i in range(rand_iter_num):
|
238 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='spatial')
|
239 |
+
extra.update(outputs)
|
240 |
+
extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
|
241 |
+
|
242 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='seg')
|
243 |
+
extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
|
244 |
+
'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default')),
|
245 |
+
'false_positive_mask': extra['false_positive_mask']}
|
246 |
+
# bipartite matching-based loss
|
247 |
+
self.criterion.losses = self.losses['seg'] # seg criterion losses
|
248 |
+
losses = self.criterion(outputs, targets, extra)
|
249 |
+
|
250 |
+
del outputs
|
251 |
+
return losses
|
252 |
+
|
253 |
+
def evaluate_demo(self, batched_inputs):
|
254 |
+
assert len(batched_inputs) == 1, "only support batch size equal to 1"
|
255 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
256 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
257 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
258 |
+
img_bs = images.tensor.shape[0]
|
259 |
+
|
260 |
+
targets = targets_grounding = queries_grounding = None
|
261 |
+
features = self.backbone(images.tensor)
|
262 |
+
mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
263 |
+
image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
|
264 |
+
|
265 |
+
extra = {}
|
266 |
+
if 'stroke' in batched_inputs[0]:
|
267 |
+
pos_masks = (batched_inputs[0]['stroke'].to(self.device)).unbind(0)
|
268 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
|
269 |
+
neg_masks = (batched_inputs[0]['stroke'].to(self.device) & False).unbind(0)
|
270 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
|
271 |
+
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
|
272 |
+
|
273 |
+
if 'visual' in batched_inputs[0]:
|
274 |
+
extra.update(batched_inputs[0]['visual'])
|
275 |
+
|
276 |
+
if 'text' in batched_inputs[0]:
|
277 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(batched_inputs[0]['text'], name='grounding', token=False, norm=False)
|
278 |
+
token_emb = gtext['token_emb']
|
279 |
+
tokens = gtext['tokens']
|
280 |
+
query_emb = token_emb[tokens['attention_mask'].bool()]
|
281 |
+
non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
|
282 |
+
extra['grounding_tokens'] = query_emb[:,None]
|
283 |
+
extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
284 |
+
extra['grounding_class'] = gtext['class_emb']
|
285 |
+
|
286 |
+
if 'audio' in batched_inputs[0]:
|
287 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(batched_inputs[0]['audio'], name='grounding', token=False, norm=False)
|
288 |
+
token_emb = gtext['token_emb']
|
289 |
+
tokens = gtext['tokens']
|
290 |
+
query_emb = token_emb[tokens['attention_mask'].bool()]
|
291 |
+
non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
|
292 |
+
extra['audio_tokens'] = query_emb[:,None]
|
293 |
+
extra['audio_nonzero_mask'] = non_zero_query_mask.t()
|
294 |
+
extra['audio_class'] = gtext['class_emb']
|
295 |
+
|
296 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='demo')
|
297 |
+
return outputs, images.tensor.shape, extra
|
298 |
+
|
299 |
+
assert self.task_switch['spatial']
|
300 |
+
assert 'spatial_query' in batched_inputs[0]
|
301 |
+
assert len(batched_inputs) == 1, "only support batch size equal to 1"
|
302 |
+
|
303 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
304 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
305 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
306 |
+
img_bs = images.tensor.shape[0]
|
307 |
+
|
308 |
+
targets = targets_grounding = queries_grounding = None
|
309 |
+
extra = {}
|
310 |
+
|
311 |
+
features = self.backbone(images.tensor)
|
312 |
+
mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
313 |
+
|
314 |
+
image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
|
315 |
+
nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
|
316 |
+
multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
|
317 |
+
mask_features = mask_features.repeat(nm,1,1,1)
|
318 |
+
|
319 |
+
all_batch_shape_iou = []
|
320 |
+
pred_smask_pointer = None
|
321 |
+
prev_smask_pointer = None
|
322 |
+
pred_smask_all = None
|
323 |
+
|
324 |
+
query_index = self.sem_seg_head.predictor.query_index
|
325 |
+
assert self.interactive_mode == 'best'
|
326 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
|
327 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
|
328 |
+
|
329 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
|
330 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
|
331 |
+
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
|
332 |
+
|
333 |
+
for i in range(self.interactive_iter):
|
334 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
|
335 |
+
extra.update(outputs)
|
336 |
+
pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
|
337 |
+
|
338 |
+
s = image_sizes[0]
|
339 |
+
b = batched_inputs[0]
|
340 |
+
pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
|
341 |
+
gt_smask = b['gt_masks_orisize']
|
342 |
+
all_batch_shape_iou += [get_iou(gt_smask, pred_smask_all)]
|
343 |
+
extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
|
344 |
+
|
345 |
+
all_batch_shape_iou = torch.stack(all_batch_shape_iou)
|
346 |
+
processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
|
347 |
+
return processed_results
|
348 |
+
|
349 |
+
def evaluate(self, batched_inputs):
|
350 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
351 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
352 |
+
|
353 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
354 |
+
img_bs = images.tensor.shape[0]
|
355 |
+
|
356 |
+
targets = targets_grounding = queries_grounding = None
|
357 |
+
features = self.backbone(images.tensor)
|
358 |
+
outputs = self.sem_seg_head(features, target_queries=queries_grounding)
|
359 |
+
|
360 |
+
mask_cls_results = outputs["pred_logits"]
|
361 |
+
mask_pred_results = outputs["pred_masks"]
|
362 |
+
box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
|
363 |
+
|
364 |
+
# upsample masks
|
365 |
+
mask_pred_results = F.interpolate(
|
366 |
+
mask_pred_results,
|
367 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
368 |
+
mode="bilinear",
|
369 |
+
align_corners=False,
|
370 |
+
)
|
371 |
+
|
372 |
+
input_size = mask_pred_results.shape[-2:]
|
373 |
+
del outputs
|
374 |
+
|
375 |
+
processed_results = []
|
376 |
+
for mask_cls_result, mask_pred_result, box_pred_result, input_per_image, image_size in zip(
|
377 |
+
mask_cls_results, mask_pred_results, box_pred_results, batched_inputs, images.image_sizes
|
378 |
+
):
|
379 |
+
height = input_per_image.get("height", image_size[0])
|
380 |
+
width = input_per_image.get("width", image_size[1])
|
381 |
+
processed_results.append({})
|
382 |
+
|
383 |
+
if self.sem_seg_postprocess_before_inference:
|
384 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
385 |
+
mask_pred_result, image_size, height, width
|
386 |
+
)
|
387 |
+
mask_cls_result = mask_cls_result.to(mask_pred_result)
|
388 |
+
|
389 |
+
# semantic segmentation inference
|
390 |
+
if self.semantic_on:
|
391 |
+
r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
|
392 |
+
if not self.sem_seg_postprocess_before_inference:
|
393 |
+
r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
|
394 |
+
processed_results[-1]["sem_seg"] = r
|
395 |
+
|
396 |
+
# panoptic segmentation inference
|
397 |
+
if self.panoptic_on:
|
398 |
+
panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
|
399 |
+
processed_results[-1]["panoptic_seg"] = panoptic_r
|
400 |
+
|
401 |
+
# instance segmentation inference
|
402 |
+
if self.instance_on:
|
403 |
+
if self.task_switch['bbox']:
|
404 |
+
box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
|
405 |
+
instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
|
406 |
+
processed_results[-1]["instances"] = instance_r
|
407 |
+
|
408 |
+
return processed_results
|
409 |
+
|
410 |
+
def evaluate_interactive(self, batched_inputs):
|
411 |
+
assert self.task_switch['spatial']
|
412 |
+
assert 'spatial_query' in batched_inputs[0]
|
413 |
+
assert len(batched_inputs) == 1, "only support batch size equal to 1"
|
414 |
+
|
415 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
416 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
417 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
418 |
+
img_bs = images.tensor.shape[0]
|
419 |
+
|
420 |
+
targets = targets_grounding = queries_grounding = None
|
421 |
+
extra = {}
|
422 |
+
|
423 |
+
features = self.backbone(images.tensor)
|
424 |
+
mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
425 |
+
|
426 |
+
image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
|
427 |
+
nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
|
428 |
+
multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
|
429 |
+
mask_features = mask_features.repeat(nm,1,1,1)
|
430 |
+
|
431 |
+
all_batch_shape_iou = []
|
432 |
+
pred_smask_pointer = None
|
433 |
+
prev_smask_pointer = None
|
434 |
+
pred_smask_all = None
|
435 |
+
|
436 |
+
query_index = self.sem_seg_head.predictor.query_index
|
437 |
+
assert self.interactive_mode == 'best'
|
438 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
|
439 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
|
440 |
+
|
441 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
|
442 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
|
443 |
+
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
|
444 |
+
|
445 |
+
for i in range(self.interactive_iter):
|
446 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
|
447 |
+
extra.update(outputs)
|
448 |
+
pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
|
449 |
+
|
450 |
+
s = image_sizes[0]
|
451 |
+
b = batched_inputs[0]
|
452 |
+
pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
|
453 |
+
gt_smask = b['gt_masks_orisize']
|
454 |
+
all_batch_shape_iou += [get_iou(gt_smask, pred_smask_all)]
|
455 |
+
extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
|
456 |
+
|
457 |
+
all_batch_shape_iou = torch.stack(all_batch_shape_iou)
|
458 |
+
processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
|
459 |
+
return processed_results
|
460 |
+
|
461 |
+
def evaluate_referring_image(self, batched_inputs, extra={}):
|
462 |
+
assert self.task_switch['spatial']
|
463 |
+
assert len(batched_inputs) == 1, "only support batch size equal to 1"
|
464 |
+
assert self.interactive_mode == 'best'
|
465 |
+
|
466 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
467 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
468 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
469 |
+
img_bs = images.tensor.shape[0]
|
470 |
+
|
471 |
+
targets = targets_grounding = queries_grounding = None
|
472 |
+
features = self.backbone(images.tensor)
|
473 |
+
mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
474 |
+
|
475 |
+
if 'spatial_query' in batched_inputs[0]:
|
476 |
+
image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
|
477 |
+
nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
|
478 |
+
multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
|
479 |
+
mask_features = mask_features.repeat(nm,1,1,1)
|
480 |
+
|
481 |
+
query_index = self.sem_seg_head.predictor.query_index
|
482 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
|
483 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
|
484 |
+
|
485 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
|
486 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
|
487 |
+
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
|
488 |
+
|
489 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='refimg')
|
490 |
+
return outputs, images.tensor.shape
|
491 |
+
|
492 |
+
def evaluate_grounding(self, batched_inputs, mode):
|
493 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
494 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
495 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
496 |
+
assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
|
497 |
+
|
498 |
+
extra = {}
|
499 |
+
# mask_pred_results = []
|
500 |
+
# for idx, batch_per_image in enumerate(batched_inputs):
|
501 |
+
# grd_texts = batch_per_image['groundings']['texts']
|
502 |
+
# grd_masks = []
|
503 |
+
# for anno_text in grd_texts:
|
504 |
+
# gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
|
505 |
+
# token_emb = gtext['token_emb']
|
506 |
+
# tokens = gtext['tokens']
|
507 |
+
|
508 |
+
# grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
|
509 |
+
# extra['grounding_tokens'] = grd_emb[:,None]
|
510 |
+
|
511 |
+
# assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
|
512 |
+
# features = self.backbone(images.tensor)
|
513 |
+
# outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
514 |
+
|
515 |
+
# pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
|
516 |
+
# v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
|
517 |
+
# t_emb = grd_emb[-1:]
|
518 |
+
|
519 |
+
# t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
520 |
+
# v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
521 |
+
|
522 |
+
# temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
523 |
+
# out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
524 |
+
|
525 |
+
# matched_id = out_prob.max(0)[1]
|
526 |
+
# grd_masks += [pred_gmasks[matched_id,:,:]]
|
527 |
+
# mask_pred_results += [torch.cat(grd_masks)]
|
528 |
+
|
529 |
+
# comment for multi object inference.
|
530 |
+
mask_pred_results = []
|
531 |
+
for idx, batch_per_image in enumerate(batched_inputs):
|
532 |
+
grd_texts = batch_per_image['groundings']['texts']
|
533 |
+
grd_texts = [x[0] for x in grd_texts]
|
534 |
+
|
535 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
|
536 |
+
token_emb = gtext['token_emb']
|
537 |
+
tokens = gtext['tokens']
|
538 |
+
query_emb = token_emb[tokens['attention_mask'].bool()]
|
539 |
+
non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
|
540 |
+
|
541 |
+
extra['grounding_tokens'] = query_emb[:,None]
|
542 |
+
extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
543 |
+
|
544 |
+
features = self.backbone(images.tensor)
|
545 |
+
outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
546 |
+
|
547 |
+
pred_gmasks = outputs['pred_gmasks'][idx]
|
548 |
+
v_emb = outputs['pred_gtexts'][idx]
|
549 |
+
t_emb = gtext['class_emb']
|
550 |
+
|
551 |
+
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
552 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
553 |
+
|
554 |
+
temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
555 |
+
out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
556 |
+
|
557 |
+
matched_id = out_prob.max(0)[1]
|
558 |
+
mask_pred_results += [pred_gmasks[matched_id,:,:]]
|
559 |
+
|
560 |
+
for i in range(len(mask_pred_results)):
|
561 |
+
# upsample masks
|
562 |
+
mask_pred_results[i] = F.interpolate(
|
563 |
+
mask_pred_results[i][None,],
|
564 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
565 |
+
mode="bilinear",
|
566 |
+
align_corners=False,
|
567 |
+
)[0]
|
568 |
+
|
569 |
+
processed_results = []
|
570 |
+
for mask_pred_result, input_per_image, image_size in zip(
|
571 |
+
mask_pred_results, batched_inputs, images.image_sizes
|
572 |
+
):
|
573 |
+
height = input_per_image.get("height", image_size[0])
|
574 |
+
width = input_per_image.get("width", image_size[1])
|
575 |
+
processed_results.append({})
|
576 |
+
|
577 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
578 |
+
mask_pred_result, image_size, height, width
|
579 |
+
)
|
580 |
+
processed_results[-1]['grounding_mask'] = mask_pred_result
|
581 |
+
|
582 |
+
# compute bbox
|
583 |
+
# bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
|
584 |
+
# bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
585 |
+
# processed_results[-1]['grounding_box'] = bbox
|
586 |
+
|
587 |
+
return processed_results
|
588 |
+
|
589 |
+
def evaluate_grounding_sptial(self, batched_inputs, mode):
|
590 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
591 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
592 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
593 |
+
assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
|
594 |
+
|
595 |
+
extra = {}
|
596 |
+
dilation = 3
|
597 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
|
598 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
|
599 |
+
pos_masks = (F.conv2d(pos_masks.float(), self.dilation_kernel, padding=dilation//2) > 0).unbind(0)
|
600 |
+
|
601 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
|
602 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
|
603 |
+
|
604 |
+
mask_pred_results = []
|
605 |
+
for idx, batch_per_image in enumerate(batched_inputs):
|
606 |
+
grd_texts = batch_per_image['groundings']['texts']
|
607 |
+
grd_masks = []
|
608 |
+
for idx2, anno_text in enumerate(grd_texts):
|
609 |
+
extra.update({'spatial_query_pos_mask': [pos_masks[idx2]], 'spatial_query_neg_mask': [neg_masks[idx2]]})
|
610 |
+
|
611 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
|
612 |
+
token_emb = gtext['token_emb']
|
613 |
+
tokens = gtext['tokens']
|
614 |
+
|
615 |
+
grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
|
616 |
+
non_zero_query_mask = torch.zeros(grd_emb[:,None].shape[:-1], dtype=torch.bool, device=grd_emb.device)
|
617 |
+
extra['grounding_tokens'] = grd_emb[:,None]
|
618 |
+
extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
619 |
+
|
620 |
+
assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
|
621 |
+
features = self.backbone(images.tensor)
|
622 |
+
outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
623 |
+
|
624 |
+
pred_gmasks = outputs['pred_gmasks'][idx]
|
625 |
+
v_emb = outputs['pred_gtexts'][idx]
|
626 |
+
t_emb = gtext['class_emb']
|
627 |
+
|
628 |
+
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
629 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
630 |
+
|
631 |
+
temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
632 |
+
out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
633 |
+
|
634 |
+
matched_id = out_prob.max(0)[1]
|
635 |
+
grd_masks += [pred_gmasks[matched_id,:,:]]
|
636 |
+
mask_pred_results += [torch.cat(grd_masks)]
|
637 |
+
|
638 |
+
# comment for multi object inference.
|
639 |
+
# mask_pred_results = []
|
640 |
+
# for idx, batch_per_image in enumerate(batched_inputs):
|
641 |
+
# grd_texts = batch_per_image['groundings']['texts']
|
642 |
+
# grd_texts = [x[0] for x in grd_texts]
|
643 |
+
|
644 |
+
# gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
|
645 |
+
# token_emb = gtext['token_emb']
|
646 |
+
# tokens = gtext['tokens']
|
647 |
+
# query_emb = token_emb[tokens['attention_mask'].bool()]
|
648 |
+
# non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
|
649 |
+
|
650 |
+
# extra['grounding_tokens'] = query_emb[:,None]
|
651 |
+
# extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
652 |
+
|
653 |
+
# features = self.backbone(images.tensor)
|
654 |
+
# outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
655 |
+
|
656 |
+
# pred_gmasks = outputs['pred_gmasks'][idx]
|
657 |
+
# v_emb = outputs['pred_gtexts'][idx]
|
658 |
+
# t_emb = gtext['class_emb']
|
659 |
+
|
660 |
+
# t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
661 |
+
# v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
662 |
+
|
663 |
+
# temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
664 |
+
# out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
665 |
+
|
666 |
+
# matched_id = out_prob.max(0)[1]
|
667 |
+
# mask_pred_results += [pred_gmasks[matched_id,:,:]]
|
668 |
+
|
669 |
+
for i in range(len(mask_pred_results)):
|
670 |
+
# upsample masks
|
671 |
+
mask_pred_results[i] = F.interpolate(
|
672 |
+
mask_pred_results[i][None,],
|
673 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
674 |
+
mode="bilinear",
|
675 |
+
align_corners=False,
|
676 |
+
)[0]
|
677 |
+
|
678 |
+
processed_results = []
|
679 |
+
for mask_pred_result, input_per_image, image_size in zip(
|
680 |
+
mask_pred_results, batched_inputs, images.image_sizes
|
681 |
+
):
|
682 |
+
height = input_per_image.get("height", image_size[0])
|
683 |
+
width = input_per_image.get("width", image_size[1])
|
684 |
+
processed_results.append({})
|
685 |
+
|
686 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
687 |
+
mask_pred_result, image_size, height, width
|
688 |
+
)
|
689 |
+
processed_results[-1]['grounding_mask'] = mask_pred_result
|
690 |
+
|
691 |
+
return processed_results
|
692 |
+
|
693 |
+
def prepare_targets(self, batched_inputs, images):
|
694 |
+
h_pad, w_pad = images.tensor.shape[-2:]
|
695 |
+
new_targets = []
|
696 |
+
for idx, batch_per_image in enumerate(batched_inputs):
|
697 |
+
targets_per_image = batch_per_image['instances'].to(self.device)
|
698 |
+
# pad gt
|
699 |
+
gt_masks = targets_per_image.gt_masks.tensor
|
700 |
+
padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
|
701 |
+
padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
|
702 |
+
|
703 |
+
gt_boxes = targets_per_image.gt_boxes.tensor
|
704 |
+
ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
|
705 |
+
gt_boxes = gt_boxes / ratio
|
706 |
+
xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
|
707 |
+
gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
|
708 |
+
|
709 |
+
target_dict = {
|
710 |
+
"labels": targets_per_image.gt_classes,
|
711 |
+
"is_things": targets_per_image.is_things,
|
712 |
+
"masks": padded_masks,
|
713 |
+
"boxes": gt_boxes,
|
714 |
+
}
|
715 |
+
|
716 |
+
if self.task_switch['spatial']:
|
717 |
+
# prepare targets for spatial query
|
718 |
+
target_dict['gt_spatial_masks'] = batch_per_image['spatial_query']['gt_masks']
|
719 |
+
|
720 |
+
if self.task_switch['grounding']:
|
721 |
+
grd_masks = batch_per_image['groundings']['masks']
|
722 |
+
grd_texts = batch_per_image['groundings']['texts']
|
723 |
+
grd_hash = batch_per_image['groundings']['hash']
|
724 |
+
grd_task = batch_per_image['groundings']['mode']
|
725 |
+
|
726 |
+
if len(grd_masks) == 0:
|
727 |
+
padded_masks = None
|
728 |
+
else:
|
729 |
+
padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
|
730 |
+
padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
|
731 |
+
|
732 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
|
733 |
+
token_emb = gtext['token_emb']
|
734 |
+
tokens = gtext['tokens']
|
735 |
+
|
736 |
+
unique_hash_id = np.unique(grd_hash, return_index=True)[1]
|
737 |
+
selected_mask = np.zeros(len(grd_hash)).astype(bool)
|
738 |
+
selected_mask[unique_hash_id] = True
|
739 |
+
|
740 |
+
selected_token_emb = token_emb[selected_mask]
|
741 |
+
selected_attn_mask = tokens['attention_mask'][selected_mask]
|
742 |
+
query_emb = selected_token_emb[selected_attn_mask.bool()]
|
743 |
+
|
744 |
+
class_idx = tokens['attention_mask'].sum(dim=-1) - 1
|
745 |
+
class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
|
746 |
+
class_emb = token_emb[class_idx]
|
747 |
+
|
748 |
+
target_dict['grounding_masks'] = padded_masks
|
749 |
+
target_dict['grounding_query_embs'] = query_emb
|
750 |
+
target_dict['grounding_class_embs'] = class_emb
|
751 |
+
target_dict['grounding_hash'] = grd_hash
|
752 |
+
target_dict['grounding_task'] = grd_task
|
753 |
+
|
754 |
+
new_targets.append(target_dict)
|
755 |
+
return new_targets
|
756 |
+
|
757 |
+
def prepare_next_spaital_mask(self, outputs, batched_inputs):
|
758 |
+
gt_masks = [batched_inputs[i]['spatial_query']['gt_masks'] for i in range(len(batched_inputs))]
|
759 |
+
if self.training:
|
760 |
+
gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor
|
761 |
+
else:
|
762 |
+
gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor.transpose(0,1)
|
763 |
+
|
764 |
+
pred_masks = (F.interpolate(outputs['prev_mask'], size=gt_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5)
|
765 |
+
prev_masks = torch.stack(outputs['spatial_query_pos_mask']) | torch.stack(outputs['spatial_query_neg_mask'])
|
766 |
+
|
767 |
+
fn = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks) # fn: False Negative, gt:1, pred:0, prev:0
|
768 |
+
fp = (~gt_masks & pred_masks) & (~prev_masks) # fp: False Positive, gt:0, pred:1, prev:0
|
769 |
+
|
770 |
+
# compute iou between gt and pred
|
771 |
+
iou = (gt_masks & pred_masks).sum(list(range(1,len(fn.shape)))) / ((gt_masks | pred_masks).sum(dim=list(range(1,len(fn.shape)))) + 1e-8)
|
772 |
+
fn_sum = fn.sum(dim=list(range(1,len(fn.shape))))
|
773 |
+
fp_sum = fp.sum(dim=list(range(1,len(fp.shape))))
|
774 |
+
|
775 |
+
is_postive = fn_sum > fp_sum
|
776 |
+
# is_postive = torch.ones(len(fn_sum), device=torch.cuda.current_device()).bool()
|
777 |
+
select_mask = torch.stack([fn[i] if is_postive[i] else fp[i] for i in range(len(fn))])
|
778 |
+
|
779 |
+
# conv implementation
|
780 |
+
n,_,h,w=select_mask.shape
|
781 |
+
mask_dt = (distance_transform((~F.pad(select_mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(n,-1)
|
782 |
+
max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
|
783 |
+
next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
|
784 |
+
next_mask = next_mask.view(n,-1)
|
785 |
+
next_mask[max_xy_idx] = True
|
786 |
+
next_mask = next_mask.reshape((n,1,h,w)).float()
|
787 |
+
dilation = 3
|
788 |
+
next_mask = F.conv2d(next_mask, self.dilation_kernel, padding=dilation//2) > 0
|
789 |
+
|
790 |
+
# determine whether next mask is zero
|
791 |
+
keep = (iou < 0.925)
|
792 |
+
next_mask = next_mask & keep.view(-1,1,1,1)
|
793 |
+
|
794 |
+
pos_mask = []
|
795 |
+
neg_mask = []
|
796 |
+
for idx, ip in enumerate(is_postive):
|
797 |
+
if ip:
|
798 |
+
pos_mask += [outputs['spatial_query_pos_mask'][idx] | next_mask[idx]]
|
799 |
+
neg_mask += [outputs['spatial_query_neg_mask'][idx]]
|
800 |
+
else:
|
801 |
+
pos_mask += [outputs['spatial_query_pos_mask'][idx]]
|
802 |
+
neg_mask += [outputs['spatial_query_neg_mask'][idx] | next_mask[idx]]
|
803 |
+
|
804 |
+
if 'false_positive_mask' in outputs:
|
805 |
+
fp = outputs['false_positive_mask'] | fp
|
806 |
+
return {'spatial_query_pos_mask': pos_mask, 'spatial_query_neg_mask': neg_mask, 'false_positive_mask': fp}
|
807 |
+
|
808 |
+
def semantic_inference(self, mask_cls, mask_pred):
|
809 |
+
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
|
810 |
+
mask_pred = mask_pred.sigmoid()
|
811 |
+
semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
|
812 |
+
return semseg
|
813 |
+
|
814 |
+
def panoptic_inference(self, mask_cls, mask_pred):
|
815 |
+
scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
|
816 |
+
mask_pred = mask_pred.sigmoid()
|
817 |
+
|
818 |
+
keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
|
819 |
+
cur_scores = scores[keep]
|
820 |
+
cur_classes = labels[keep]
|
821 |
+
cur_masks = mask_pred[keep]
|
822 |
+
cur_mask_cls = mask_cls[keep]
|
823 |
+
cur_mask_cls = cur_mask_cls[:, :-1]
|
824 |
+
|
825 |
+
cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
|
826 |
+
|
827 |
+
h, w = cur_masks.shape[-2:]
|
828 |
+
panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
|
829 |
+
segments_info = []
|
830 |
+
|
831 |
+
current_segment_id = 0
|
832 |
+
|
833 |
+
if cur_masks.shape[0] == 0:
|
834 |
+
# We didn't detect any mask :(
|
835 |
+
return panoptic_seg, segments_info
|
836 |
+
else:
|
837 |
+
# take argmax
|
838 |
+
cur_mask_ids = cur_prob_masks.argmax(0)
|
839 |
+
stuff_memory_list = {}
|
840 |
+
for k in range(cur_classes.shape[0]):
|
841 |
+
pred_class = cur_classes[k].item()
|
842 |
+
isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
|
843 |
+
mask_area = (cur_mask_ids == k).sum().item()
|
844 |
+
original_area = (cur_masks[k] >= 0.5).sum().item()
|
845 |
+
mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
|
846 |
+
|
847 |
+
if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
|
848 |
+
if mask_area / original_area < self.overlap_threshold:
|
849 |
+
continue
|
850 |
+
|
851 |
+
# merge stuff regions
|
852 |
+
if not isthing:
|
853 |
+
if int(pred_class) in stuff_memory_list.keys():
|
854 |
+
panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
|
855 |
+
continue
|
856 |
+
else:
|
857 |
+
stuff_memory_list[int(pred_class)] = current_segment_id + 1
|
858 |
+
|
859 |
+
current_segment_id += 1
|
860 |
+
panoptic_seg[mask] = current_segment_id
|
861 |
+
|
862 |
+
segments_info.append(
|
863 |
+
{
|
864 |
+
"id": current_segment_id,
|
865 |
+
"isthing": bool(isthing),
|
866 |
+
"category_id": int(pred_class),
|
867 |
+
}
|
868 |
+
)
|
869 |
+
|
870 |
+
return panoptic_seg, segments_info
|
871 |
+
|
872 |
+
def instance_inference(self, mask_cls, mask_pred, box_pred):
|
873 |
+
# mask_pred is already processed to have the same shape as original input
|
874 |
+
image_size = mask_pred.shape[-2:]
|
875 |
+
|
876 |
+
# [Q, K]
|
877 |
+
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
|
878 |
+
labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
|
879 |
+
# scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
|
880 |
+
scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
|
881 |
+
|
882 |
+
labels_per_image = labels[topk_indices]
|
883 |
+
topk_indices = (topk_indices // self.sem_seg_head.num_classes)
|
884 |
+
# mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
|
885 |
+
mask_pred = mask_pred[topk_indices]
|
886 |
+
if box_pred is not None:
|
887 |
+
box_pred = box_pred[topk_indices]
|
888 |
+
|
889 |
+
# if this is panoptic segmentation, we only keep the "thing" classes
|
890 |
+
if self.panoptic_on:
|
891 |
+
keep = torch.zeros_like(scores_per_image).bool()
|
892 |
+
for i, lab in enumerate(labels_per_image):
|
893 |
+
keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
|
894 |
+
|
895 |
+
scores_per_image = scores_per_image[keep]
|
896 |
+
labels_per_image = labels_per_image[keep]
|
897 |
+
mask_pred = mask_pred[keep]
|
898 |
+
|
899 |
+
if box_pred is not None:
|
900 |
+
box_pred = box_pred[keep]
|
901 |
+
|
902 |
+
result = Instances(image_size)
|
903 |
+
# mask (before sigmoid)
|
904 |
+
result.pred_masks = (mask_pred > 0).float()
|
905 |
+
# result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
906 |
+
# Uncomment the following to get boxes from masks (this is slow)
|
907 |
+
|
908 |
+
if box_pred is not None:
|
909 |
+
result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
|
910 |
+
else:
|
911 |
+
result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
912 |
+
|
913 |
+
# calculate average mask prob
|
914 |
+
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
|
915 |
+
result.scores = scores_per_image * mask_scores_per_image
|
916 |
+
result.pred_classes = labels_per_image
|
917 |
+
|
918 |
+
return result
|
919 |
+
|
920 |
+
|
921 |
+
@register_model
|
922 |
+
def get_seem_model(cfg, **kwargs):
|
923 |
+
return GeneralizedSEEM(cfg)
|
modeling/architectures/seem_model_v0.py
ADDED
@@ -0,0 +1,1160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# SEEM -- Segment Everything Everywhere All at Once
|
3 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
4 |
+
# Written by Xueyan Zou ([email protected])
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
import random
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
from kornia.contrib import distance_transform
|
15 |
+
|
16 |
+
from detectron2.structures import Boxes, ImageList, Instances, BitMasks
|
17 |
+
from detectron2.utils.memory import retry_if_cuda_oom
|
18 |
+
from detectron2.data import MetadataCatalog
|
19 |
+
|
20 |
+
from .build import register_model
|
21 |
+
|
22 |
+
from ..utils import configurable, get_class_names, get_iou
|
23 |
+
from ..vision.backbone import build_backbone, Backbone
|
24 |
+
from ..body import build_xdecoder_head
|
25 |
+
from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
|
26 |
+
from ..language import build_language_encoder
|
27 |
+
from ..language.loss import vl_similarity
|
28 |
+
from utilities.prompt_engineering import prompt_engineering
|
29 |
+
from utilities.constants import COCO_PANOPTIC_CLASSES
|
30 |
+
|
31 |
+
|
32 |
+
class GeneralizedSEEM(nn.Module):
|
33 |
+
|
34 |
+
@configurable
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
*,
|
38 |
+
backbone: Backbone,
|
39 |
+
sem_seg_head: nn.Module,
|
40 |
+
criterion: nn.Module,
|
41 |
+
losses: dict,
|
42 |
+
num_queries: int,
|
43 |
+
object_mask_threshold: float,
|
44 |
+
overlap_threshold: float,
|
45 |
+
metadata,
|
46 |
+
task_switch: dict,
|
47 |
+
phrase_prob: float,
|
48 |
+
size_divisibility: int,
|
49 |
+
sem_seg_postprocess_before_inference: bool,
|
50 |
+
pixel_mean: Tuple[float],
|
51 |
+
pixel_std: Tuple[float],
|
52 |
+
# inference
|
53 |
+
semantic_on: bool,
|
54 |
+
panoptic_on: bool,
|
55 |
+
instance_on: bool,
|
56 |
+
test_topk_per_image: int,
|
57 |
+
train_dataset_name: str,
|
58 |
+
interactive_mode: str,
|
59 |
+
interactive_iter: str,
|
60 |
+
dilation_kernel: torch.Tensor,
|
61 |
+
train_max_iter: int,
|
62 |
+
):
|
63 |
+
"""
|
64 |
+
Args:
|
65 |
+
backbone: a backbone module, must follow detectron2's backbone interface
|
66 |
+
sem_seg_head: a module that predicts semantic segmentation from backbone features
|
67 |
+
criterion: a module that defines the loss
|
68 |
+
num_queries: int, number of queries
|
69 |
+
object_mask_threshold: float, threshold to filter query based on classification score
|
70 |
+
for panoptic segmentation inference
|
71 |
+
overlap_threshold: overlap threshold used in general inference for panoptic segmentation
|
72 |
+
metadata: dataset meta, get `thing` and `stuff` category names for panoptic
|
73 |
+
segmentation inference
|
74 |
+
size_divisibility: Some backbones require the input height and width to be divisible by a
|
75 |
+
specific integer. We can use this to override such requirement.
|
76 |
+
sem_seg_postprocess_before_inference: whether to resize the prediction back
|
77 |
+
to original input size before semantic segmentation inference or after.
|
78 |
+
For high-resolution dataset like Mapillary, resizing predictions before
|
79 |
+
inference will cause OOM error.
|
80 |
+
pixel_mean, pixel_std: list or tuple with #channels element, representing
|
81 |
+
the per-channel mean and std to be used to normalize the input image
|
82 |
+
semantic_on: bool, whether to output semantic segmentation prediction
|
83 |
+
instance_on: bool, whether to output instance segmentation prediction
|
84 |
+
panoptic_on: bool, whether to output panoptic segmentation prediction
|
85 |
+
test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
|
86 |
+
"""
|
87 |
+
super().__init__()
|
88 |
+
self.backbone = backbone
|
89 |
+
self.sem_seg_head = sem_seg_head
|
90 |
+
self.criterion = criterion
|
91 |
+
self.losses = losses
|
92 |
+
self.num_queries = num_queries
|
93 |
+
self.overlap_threshold = overlap_threshold
|
94 |
+
self.object_mask_threshold = object_mask_threshold
|
95 |
+
self.metadata = metadata
|
96 |
+
if size_divisibility < 0:
|
97 |
+
# use backbone size_divisibility if not set
|
98 |
+
size_divisibility = self.backbone.size_divisibility
|
99 |
+
self.size_divisibility = size_divisibility
|
100 |
+
self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
|
101 |
+
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
102 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
103 |
+
|
104 |
+
# additional args
|
105 |
+
self.semantic_on = semantic_on
|
106 |
+
self.instance_on = instance_on
|
107 |
+
self.panoptic_on = panoptic_on
|
108 |
+
|
109 |
+
# caption argument
|
110 |
+
self.task_switch = task_switch
|
111 |
+
self.phrase_prob = phrase_prob
|
112 |
+
self.train_max_iter = train_max_iter
|
113 |
+
|
114 |
+
self.test_topk_per_image = test_topk_per_image
|
115 |
+
self.train_class_names = get_class_names(train_dataset_name)
|
116 |
+
self.interactive_mode = interactive_mode
|
117 |
+
self.interactive_iter = interactive_iter
|
118 |
+
|
119 |
+
if not self.semantic_on:
|
120 |
+
assert self.sem_seg_postprocess_before_inference
|
121 |
+
|
122 |
+
self.register_buffer("dilation_kernel", dilation_kernel)
|
123 |
+
|
124 |
+
@classmethod
|
125 |
+
def from_config(cls, cfg):
|
126 |
+
enc_cfg = cfg['MODEL']['ENCODER']
|
127 |
+
dec_cfg = cfg['MODEL']['DECODER']
|
128 |
+
|
129 |
+
# Loss parameters:
|
130 |
+
deep_supervision = dec_cfg['DEEP_SUPERVISION']
|
131 |
+
no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']
|
132 |
+
|
133 |
+
# loss weights
|
134 |
+
loss_weights = {'mask': {'ce': dec_cfg['CLASS_WEIGHT'], 'dice': dec_cfg['DICE_WEIGHT'], 'bce': dec_cfg['MASK_WEIGHT']},
|
135 |
+
'bbox': {'l1': dec_cfg['BBOX_WEIGHT'], 'giou': dec_cfg['GIOU_WEIGHT']},
|
136 |
+
'spatial': {'ce': dec_cfg['SCLASS_WEIGHT'], 'dice': dec_cfg['SDICE_WEIGHT'], 'bce': dec_cfg['SMASK_WEIGHT']},
|
137 |
+
'grounding': {'ce': dec_cfg['GCLASS_WEIGHT'], 'dice': dec_cfg['GDICE_WEIGHT'], 'bce': dec_cfg['GMASK_WEIGHT']},
|
138 |
+
'openimage': {'ce': dec_cfg['OCLASS_WEIGHT'], 'dice': dec_cfg['ODICE_WEIGHT'], 'bce': dec_cfg['OMASK_WEIGHT']}}
|
139 |
+
|
140 |
+
openimage_switch = {'grounding': dec_cfg['OPENIMAGE']['GROUNDING'].get('ENABLED', False),
|
141 |
+
'mask': dec_cfg['OPENIMAGE'].get('ENABLED', False)}
|
142 |
+
|
143 |
+
task_switch = {'bbox': dec_cfg.get('DETECTION', False),
|
144 |
+
'mask': dec_cfg['MASK'].get('ENABLED', True),
|
145 |
+
'spatial': dec_cfg['SPATIAL'].get('ENABLED', False),
|
146 |
+
'grounding': dec_cfg['GROUNDING'].get('ENABLED', False),
|
147 |
+
'openimage': openimage_switch}
|
148 |
+
|
149 |
+
top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10),
|
150 |
+
'grounding': dec_cfg.get('TOP_GROUNDING_LAYERS', 10),
|
151 |
+
'openimage': dec_cfg.get('TOP_OPENIMAGE_LAYERS', 10),
|
152 |
+
'spatial': dec_cfg.get('TOP_SPATIAL_LAYERS', 10)}
|
153 |
+
|
154 |
+
spatial_cost = {"class_weight": dec_cfg['COST_SPATIAL']['CLASS_WEIGHT'],
|
155 |
+
"mask_weight": dec_cfg['COST_SPATIAL']['MASK_WEIGHT'],
|
156 |
+
"dice_weight": dec_cfg['COST_SPATIAL']['DICE_WEIGHT']}
|
157 |
+
|
158 |
+
extra = {'task_switch': task_switch}
|
159 |
+
backbone = build_backbone(cfg)
|
160 |
+
lang_encoder = build_language_encoder(cfg)
|
161 |
+
sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra=extra)
|
162 |
+
|
163 |
+
# building criterion
|
164 |
+
matcher = HungarianMatcher(
|
165 |
+
cost_class=loss_weights['mask']['ce'],
|
166 |
+
cost_mask=loss_weights['mask']['bce'],
|
167 |
+
cost_dice=loss_weights['mask']['dice'],
|
168 |
+
num_points=dec_cfg['TRAIN_NUM_POINTS'],
|
169 |
+
spatial_cost=spatial_cost,
|
170 |
+
)
|
171 |
+
|
172 |
+
# init weight dict and criterion loss functions.
|
173 |
+
losses = {'seg': [], 'openimage': []}
|
174 |
+
if task_switch['mask']:
|
175 |
+
losses['seg'] += ["labels", "masks"]
|
176 |
+
if task_switch['spatial']:
|
177 |
+
losses['seg'] += ["spatials"]
|
178 |
+
if task_switch['grounding']:
|
179 |
+
losses['seg'] += ["groundings"]
|
180 |
+
if task_switch['openimage']:
|
181 |
+
losses['openimage'] += ["labels_openimage", "masks"]
|
182 |
+
if task_switch['openimage']['grounding']:
|
183 |
+
losses['openimage'] += ["groundings"]
|
184 |
+
|
185 |
+
weight_dict = {}
|
186 |
+
for key, turn_on in task_switch.items():
|
187 |
+
if turn_on:
|
188 |
+
if isinstance(loss_weights[key], dict):
|
189 |
+
# HACK it should support bbox in the future
|
190 |
+
for key_, weight in loss_weights[key].items():
|
191 |
+
weight_dict["loss_{}_{}_0".format(key, key_)] = weight # NOTE: hard code for segmentation that has multiple loss
|
192 |
+
else:
|
193 |
+
weight_dict["loss_{}_0".format(key)] = loss_weights[key]
|
194 |
+
|
195 |
+
# generate full weight dict and remove not computed layers.
|
196 |
+
if deep_supervision:
|
197 |
+
dec_layers = dec_cfg['DEC_LAYERS']
|
198 |
+
aux_weight_dict = {}
|
199 |
+
for i in range(dec_layers - 1):
|
200 |
+
for k, v in weight_dict.items():
|
201 |
+
if (i+1) > (top_x_layers[k.split('_')[1]] - 1):
|
202 |
+
continue
|
203 |
+
aux_weight_dict.update({k.replace('_0', f"_{i+1}"): v})
|
204 |
+
weight_dict.update(aux_weight_dict)
|
205 |
+
|
206 |
+
grd_weight = {'text': dec_cfg['GROUNDING']['TEXT_WEIGHT'], 'class': dec_cfg['GROUNDING']['CLASS_WEIGHT']}
|
207 |
+
# generate critenrion for loss function.
|
208 |
+
criterion = SetCriterion(
|
209 |
+
sem_seg_head.num_classes,
|
210 |
+
matcher=matcher,
|
211 |
+
weight_dict=weight_dict,
|
212 |
+
top_x_layers=top_x_layers,
|
213 |
+
eos_coef=no_object_weight,
|
214 |
+
losses=[],
|
215 |
+
num_points=dec_cfg['TRAIN_NUM_POINTS'],
|
216 |
+
oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],
|
217 |
+
importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],
|
218 |
+
grounding_weight=grd_weight,
|
219 |
+
)
|
220 |
+
|
221 |
+
# extra logistic
|
222 |
+
train_dataset_name = cfg['DATASETS']['TRAIN'][0] # HACK for only one training set.
|
223 |
+
train_max_iter = dec_cfg['SPATIAL'].get('MAX_ITER', 3)
|
224 |
+
phrase_prob = dec_cfg['CAPTION'].get('PHRASE_PROB', 0.5)
|
225 |
+
interactive_mode = cfg['STROKE_SAMPLER']['EVAL']['MODE']
|
226 |
+
interactive_iter = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
|
227 |
+
|
228 |
+
dilation = 3
|
229 |
+
dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
|
230 |
+
|
231 |
+
return {
|
232 |
+
"backbone": backbone,
|
233 |
+
"sem_seg_head": sem_seg_head,
|
234 |
+
"criterion": criterion,
|
235 |
+
"losses": losses,
|
236 |
+
"num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
|
237 |
+
"object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
|
238 |
+
"overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
|
239 |
+
"metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),
|
240 |
+
"size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
|
241 |
+
"sem_seg_postprocess_before_inference": (
|
242 |
+
dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
|
243 |
+
or dec_cfg['TEST']['PANOPTIC_ON']
|
244 |
+
or dec_cfg['TEST']['INSTANCE_ON']
|
245 |
+
),
|
246 |
+
"pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
|
247 |
+
"pixel_std": cfg['INPUT']['PIXEL_STD'],
|
248 |
+
"task_switch": task_switch,
|
249 |
+
"phrase_prob": phrase_prob,
|
250 |
+
# inference
|
251 |
+
"semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
|
252 |
+
"instance_on": dec_cfg['TEST']['INSTANCE_ON'],
|
253 |
+
"panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
|
254 |
+
"test_topk_per_image": cfg['TEST']['DETECTIONS_PER_IMAGE'],
|
255 |
+
"train_dataset_name": train_dataset_name,
|
256 |
+
"interactive_mode": interactive_mode,
|
257 |
+
"interactive_iter": interactive_iter,
|
258 |
+
"dilation_kernel": dilation_kernel,
|
259 |
+
"train_max_iter": train_max_iter,
|
260 |
+
}
|
261 |
+
|
262 |
+
@property
|
263 |
+
def device(self):
|
264 |
+
return self.pixel_mean.device
|
265 |
+
|
266 |
+
def forward(self, batched_inputs, mode='default'):
|
267 |
+
"""
|
268 |
+
Args:
|
269 |
+
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
|
270 |
+
Each item in the list contains the inputs for one image.
|
271 |
+
For now, each item in the list is a dict that contains:
|
272 |
+
* "image": Tensor, image in (C, H, W) format.
|
273 |
+
* "instances": per-region ground truth
|
274 |
+
* Other information that's included in the original dicts, such as:
|
275 |
+
"height", "width" (int): the output resolution of the model (may be different
|
276 |
+
from input resolution), used in inference.
|
277 |
+
Returns:
|
278 |
+
list[dict]:
|
279 |
+
each dict has the results for one image. The dict contains the following keys:
|
280 |
+
|
281 |
+
* "sem_seg":
|
282 |
+
A Tensor that represents the
|
283 |
+
per-pixel segmentation prediced by the head.
|
284 |
+
The prediction has shape KxHxW that represents the logits of
|
285 |
+
each class for each pixel.
|
286 |
+
* "panoptic_seg":
|
287 |
+
A tuple that represent panoptic output
|
288 |
+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
|
289 |
+
segments_info (list[dict]): Describe each segment in `panoptic_seg`.
|
290 |
+
Each dict contains keys "id", "category_id", "isthing".
|
291 |
+
"""
|
292 |
+
if self.training:
|
293 |
+
losses = {}
|
294 |
+
if self.task_switch['mask'] or self.task_switch['grounding'] or self.task_switch['spatial']:
|
295 |
+
losses_seg = self.forward_seg(batched_inputs)
|
296 |
+
losses.update(losses_seg)
|
297 |
+
if self.task_switch['openimage'] and self.task_switch['openimage']['mask']:
|
298 |
+
losses_openimage = self.forward_openimage(batched_inputs['openimage'])
|
299 |
+
losses_openimage = {key.replace('mask', 'openimage'):value for key, value in losses_openimage.items()}
|
300 |
+
losses_openimage = {key.replace('grounding', 'grounding_openimage'):value for key, value in losses_openimage.items()}
|
301 |
+
losses.update(losses_openimage)
|
302 |
+
for k in list(losses.keys()):
|
303 |
+
if k in self.criterion.weight_dict:
|
304 |
+
losses[k] *= self.criterion.weight_dict[k]
|
305 |
+
else: # remove this loss if not specified in `weight_dict`
|
306 |
+
losses.pop(k)
|
307 |
+
return losses
|
308 |
+
else:
|
309 |
+
if mode == 'interactive':
|
310 |
+
return self.evaluate_interactive(batched_inputs)
|
311 |
+
elif mode == 'interactive_grounding':
|
312 |
+
return self.evaluate_interactive_grounding(batched_inputs)
|
313 |
+
elif mode == 'grounding_spatial':
|
314 |
+
return self.evaluate_grounding_sptial(batched_inputs, mode)
|
315 |
+
elif mode in ['grounding_phrasecut', 'grounding_refcoco']:
|
316 |
+
return self.evaluate_grounding(batched_inputs, mode)
|
317 |
+
else:
|
318 |
+
return self.evaluate(batched_inputs)
|
319 |
+
|
320 |
+
|
321 |
+
def forward_seg(self, batched_inputs):
|
322 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
323 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
324 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
325 |
+
|
326 |
+
self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
|
327 |
+
|
328 |
+
extra = {}
|
329 |
+
# mask classification target
|
330 |
+
if "instances" in batched_inputs[0]:
|
331 |
+
# input bounding box is checked to be correct.
|
332 |
+
targets = self.prepare_targets(batched_inputs, images)
|
333 |
+
|
334 |
+
if self.task_switch['grounding']:
|
335 |
+
grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
|
336 |
+
grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens, padding_value=-1)
|
337 |
+
non_zero_query_mask = (grounding_tokens.sum(dim=-1) == -grounding_tokens.shape[-1])
|
338 |
+
grounding_tokens[non_zero_query_mask] = 0
|
339 |
+
|
340 |
+
extra['grounding_tokens'] = grounding_tokens
|
341 |
+
extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
342 |
+
|
343 |
+
if self.task_switch['spatial']:
|
344 |
+
pos_masks = [x['spatial_query']['rand_shape'].to(self.device) for x in batched_inputs]
|
345 |
+
neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs]
|
346 |
+
fp_masks = torch.stack([(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs])
|
347 |
+
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks, 'false_positive_mask': fp_masks})
|
348 |
+
|
349 |
+
features = self.backbone(images.tensor)
|
350 |
+
mask_features, _, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
351 |
+
|
352 |
+
# forward spatial only without gradient
|
353 |
+
if self.task_switch['spatial']:
|
354 |
+
with torch.no_grad():
|
355 |
+
# generate random integeter between [0,3]
|
356 |
+
rand_iter_num = random.randint(0, self.train_max_iter)
|
357 |
+
for i in range(rand_iter_num):
|
358 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='spatial')
|
359 |
+
extra.update(outputs)
|
360 |
+
extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
|
361 |
+
|
362 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='seg')
|
363 |
+
|
364 |
+
extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
|
365 |
+
'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default')),
|
366 |
+
'false_positive_mask': extra['false_positive_mask']}
|
367 |
+
# bipartite matching-based loss
|
368 |
+
self.criterion.losses = self.losses['seg'] # seg criterion losses
|
369 |
+
losses = self.criterion(outputs, targets, extra)
|
370 |
+
|
371 |
+
del outputs
|
372 |
+
return losses
|
373 |
+
|
374 |
+
def evaluate(self, batched_inputs):
|
375 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
376 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
377 |
+
|
378 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
379 |
+
img_bs = images.tensor.shape[0]
|
380 |
+
|
381 |
+
targets = targets_grounding = queries_grounding = None
|
382 |
+
features = self.backbone(images.tensor)
|
383 |
+
outputs = self.sem_seg_head(features, target_queries=queries_grounding)
|
384 |
+
|
385 |
+
mask_cls_results = outputs["pred_logits"]
|
386 |
+
mask_pred_results = outputs["pred_masks"]
|
387 |
+
box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
|
388 |
+
|
389 |
+
# upsample masks
|
390 |
+
mask_pred_results = F.interpolate(
|
391 |
+
mask_pred_results,
|
392 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
393 |
+
mode="bilinear",
|
394 |
+
align_corners=False,
|
395 |
+
)
|
396 |
+
|
397 |
+
input_size = mask_pred_results.shape[-2:]
|
398 |
+
del outputs
|
399 |
+
|
400 |
+
processed_results = []
|
401 |
+
for mask_cls_result, mask_pred_result, box_pred_result, input_per_image, image_size in zip(
|
402 |
+
mask_cls_results, mask_pred_results, box_pred_results, batched_inputs, images.image_sizes
|
403 |
+
):
|
404 |
+
height = input_per_image.get("height", image_size[0])
|
405 |
+
width = input_per_image.get("width", image_size[1])
|
406 |
+
processed_results.append({})
|
407 |
+
|
408 |
+
if self.sem_seg_postprocess_before_inference:
|
409 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
410 |
+
mask_pred_result, image_size, height, width
|
411 |
+
)
|
412 |
+
mask_cls_result = mask_cls_result.to(mask_pred_result)
|
413 |
+
|
414 |
+
# semantic segmentation inference
|
415 |
+
if self.semantic_on:
|
416 |
+
r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
|
417 |
+
if not self.sem_seg_postprocess_before_inference:
|
418 |
+
r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
|
419 |
+
processed_results[-1]["sem_seg"] = r
|
420 |
+
|
421 |
+
# panoptic segmentation inference
|
422 |
+
if self.panoptic_on:
|
423 |
+
panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
|
424 |
+
processed_results[-1]["panoptic_seg"] = panoptic_r
|
425 |
+
|
426 |
+
# instance segmentation inference
|
427 |
+
if self.instance_on:
|
428 |
+
if self.task_switch['bbox']:
|
429 |
+
box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
|
430 |
+
instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
|
431 |
+
processed_results[-1]["instances"] = instance_r
|
432 |
+
|
433 |
+
return processed_results
|
434 |
+
|
435 |
+
def evaluate_interactive(self, batched_inputs):
|
436 |
+
assert self.task_switch['spatial']
|
437 |
+
assert 'spatial_query' in batched_inputs[0]
|
438 |
+
assert len(batched_inputs) == 1, "only support batch size equal to 1"
|
439 |
+
|
440 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
441 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
442 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
443 |
+
img_bs = images.tensor.shape[0]
|
444 |
+
|
445 |
+
targets = targets_grounding = queries_grounding = None
|
446 |
+
extra = {}
|
447 |
+
|
448 |
+
features = self.backbone(images.tensor)
|
449 |
+
mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
450 |
+
|
451 |
+
image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
|
452 |
+
nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
|
453 |
+
multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
|
454 |
+
mask_features = mask_features.repeat(nm,1,1,1)
|
455 |
+
|
456 |
+
all_batch_shape_iou = []
|
457 |
+
pred_smask_pointer = None
|
458 |
+
prev_smask_pointer = None
|
459 |
+
pred_smask_all = None
|
460 |
+
|
461 |
+
# visualization code
|
462 |
+
# v_pred_mask = []
|
463 |
+
# v_pos_mask = []
|
464 |
+
# v_neg_mask = []
|
465 |
+
# v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
|
466 |
+
query_index = self.sem_seg_head.predictor.query_index
|
467 |
+
if self.interactive_mode in ['best', 'best_random']:
|
468 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
|
469 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
|
470 |
+
|
471 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
|
472 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
|
473 |
+
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
|
474 |
+
elif self.interactive_mode == 'random':
|
475 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
|
476 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
|
477 |
+
|
478 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
|
479 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
|
480 |
+
extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
|
481 |
+
else:
|
482 |
+
assert False, "invalid interactive mode"
|
483 |
+
|
484 |
+
for i in range(self.interactive_iter):
|
485 |
+
# v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
|
486 |
+
# v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
|
487 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
|
488 |
+
extra.update(outputs)
|
489 |
+
pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
|
490 |
+
# v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
|
491 |
+
|
492 |
+
s = image_sizes[0]
|
493 |
+
b = batched_inputs[0]
|
494 |
+
pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[:,0].sigmoid() > 0.5
|
495 |
+
gt_smask = b['gt_masks_orisize']
|
496 |
+
ious = get_iou(gt_smask, pred_smask_all)
|
497 |
+
all_batch_shape_iou += [ious]
|
498 |
+
if (ious > 0.9).sum() == len(ious):
|
499 |
+
all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
|
500 |
+
break
|
501 |
+
if self.interactive_mode in ['best', 'best_random']:
|
502 |
+
extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
|
503 |
+
elif self.interactive_mode == 'random':
|
504 |
+
extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
|
505 |
+
else:
|
506 |
+
assert False, "invalid interactive mode"
|
507 |
+
all_batch_shape_iou = torch.stack(all_batch_shape_iou)
|
508 |
+
processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
|
509 |
+
|
510 |
+
return processed_results
|
511 |
+
|
512 |
+
def evaluate_interactive_single(self, batched_inputs, extra={}):
|
513 |
+
assert self.task_switch['spatial']
|
514 |
+
assert 'spatial_query' in batched_inputs[0]
|
515 |
+
assert len(batched_inputs) == 1, "only support batch size equal to 1"
|
516 |
+
|
517 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
518 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
519 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
520 |
+
img_bs = images.tensor.shape[0]
|
521 |
+
|
522 |
+
targets = targets_grounding = queries_grounding = None
|
523 |
+
|
524 |
+
features = self.backbone(images.tensor)
|
525 |
+
mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
526 |
+
|
527 |
+
image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
|
528 |
+
nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
|
529 |
+
multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
|
530 |
+
mask_features = mask_features.repeat(nm,1,1,1)
|
531 |
+
|
532 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
|
533 |
+
pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
|
534 |
+
|
535 |
+
s = image_sizes[0]
|
536 |
+
b = batched_inputs[0]
|
537 |
+
pred_smask_ori = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
|
538 |
+
pred_smask_batch = pred_smask[:,:,:s[0],:s[1]].sigmoid() > 0.5
|
539 |
+
ious = []
|
540 |
+
if 'gt_masks_orisize' in b:
|
541 |
+
gt_smask = b['gt_masks_orisize'].to(pred_smask_ori.device)
|
542 |
+
ious = get_iou(gt_smask, pred_smask_ori)
|
543 |
+
processed_results = [{"mask_iou": ious, 'pred_mask_ori': pred_smask_ori, 'pred_mask_batch': pred_smask_batch}]
|
544 |
+
return processed_results
|
545 |
+
|
546 |
+
def evaluate_interactive_grounding(self, batched_inputs):
|
547 |
+
assert self.task_switch['spatial']
|
548 |
+
assert 'spatial_query' in batched_inputs[0]
|
549 |
+
assert len(batched_inputs) == 1, "only support batch size equal to 1"
|
550 |
+
|
551 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
552 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
553 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
554 |
+
img_bs = images.tensor.shape[0]
|
555 |
+
|
556 |
+
targets = targets_grounding = queries_grounding = None
|
557 |
+
extra = {}
|
558 |
+
|
559 |
+
features = self.backbone(images.tensor)
|
560 |
+
mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
561 |
+
|
562 |
+
image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
|
563 |
+
nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
|
564 |
+
multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
|
565 |
+
mask_features = mask_features.repeat(nm,1,1,1)
|
566 |
+
|
567 |
+
all_batch_shape_iou = []
|
568 |
+
pred_smask_pointer = None
|
569 |
+
prev_smask_pointer = None
|
570 |
+
pred_smask_all = None
|
571 |
+
|
572 |
+
# visualization code
|
573 |
+
# v_pred_mask = []
|
574 |
+
# v_pos_mask = []
|
575 |
+
# v_neg_mask = []
|
576 |
+
# v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
|
577 |
+
query_index = self.sem_seg_head.predictor.query_index
|
578 |
+
if self.interactive_mode in ['best', 'best_random']:
|
579 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
|
580 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
|
581 |
+
|
582 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
|
583 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
|
584 |
+
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
|
585 |
+
elif self.interactive_mode == 'random':
|
586 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
|
587 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
|
588 |
+
|
589 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
|
590 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
|
591 |
+
extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
|
592 |
+
else:
|
593 |
+
assert False, "invalid interactive mode"
|
594 |
+
|
595 |
+
grd_texts = batched_inputs[0]['classes']
|
596 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
|
597 |
+
token_emb = gtext['token_emb']
|
598 |
+
tokens = gtext['tokens']
|
599 |
+
query_emb = nn.utils.rnn.pad_sequence([_token_emb[_tokens.bool()] for _token_emb, _tokens in zip(token_emb, tokens['attention_mask'])], padding_value=-1)
|
600 |
+
non_zero_query_mask = (query_emb.sum(dim=-1) < 0)
|
601 |
+
|
602 |
+
extra['grounding_tokens'] = query_emb
|
603 |
+
extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
604 |
+
|
605 |
+
for i in range(self.interactive_iter):
|
606 |
+
# v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
|
607 |
+
# v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
|
608 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
|
609 |
+
extra.update(outputs)
|
610 |
+
pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
|
611 |
+
# v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
|
612 |
+
|
613 |
+
s = image_sizes[0]
|
614 |
+
b = batched_inputs[0]
|
615 |
+
pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[:,0].sigmoid() > 0.5
|
616 |
+
gt_smask = b['gt_masks_orisize']
|
617 |
+
ious = get_iou(gt_smask, pred_smask_all)
|
618 |
+
all_batch_shape_iou += [ious]
|
619 |
+
if (ious > 0.9).sum() == len(ious):
|
620 |
+
all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
|
621 |
+
break
|
622 |
+
if self.interactive_mode in ['best', 'best_random']:
|
623 |
+
extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
|
624 |
+
elif self.interactive_mode == 'random':
|
625 |
+
extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
|
626 |
+
else:
|
627 |
+
assert False, "invalid interactive mode"
|
628 |
+
all_batch_shape_iou = torch.stack(all_batch_shape_iou)
|
629 |
+
processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
|
630 |
+
|
631 |
+
# visualization
|
632 |
+
# VL.step()
|
633 |
+
# import cv2
|
634 |
+
# v_masks = []
|
635 |
+
# v_pos_masks = []
|
636 |
+
# v_neg_masks = []
|
637 |
+
# txt = []
|
638 |
+
|
639 |
+
# img = batched_inputs[0]['image'].permute(1,2,0).cpu().numpy()
|
640 |
+
# mask_img = VL.overlay_single_mask_to_image(img[:,:,::-1], v_gt_mask.cpu().float().numpy())
|
641 |
+
# acc_pos_mask = np.zeros(v_pos_mask[0].shape)
|
642 |
+
# acc_neg_mask = np.zeros(v_neg_mask[0].shape)
|
643 |
+
# for x,y,z,iou in zip(v_pos_mask, v_neg_mask, v_pred_mask, all_batch_shape_iou):
|
644 |
+
# # dilate x,y
|
645 |
+
# x = cv2.dilate(x, np.ones((5,5), np.uint8), iterations=3)
|
646 |
+
# y = cv2.dilate(y, np.ones((5,5), np.uint8), iterations=3)
|
647 |
+
# acc_pos_mask += x
|
648 |
+
# acc_neg_mask += y
|
649 |
+
|
650 |
+
# v_masks += [z]
|
651 |
+
# v_pos_masks += [acc_pos_mask.clip(0,1)]
|
652 |
+
# v_neg_masks += [acc_neg_mask.clip(0,1)]
|
653 |
+
# txt += ["pred_{}".format(str(iou[0].item())[0:5])]
|
654 |
+
|
655 |
+
# VL.add_image(img[:,:,::-1])
|
656 |
+
# VL.insert(mask_img, "gt_mask")
|
657 |
+
# VL.overlay_obj_mask_to_image_withposneg(img[:,:,::-1], v_masks, v_pos_masks, v_neg_masks, txt, max_len=20)
|
658 |
+
return processed_results
|
659 |
+
|
660 |
+
def evaluate_referring_image(self, batched_inputs, extra={}):
|
661 |
+
assert self.task_switch['spatial']
|
662 |
+
assert len(batched_inputs) == 1, "only support batch size equal to 1"
|
663 |
+
assert self.interactive_mode == 'best'
|
664 |
+
|
665 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
666 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
667 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
668 |
+
img_bs = images.tensor.shape[0]
|
669 |
+
|
670 |
+
targets = targets_grounding = queries_grounding = None
|
671 |
+
features = self.backbone(images.tensor)
|
672 |
+
mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
673 |
+
|
674 |
+
if 'spatial_query' in batched_inputs[0]:
|
675 |
+
image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
|
676 |
+
nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
|
677 |
+
multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
|
678 |
+
mask_features = mask_features.repeat(nm,1,1,1)
|
679 |
+
|
680 |
+
query_index = self.sem_seg_head.predictor.query_index
|
681 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
|
682 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
|
683 |
+
|
684 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
|
685 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
|
686 |
+
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
|
687 |
+
|
688 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='refimg')
|
689 |
+
return outputs, images.tensor.shape
|
690 |
+
|
691 |
+
def evaluate_grounding(self, batched_inputs, mode):
|
692 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
693 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
694 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
695 |
+
assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
|
696 |
+
|
697 |
+
extra = {}
|
698 |
+
# mask_pred_results = []
|
699 |
+
# for idx, batch_per_image in enumerate(batched_inputs):
|
700 |
+
# grd_texts = batch_per_image['groundings']['texts']
|
701 |
+
# grd_masks = []
|
702 |
+
# for anno_text in grd_texts:
|
703 |
+
# gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
|
704 |
+
# token_emb = gtext['token_emb']
|
705 |
+
# tokens = gtext['tokens']
|
706 |
+
|
707 |
+
# grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
|
708 |
+
# extra['grounding_tokens'] = grd_emb[:,None]
|
709 |
+
|
710 |
+
# assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
|
711 |
+
# features = self.backbone(images.tensor)
|
712 |
+
# outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
713 |
+
|
714 |
+
# pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
|
715 |
+
# v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
|
716 |
+
# t_emb = grd_emb[-1:]
|
717 |
+
|
718 |
+
# t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
719 |
+
# v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
720 |
+
|
721 |
+
# temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
722 |
+
# out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
723 |
+
|
724 |
+
# matched_id = out_prob.max(0)[1]
|
725 |
+
# grd_masks += [pred_gmasks[matched_id,:,:]]
|
726 |
+
# mask_pred_results += [torch.cat(grd_masks)]
|
727 |
+
|
728 |
+
# comment for multi object inference.
|
729 |
+
mask_pred_results = []
|
730 |
+
for idx, batch_per_image in enumerate(batched_inputs):
|
731 |
+
grd_texts = batch_per_image['groundings']['texts']
|
732 |
+
grd_texts = [x[0] for x in grd_texts]
|
733 |
+
|
734 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
|
735 |
+
token_emb = gtext['token_emb']
|
736 |
+
tokens = gtext['tokens']
|
737 |
+
query_emb = token_emb[tokens['attention_mask'].bool()]
|
738 |
+
non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
|
739 |
+
|
740 |
+
extra['grounding_tokens'] = query_emb[:,None]
|
741 |
+
extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
742 |
+
|
743 |
+
features = self.backbone(images.tensor)
|
744 |
+
outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
745 |
+
|
746 |
+
pred_gmasks = outputs['pred_gmasks'][idx]
|
747 |
+
v_emb = outputs['pred_gtexts'][idx]
|
748 |
+
t_emb = gtext['class_emb']
|
749 |
+
|
750 |
+
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
751 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
752 |
+
|
753 |
+
temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
754 |
+
out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
755 |
+
|
756 |
+
matched_id = out_prob.max(0)[1]
|
757 |
+
mask_pred_results += [pred_gmasks[matched_id,:,:]]
|
758 |
+
|
759 |
+
for i in range(len(mask_pred_results)):
|
760 |
+
# upsample masks
|
761 |
+
mask_pred_results[i] = F.interpolate(
|
762 |
+
mask_pred_results[i][None,],
|
763 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
764 |
+
mode="bilinear",
|
765 |
+
align_corners=False,
|
766 |
+
)[0]
|
767 |
+
|
768 |
+
processed_results = []
|
769 |
+
for mask_pred_result, input_per_image, image_size in zip(
|
770 |
+
mask_pred_results, batched_inputs, images.image_sizes
|
771 |
+
):
|
772 |
+
height = input_per_image.get("height", image_size[0])
|
773 |
+
width = input_per_image.get("width", image_size[1])
|
774 |
+
processed_results.append({})
|
775 |
+
|
776 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
777 |
+
mask_pred_result, image_size, height, width
|
778 |
+
)
|
779 |
+
processed_results[-1]['grounding_mask'] = mask_pred_result
|
780 |
+
|
781 |
+
# compute bbox
|
782 |
+
# bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
|
783 |
+
# bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
784 |
+
# processed_results[-1]['grounding_box'] = bbox
|
785 |
+
|
786 |
+
return processed_results
|
787 |
+
|
788 |
+
def evaluate_grounding_sptial(self, batched_inputs, mode):
|
789 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
790 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
791 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
792 |
+
assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
|
793 |
+
|
794 |
+
extra = {}
|
795 |
+
dilation = 3
|
796 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
|
797 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
|
798 |
+
pos_masks = (F.conv2d(pos_masks.float(), self.dilation_kernel, padding=dilation//2) > 0).unbind(0)
|
799 |
+
|
800 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
|
801 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
|
802 |
+
|
803 |
+
mask_pred_results = []
|
804 |
+
for idx, batch_per_image in enumerate(batched_inputs):
|
805 |
+
grd_texts = batch_per_image['groundings']['texts']
|
806 |
+
grd_masks = []
|
807 |
+
for idx2, anno_text in enumerate(grd_texts):
|
808 |
+
extra.update({'spatial_query_pos_mask': [pos_masks[idx2]], 'spatial_query_neg_mask': [neg_masks[idx2]]})
|
809 |
+
|
810 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
|
811 |
+
token_emb = gtext['token_emb']
|
812 |
+
tokens = gtext['tokens']
|
813 |
+
|
814 |
+
grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
|
815 |
+
non_zero_query_mask = torch.zeros(grd_emb[:,None].shape[:-1], dtype=torch.bool, device=grd_emb.device)
|
816 |
+
extra['grounding_tokens'] = grd_emb[:,None]
|
817 |
+
extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
818 |
+
|
819 |
+
assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
|
820 |
+
features = self.backbone(images.tensor)
|
821 |
+
outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
822 |
+
|
823 |
+
pred_gmasks = outputs['pred_gmasks'][idx]
|
824 |
+
v_emb = outputs['pred_gtexts'][idx]
|
825 |
+
t_emb = gtext['class_emb']
|
826 |
+
|
827 |
+
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
828 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
829 |
+
|
830 |
+
temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
831 |
+
out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
832 |
+
|
833 |
+
matched_id = out_prob.max(0)[1]
|
834 |
+
grd_masks += [pred_gmasks[matched_id,:,:]]
|
835 |
+
# grd_masks += [outputs['prev_mask'][0]]
|
836 |
+
|
837 |
+
mask_pred_results += [torch.cat(grd_masks)]
|
838 |
+
|
839 |
+
# comment for multi object inference.
|
840 |
+
# mask_pred_results = []
|
841 |
+
# for idx, batch_per_image in enumerate(batched_inputs):
|
842 |
+
# grd_texts = batch_per_image['groundings']['texts']
|
843 |
+
# grd_texts = [x[0] for x in grd_texts]
|
844 |
+
|
845 |
+
# gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
|
846 |
+
# token_emb = gtext['token_emb']
|
847 |
+
# tokens = gtext['tokens']
|
848 |
+
# query_emb = token_emb[tokens['attention_mask'].bool()]
|
849 |
+
# non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
|
850 |
+
|
851 |
+
# extra['grounding_tokens'] = query_emb[:,None]
|
852 |
+
# extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
853 |
+
|
854 |
+
# features = self.backbone(images.tensor)
|
855 |
+
# outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
856 |
+
|
857 |
+
# pred_gmasks = outputs['pred_gmasks'][idx]
|
858 |
+
# v_emb = outputs['pred_gtexts'][idx]
|
859 |
+
# t_emb = gtext['class_emb']
|
860 |
+
|
861 |
+
# t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
862 |
+
# v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
863 |
+
|
864 |
+
# temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
865 |
+
# out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
866 |
+
|
867 |
+
# matched_id = out_prob.max(0)[1]
|
868 |
+
# mask_pred_results += [pred_gmasks[matched_id,:,:]]
|
869 |
+
|
870 |
+
for i in range(len(mask_pred_results)):
|
871 |
+
# upsample masks
|
872 |
+
mask_pred_results[i] = F.interpolate(
|
873 |
+
mask_pred_results[i][None,],
|
874 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
875 |
+
mode="bilinear",
|
876 |
+
align_corners=False,
|
877 |
+
)[0]
|
878 |
+
|
879 |
+
processed_results = []
|
880 |
+
for mask_pred_result, input_per_image, image_size in zip(
|
881 |
+
mask_pred_results, batched_inputs, images.image_sizes
|
882 |
+
):
|
883 |
+
height = input_per_image.get("height", image_size[0])
|
884 |
+
width = input_per_image.get("width", image_size[1])
|
885 |
+
processed_results.append({})
|
886 |
+
|
887 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
888 |
+
mask_pred_result, image_size, height, width
|
889 |
+
)
|
890 |
+
processed_results[-1]['grounding_mask'] = mask_pred_result
|
891 |
+
|
892 |
+
return processed_results
|
893 |
+
|
894 |
+
def prepare_targets(self, batched_inputs, images):
|
895 |
+
h_pad, w_pad = images.tensor.shape[-2:]
|
896 |
+
new_targets = []
|
897 |
+
for idx, batch_per_image in enumerate(batched_inputs):
|
898 |
+
targets_per_image = batch_per_image['instances'].to(self.device)
|
899 |
+
# pad gt
|
900 |
+
gt_masks = targets_per_image.gt_masks.tensor
|
901 |
+
padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
|
902 |
+
padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
|
903 |
+
|
904 |
+
gt_boxes = targets_per_image.gt_boxes.tensor
|
905 |
+
ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
|
906 |
+
gt_boxes = gt_boxes / ratio
|
907 |
+
xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
|
908 |
+
gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
|
909 |
+
|
910 |
+
target_dict = {
|
911 |
+
"labels": targets_per_image.gt_classes,
|
912 |
+
"is_things": targets_per_image.is_things,
|
913 |
+
"masks": padded_masks,
|
914 |
+
"boxes": gt_boxes,
|
915 |
+
}
|
916 |
+
|
917 |
+
if self.task_switch['spatial']:
|
918 |
+
# prepare targets for spatial query
|
919 |
+
target_dict['gt_spatial_masks'] = batch_per_image['spatial_query']['gt_masks']
|
920 |
+
|
921 |
+
if self.task_switch['grounding']:
|
922 |
+
grd_masks = batch_per_image['groundings']['masks']
|
923 |
+
grd_texts = batch_per_image['groundings']['texts']
|
924 |
+
grd_hash = batch_per_image['groundings']['hash']
|
925 |
+
grd_task = batch_per_image['groundings']['mode']
|
926 |
+
|
927 |
+
if len(grd_masks) == 0:
|
928 |
+
padded_masks = None
|
929 |
+
else:
|
930 |
+
padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
|
931 |
+
padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
|
932 |
+
|
933 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
|
934 |
+
token_emb = gtext['token_emb']
|
935 |
+
tokens = gtext['tokens']
|
936 |
+
|
937 |
+
unique_hash_id = np.unique(grd_hash, return_index=True)[1]
|
938 |
+
selected_mask = np.zeros(len(grd_hash)).astype(np.bool)
|
939 |
+
selected_mask[unique_hash_id] = True
|
940 |
+
|
941 |
+
selected_token_emb = token_emb[selected_mask]
|
942 |
+
selected_attn_mask = tokens['attention_mask'][selected_mask]
|
943 |
+
query_emb = selected_token_emb[selected_attn_mask.bool()]
|
944 |
+
|
945 |
+
class_idx = tokens['attention_mask'].sum(dim=-1) - 1
|
946 |
+
class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
|
947 |
+
class_emb = token_emb[class_idx]
|
948 |
+
|
949 |
+
target_dict['grounding_masks'] = padded_masks
|
950 |
+
target_dict['grounding_query_embs'] = query_emb
|
951 |
+
target_dict['grounding_class_embs'] = class_emb
|
952 |
+
target_dict['grounding_hash'] = grd_hash
|
953 |
+
target_dict['grounding_task'] = grd_task
|
954 |
+
|
955 |
+
new_targets.append(target_dict)
|
956 |
+
return new_targets
|
957 |
+
|
958 |
+
def prepare_next_spaital_mask(self, outputs, batched_inputs, mode='best'):
|
959 |
+
gt_masks = [batched_inputs[i]['spatial_query']['gt_masks'] for i in range(len(batched_inputs))]
|
960 |
+
if self.training:
|
961 |
+
gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor
|
962 |
+
else:
|
963 |
+
gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor.transpose(0,1)
|
964 |
+
|
965 |
+
pred_masks = (F.interpolate(outputs['prev_mask'], size=gt_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5)
|
966 |
+
prev_masks = torch.stack(outputs['spatial_query_pos_mask']) | torch.stack(outputs['spatial_query_neg_mask'])
|
967 |
+
|
968 |
+
fn = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks) # fn: False Negative, gt:1, pred:0, prev:0
|
969 |
+
fp = (~gt_masks & pred_masks) & (~prev_masks) # fp: False Positive, gt:0, pred:1, prev:0
|
970 |
+
|
971 |
+
# compute iou between gt and pred
|
972 |
+
iou = (gt_masks & pred_masks).sum(list(range(1,len(fn.shape)))) / ((gt_masks | pred_masks).sum(dim=list(range(1,len(fn.shape)))) + 1e-8)
|
973 |
+
fn_sum = fn.sum(dim=list(range(1,len(fn.shape))))
|
974 |
+
fp_sum = fp.sum(dim=list(range(1,len(fp.shape))))
|
975 |
+
|
976 |
+
is_postive = fn_sum > fp_sum
|
977 |
+
# is_postive = torch.ones(len(fn_sum), device=torch.cuda.current_device()).bool()
|
978 |
+
select_mask = torch.stack([fn[i] if is_postive[i] else fp[i] for i in range(len(fn))])
|
979 |
+
|
980 |
+
# conv implementation
|
981 |
+
n,_,h,w = select_mask.shape
|
982 |
+
mask_dt = (distance_transform((~F.pad(select_mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(n,-1)
|
983 |
+
if mode == 'best':
|
984 |
+
max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
|
985 |
+
elif mode == 'best_random':
|
986 |
+
max_xy_idx = torch.stack([torch.arange(n), torch.cat([(mask_dt[i] > 0).nonzero()[torch.randint(0, len((mask_dt[i] > 0).nonzero()), (1,))][0] for i in range(len(mask_dt))]).cpu()]).tolist()
|
987 |
+
next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
|
988 |
+
next_mask = next_mask.view(n,-1)
|
989 |
+
next_mask[max_xy_idx] = True
|
990 |
+
next_mask = next_mask.reshape((n,1,h,w)).float()
|
991 |
+
dilation = 3
|
992 |
+
next_mask = F.conv2d(next_mask, self.dilation_kernel, padding=dilation//2) > 0
|
993 |
+
|
994 |
+
# determine whether next mask is zero
|
995 |
+
keep = (iou < 0.925)
|
996 |
+
next_mask = next_mask & keep.view(-1,1,1,1)
|
997 |
+
|
998 |
+
pos_mask = []
|
999 |
+
neg_mask = []
|
1000 |
+
for idx, ip in enumerate(is_postive):
|
1001 |
+
if ip:
|
1002 |
+
pos_mask += [outputs['spatial_query_pos_mask'][idx] | next_mask[idx]]
|
1003 |
+
neg_mask += [outputs['spatial_query_neg_mask'][idx]]
|
1004 |
+
else:
|
1005 |
+
pos_mask += [outputs['spatial_query_pos_mask'][idx]]
|
1006 |
+
neg_mask += [outputs['spatial_query_neg_mask'][idx] | next_mask[idx]]
|
1007 |
+
|
1008 |
+
if 'false_positive_mask' in outputs:
|
1009 |
+
fp = outputs['false_positive_mask'] | fp
|
1010 |
+
return {'spatial_query_pos_mask': pos_mask, 'spatial_query_neg_mask': neg_mask, 'false_positive_mask': fp}
|
1011 |
+
|
1012 |
+
def semantic_inference(self, mask_cls, mask_pred):
|
1013 |
+
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
|
1014 |
+
mask_pred = mask_pred.sigmoid()
|
1015 |
+
semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
|
1016 |
+
return semseg
|
1017 |
+
|
1018 |
+
def panoptic_inference(self, mask_cls, mask_pred):
|
1019 |
+
scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
|
1020 |
+
mask_pred = mask_pred.sigmoid()
|
1021 |
+
|
1022 |
+
keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
|
1023 |
+
cur_scores = scores[keep]
|
1024 |
+
cur_classes = labels[keep]
|
1025 |
+
cur_masks = mask_pred[keep]
|
1026 |
+
cur_mask_cls = mask_cls[keep]
|
1027 |
+
cur_mask_cls = cur_mask_cls[:, :-1]
|
1028 |
+
|
1029 |
+
cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
|
1030 |
+
|
1031 |
+
h, w = cur_masks.shape[-2:]
|
1032 |
+
panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
|
1033 |
+
segments_info = []
|
1034 |
+
|
1035 |
+
current_segment_id = 0
|
1036 |
+
|
1037 |
+
if cur_masks.shape[0] == 0:
|
1038 |
+
# We didn't detect any mask :(
|
1039 |
+
return panoptic_seg, segments_info
|
1040 |
+
else:
|
1041 |
+
# take argmax
|
1042 |
+
cur_mask_ids = cur_prob_masks.argmax(0)
|
1043 |
+
stuff_memory_list = {}
|
1044 |
+
for k in range(cur_classes.shape[0]):
|
1045 |
+
pred_class = cur_classes[k].item()
|
1046 |
+
isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
|
1047 |
+
mask_area = (cur_mask_ids == k).sum().item()
|
1048 |
+
original_area = (cur_masks[k] >= 0.5).sum().item()
|
1049 |
+
mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
|
1050 |
+
|
1051 |
+
if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
|
1052 |
+
if mask_area / original_area < self.overlap_threshold:
|
1053 |
+
continue
|
1054 |
+
|
1055 |
+
# merge stuff regions
|
1056 |
+
if not isthing:
|
1057 |
+
if int(pred_class) in stuff_memory_list.keys():
|
1058 |
+
panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
|
1059 |
+
continue
|
1060 |
+
else:
|
1061 |
+
stuff_memory_list[int(pred_class)] = current_segment_id + 1
|
1062 |
+
|
1063 |
+
current_segment_id += 1
|
1064 |
+
panoptic_seg[mask] = current_segment_id
|
1065 |
+
|
1066 |
+
segments_info.append(
|
1067 |
+
{
|
1068 |
+
"id": current_segment_id,
|
1069 |
+
"isthing": bool(isthing),
|
1070 |
+
"category_id": int(pred_class),
|
1071 |
+
}
|
1072 |
+
)
|
1073 |
+
|
1074 |
+
return panoptic_seg, segments_info
|
1075 |
+
|
1076 |
+
def instance_inference(self, mask_cls, mask_pred, box_pred):
|
1077 |
+
# mask_pred is already processed to have the same shape as original input
|
1078 |
+
image_size = mask_pred.shape[-2:]
|
1079 |
+
|
1080 |
+
# [Q, K]
|
1081 |
+
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
|
1082 |
+
labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
|
1083 |
+
# scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
|
1084 |
+
scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
|
1085 |
+
|
1086 |
+
labels_per_image = labels[topk_indices]
|
1087 |
+
topk_indices = (topk_indices // self.sem_seg_head.num_classes)
|
1088 |
+
# mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
|
1089 |
+
mask_pred = mask_pred[topk_indices]
|
1090 |
+
if box_pred is not None:
|
1091 |
+
box_pred = box_pred[topk_indices]
|
1092 |
+
|
1093 |
+
# if this is panoptic segmentation, we only keep the "thing" classes
|
1094 |
+
if self.panoptic_on:
|
1095 |
+
keep = torch.zeros_like(scores_per_image).bool()
|
1096 |
+
for i, lab in enumerate(labels_per_image):
|
1097 |
+
keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
|
1098 |
+
|
1099 |
+
scores_per_image = scores_per_image[keep]
|
1100 |
+
labels_per_image = labels_per_image[keep]
|
1101 |
+
mask_pred = mask_pred[keep]
|
1102 |
+
|
1103 |
+
if box_pred is not None:
|
1104 |
+
box_pred = box_pred[keep]
|
1105 |
+
|
1106 |
+
result = Instances(image_size)
|
1107 |
+
# mask (before sigmoid)
|
1108 |
+
result.pred_masks = (mask_pred > 0).float()
|
1109 |
+
# result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
1110 |
+
# Uncomment the following to get boxes from masks (this is slow)
|
1111 |
+
|
1112 |
+
if box_pred is not None:
|
1113 |
+
result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
|
1114 |
+
else:
|
1115 |
+
result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
1116 |
+
|
1117 |
+
# calculate average mask prob
|
1118 |
+
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
|
1119 |
+
result.scores = scores_per_image * mask_scores_per_image
|
1120 |
+
result.pred_classes = labels_per_image
|
1121 |
+
|
1122 |
+
return result
|
1123 |
+
|
1124 |
+
def prepare_targets4query(self, targets, images, topk=5):
|
1125 |
+
h_pad, w_pad = images.tensor.shape[-2:]
|
1126 |
+
new_targets = []
|
1127 |
+
new_queries = []
|
1128 |
+
for targets_per_image in targets:
|
1129 |
+
# we randomly sample maximally topk concepts
|
1130 |
+
unique_target_classes = [k for k in set(targets_per_image.gt_classes.tolist())]
|
1131 |
+
selected_target_classes = random.sample(unique_target_classes, min(topk, len(unique_target_classes)))
|
1132 |
+
new_targets_per_image = []
|
1133 |
+
new_queries_per_image = []
|
1134 |
+
for clss in selected_target_classes:
|
1135 |
+
indices = (targets_per_image.gt_classes == clss).nonzero().view(-1)
|
1136 |
+
# pad gt
|
1137 |
+
gt_masks = targets_per_image.gt_masks[indices]
|
1138 |
+
padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
|
1139 |
+
padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
|
1140 |
+
|
1141 |
+
# convert class into concept name and then token seq
|
1142 |
+
self.sem_seg_head.predictor.lang_encoder.get_text_embeddings([COCO_PANOPTIC_CLASSES[clss]], name='grounding')
|
1143 |
+
query = getattr(self.sem_seg_head.predictor.lang_encoder, 'grounding_text_embeddings')
|
1144 |
+
|
1145 |
+
new_targets.append(
|
1146 |
+
{
|
1147 |
+
"labels": targets_per_image.gt_classes[indices],
|
1148 |
+
"masks": padded_masks,
|
1149 |
+
}
|
1150 |
+
)
|
1151 |
+
new_queries_per_image.append(query)
|
1152 |
+
new_queries.append(new_queries_per_image)
|
1153 |
+
|
1154 |
+
return new_targets, new_queries
|
1155 |
+
|
1156 |
+
|
1157 |
+
|
1158 |
+
@register_model
|
1159 |
+
def get_seem_model(cfg, **kwargs):
|
1160 |
+
return GeneralizedSEEM(cfg)
|
modeling/architectures/seem_model_v1.py
ADDED
@@ -0,0 +1,1179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# SEEM -- Segment Everything Everywhere All at Once
|
3 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
4 |
+
# Written by Xueyan Zou ([email protected])
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
import random
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
from kornia.contrib import distance_transform
|
15 |
+
|
16 |
+
from detectron2.structures import Boxes, ImageList, Instances, BitMasks
|
17 |
+
from detectron2.utils.memory import retry_if_cuda_oom
|
18 |
+
from detectron2.data import MetadataCatalog
|
19 |
+
|
20 |
+
from .build import register_model
|
21 |
+
|
22 |
+
from ..utils import configurable, get_class_names, get_iou, Spatial_ImageList
|
23 |
+
from ..vision.backbone import build_backbone, Backbone
|
24 |
+
from ..body import build_xdecoder_head
|
25 |
+
from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
|
26 |
+
from ..language import build_language_encoder
|
27 |
+
from ..language.loss import vl_similarity
|
28 |
+
from utilities.prompt_engineering import prompt_engineering
|
29 |
+
from utilities.constants import COCO_PANOPTIC_CLASSES, BIOMED_CLASSES
|
30 |
+
|
31 |
+
|
32 |
+
class GeneralizedSEEM(nn.Module):
|
33 |
+
|
34 |
+
@configurable
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
*,
|
38 |
+
backbone: Backbone,
|
39 |
+
sem_seg_head: nn.Module,
|
40 |
+
criterion: nn.Module,
|
41 |
+
losses: dict,
|
42 |
+
num_queries: int,
|
43 |
+
object_mask_threshold: float,
|
44 |
+
overlap_threshold: float,
|
45 |
+
metadata,
|
46 |
+
task_switch: dict,
|
47 |
+
phrase_prob: float,
|
48 |
+
size_divisibility: int,
|
49 |
+
sem_seg_postprocess_before_inference: bool,
|
50 |
+
pixel_mean: Tuple[float],
|
51 |
+
pixel_std: Tuple[float],
|
52 |
+
# inference
|
53 |
+
semantic_on: bool,
|
54 |
+
panoptic_on: bool,
|
55 |
+
instance_on: bool,
|
56 |
+
test_topk_per_image: int,
|
57 |
+
train_dataset_name: str,
|
58 |
+
interactive_mode: str,
|
59 |
+
interactive_iter: str,
|
60 |
+
dilation_kernel: torch.Tensor,
|
61 |
+
train_max_iter: int,
|
62 |
+
binary_classes: bool,
|
63 |
+
standard_text_for_eval: bool,
|
64 |
+
):
|
65 |
+
"""
|
66 |
+
Args:
|
67 |
+
backbone: a backbone module, must follow detectron2's backbone interface
|
68 |
+
sem_seg_head: a module that predicts semantic segmentation from backbone features
|
69 |
+
criterion: a module that defines the loss
|
70 |
+
num_queries: int, number of queries
|
71 |
+
object_mask_threshold: float, threshold to filter query based on classification score
|
72 |
+
for panoptic segmentation inference
|
73 |
+
overlap_threshold: overlap threshold used in general inference for panoptic segmentation
|
74 |
+
metadata: dataset meta, get `thing` and `stuff` category names for panoptic
|
75 |
+
segmentation inference
|
76 |
+
size_divisibility: Some backbones require the input height and width to be divisible by a
|
77 |
+
specific integer. We can use this to override such requirement.
|
78 |
+
sem_seg_postprocess_before_inference: whether to resize the prediction back
|
79 |
+
to original input size before semantic segmentation inference or after.
|
80 |
+
For high-resolution dataset like Mapillary, resizing predictions before
|
81 |
+
inference will cause OOM error.
|
82 |
+
pixel_mean, pixel_std: list or tuple with #channels element, representing
|
83 |
+
the per-channel mean and std to be used to normalize the input image
|
84 |
+
semantic_on: bool, whether to output semantic segmentation prediction
|
85 |
+
instance_on: bool, whether to output instance segmentation prediction
|
86 |
+
panoptic_on: bool, whether to output panoptic segmentation prediction
|
87 |
+
test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
|
88 |
+
"""
|
89 |
+
super().__init__()
|
90 |
+
self.backbone = backbone
|
91 |
+
self.sem_seg_head = sem_seg_head
|
92 |
+
self.criterion = criterion
|
93 |
+
self.losses = losses
|
94 |
+
self.num_queries = num_queries
|
95 |
+
self.overlap_threshold = overlap_threshold
|
96 |
+
self.object_mask_threshold = object_mask_threshold
|
97 |
+
self.metadata = metadata
|
98 |
+
if size_divisibility < 0:
|
99 |
+
# use backbone size_divisibility if not set
|
100 |
+
size_divisibility = self.backbone.size_divisibility
|
101 |
+
self.size_divisibility = size_divisibility
|
102 |
+
self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
|
103 |
+
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
104 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
105 |
+
|
106 |
+
# additional args
|
107 |
+
self.semantic_on = semantic_on
|
108 |
+
self.instance_on = instance_on
|
109 |
+
self.panoptic_on = panoptic_on
|
110 |
+
|
111 |
+
# caption argument
|
112 |
+
self.task_switch = task_switch
|
113 |
+
self.phrase_prob = phrase_prob
|
114 |
+
self.train_max_iter = train_max_iter
|
115 |
+
|
116 |
+
self.test_topk_per_image = test_topk_per_image
|
117 |
+
self.train_class_names = get_class_names(train_dataset_name)
|
118 |
+
if binary_classes:
|
119 |
+
self.train_class_names = ['target', 'background']
|
120 |
+
self.interactive_mode = interactive_mode
|
121 |
+
self.interactive_iter = interactive_iter
|
122 |
+
|
123 |
+
if not self.semantic_on:
|
124 |
+
assert self.sem_seg_postprocess_before_inference
|
125 |
+
|
126 |
+
self.register_buffer("dilation_kernel", dilation_kernel)
|
127 |
+
|
128 |
+
self.standard_text_for_eval = standard_text_for_eval
|
129 |
+
|
130 |
+
@classmethod
|
131 |
+
def from_config(cls, cfg):
|
132 |
+
enc_cfg = cfg['MODEL']['ENCODER']
|
133 |
+
dec_cfg = cfg['MODEL']['DECODER']
|
134 |
+
|
135 |
+
# Loss parameters:
|
136 |
+
deep_supervision = dec_cfg['DEEP_SUPERVISION']
|
137 |
+
no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']
|
138 |
+
|
139 |
+
# loss weights
|
140 |
+
loss_weights = {'mask': {'ce': dec_cfg['CLASS_WEIGHT'], 'dice': dec_cfg['DICE_WEIGHT'], 'bce': dec_cfg['MASK_WEIGHT']},
|
141 |
+
'bbox': {'l1': dec_cfg['BBOX_WEIGHT'], 'giou': dec_cfg['GIOU_WEIGHT']},
|
142 |
+
'spatial': {'ce': dec_cfg['SCLASS_WEIGHT'], 'dice': dec_cfg['SDICE_WEIGHT'], 'bce': dec_cfg['SMASK_WEIGHT']},
|
143 |
+
'grounding': {'ce': dec_cfg['GCLASS_WEIGHT'], 'dice': dec_cfg['GDICE_WEIGHT'], 'bce': dec_cfg['GMASK_WEIGHT']},
|
144 |
+
'openimage': {'ce': dec_cfg['OCLASS_WEIGHT'], 'dice': dec_cfg['ODICE_WEIGHT'], 'bce': dec_cfg['OMASK_WEIGHT']}}
|
145 |
+
|
146 |
+
openimage_switch = {'grounding': dec_cfg['OPENIMAGE']['GROUNDING'].get('ENABLED', False),
|
147 |
+
'mask': dec_cfg['OPENIMAGE'].get('ENABLED', False)}
|
148 |
+
|
149 |
+
task_switch = {'bbox': dec_cfg.get('DETECTION', False),
|
150 |
+
'mask': dec_cfg['MASK'].get('ENABLED', True),
|
151 |
+
'spatial': dec_cfg['SPATIAL'].get('ENABLED', False),
|
152 |
+
'grounding': dec_cfg['GROUNDING'].get('ENABLED', False),
|
153 |
+
'openimage': openimage_switch}
|
154 |
+
|
155 |
+
top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10),
|
156 |
+
'grounding': dec_cfg.get('TOP_GROUNDING_LAYERS', 10),
|
157 |
+
'openimage': dec_cfg.get('TOP_OPENIMAGE_LAYERS', 10),
|
158 |
+
'spatial': dec_cfg.get('TOP_SPATIAL_LAYERS', 10)}
|
159 |
+
|
160 |
+
spatial_cost = {"class_weight": dec_cfg['COST_SPATIAL']['CLASS_WEIGHT'],
|
161 |
+
"mask_weight": dec_cfg['COST_SPATIAL']['MASK_WEIGHT'],
|
162 |
+
"dice_weight": dec_cfg['COST_SPATIAL']['DICE_WEIGHT']}
|
163 |
+
|
164 |
+
extra = {'task_switch': task_switch}
|
165 |
+
backbone = build_backbone(cfg)
|
166 |
+
lang_encoder = build_language_encoder(cfg)
|
167 |
+
sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra=extra)
|
168 |
+
|
169 |
+
# building criterion
|
170 |
+
matcher = HungarianMatcher(
|
171 |
+
cost_class=loss_weights['mask']['ce'],
|
172 |
+
cost_mask=loss_weights['mask']['bce'],
|
173 |
+
cost_dice=loss_weights['mask']['dice'],
|
174 |
+
num_points=dec_cfg['TRAIN_NUM_POINTS'],
|
175 |
+
spatial_cost=spatial_cost,
|
176 |
+
)
|
177 |
+
|
178 |
+
# init weight dict and criterion loss functions.
|
179 |
+
losses = {'seg': [], 'openimage': []}
|
180 |
+
if task_switch['mask']:
|
181 |
+
losses['seg'] += ["labels", "masks"]
|
182 |
+
if task_switch['spatial']:
|
183 |
+
losses['seg'] += ["spatials"]
|
184 |
+
if task_switch['grounding']:
|
185 |
+
losses['seg'] += ["groundings"]
|
186 |
+
if task_switch['openimage']:
|
187 |
+
losses['openimage'] += ["labels_openimage", "masks"]
|
188 |
+
if task_switch['openimage']['grounding']:
|
189 |
+
losses['openimage'] += ["groundings"]
|
190 |
+
|
191 |
+
weight_dict = {}
|
192 |
+
for key, turn_on in task_switch.items():
|
193 |
+
if turn_on:
|
194 |
+
if isinstance(loss_weights[key], dict):
|
195 |
+
# HACK it should support bbox in the future
|
196 |
+
for key_, weight in loss_weights[key].items():
|
197 |
+
weight_dict["loss_{}_{}_0".format(key, key_)] = weight # NOTE: hard code for segmentation that has multiple loss
|
198 |
+
else:
|
199 |
+
weight_dict["loss_{}_0".format(key)] = loss_weights[key]
|
200 |
+
|
201 |
+
# generate full weight dict and remove not computed layers.
|
202 |
+
if deep_supervision:
|
203 |
+
dec_layers = dec_cfg['DEC_LAYERS']
|
204 |
+
aux_weight_dict = {}
|
205 |
+
for i in range(dec_layers - 1):
|
206 |
+
for k, v in weight_dict.items():
|
207 |
+
if (i+1) > (top_x_layers[k.split('_')[1]] - 1):
|
208 |
+
continue
|
209 |
+
aux_weight_dict.update({k.replace('_0', f"_{i+1}"): v})
|
210 |
+
weight_dict.update(aux_weight_dict)
|
211 |
+
|
212 |
+
grd_weight = {'text': dec_cfg['GROUNDING']['TEXT_WEIGHT'], 'class': dec_cfg['GROUNDING']['CLASS_WEIGHT']}
|
213 |
+
# generate critenrion for loss function.
|
214 |
+
criterion = SetCriterion(
|
215 |
+
sem_seg_head.num_classes,
|
216 |
+
matcher=matcher,
|
217 |
+
weight_dict=weight_dict,
|
218 |
+
top_x_layers=top_x_layers,
|
219 |
+
eos_coef=no_object_weight,
|
220 |
+
losses=[],
|
221 |
+
num_points=dec_cfg['TRAIN_NUM_POINTS'],
|
222 |
+
oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],
|
223 |
+
importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],
|
224 |
+
grounding_weight=grd_weight,
|
225 |
+
)
|
226 |
+
|
227 |
+
# extra logistic
|
228 |
+
train_dataset_name = cfg['DATASETS']['TRAIN'][0] # HACK for only one training set.
|
229 |
+
train_max_iter = dec_cfg['SPATIAL'].get('MAX_ITER', 3)
|
230 |
+
phrase_prob = dec_cfg['CAPTION'].get('PHRASE_PROB', 0.5)
|
231 |
+
interactive_mode = cfg['STROKE_SAMPLER']['EVAL']['MODE']
|
232 |
+
interactive_iter = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
|
233 |
+
|
234 |
+
dilation = 3
|
235 |
+
dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
|
236 |
+
|
237 |
+
return {
|
238 |
+
"backbone": backbone,
|
239 |
+
"sem_seg_head": sem_seg_head,
|
240 |
+
"criterion": criterion,
|
241 |
+
"losses": losses,
|
242 |
+
"num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
|
243 |
+
"object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
|
244 |
+
"overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
|
245 |
+
"metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),
|
246 |
+
"size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
|
247 |
+
"sem_seg_postprocess_before_inference": (
|
248 |
+
dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
|
249 |
+
or dec_cfg['TEST']['PANOPTIC_ON']
|
250 |
+
or dec_cfg['TEST']['INSTANCE_ON']
|
251 |
+
),
|
252 |
+
"pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
|
253 |
+
"pixel_std": cfg['INPUT']['PIXEL_STD'],
|
254 |
+
"task_switch": task_switch,
|
255 |
+
"phrase_prob": phrase_prob,
|
256 |
+
# inference
|
257 |
+
"semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
|
258 |
+
"instance_on": dec_cfg['TEST']['INSTANCE_ON'],
|
259 |
+
"panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
|
260 |
+
"test_topk_per_image": cfg['TEST']['DETECTIONS_PER_IMAGE'],
|
261 |
+
"train_dataset_name": train_dataset_name,
|
262 |
+
"interactive_mode": interactive_mode,
|
263 |
+
"interactive_iter": interactive_iter,
|
264 |
+
"dilation_kernel": dilation_kernel,
|
265 |
+
"train_max_iter": train_max_iter,
|
266 |
+
"binary_classes": enc_cfg['BINARY_CLASSES'],
|
267 |
+
"standard_text_for_eval": cfg['STANDARD_TEXT_FOR_EVAL'],
|
268 |
+
}
|
269 |
+
|
270 |
+
@property
|
271 |
+
def device(self):
|
272 |
+
return self.pixel_mean.device
|
273 |
+
|
274 |
+
def forward(self, batched_inputs, mode='default'):
|
275 |
+
"""
|
276 |
+
Args:
|
277 |
+
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
|
278 |
+
Each item in the list contains the inputs for one image.
|
279 |
+
For now, each item in the list is a dict that contains:
|
280 |
+
* "image": Tensor, image in (C, H, W) format.
|
281 |
+
* "instances": per-region ground truth
|
282 |
+
* Other information that's included in the original dicts, such as:
|
283 |
+
"height", "width" (int): the output resolution of the model (may be different
|
284 |
+
from input resolution), used in inference.
|
285 |
+
Returns:
|
286 |
+
list[dict]:
|
287 |
+
each dict has the results for one image. The dict contains the following keys:
|
288 |
+
|
289 |
+
* "sem_seg":
|
290 |
+
A Tensor that represents the
|
291 |
+
per-pixel segmentation prediced by the head.
|
292 |
+
The prediction has shape KxHxW that represents the logits of
|
293 |
+
each class for each pixel.
|
294 |
+
* "panoptic_seg":
|
295 |
+
A tuple that represent panoptic output
|
296 |
+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
|
297 |
+
segments_info (list[dict]): Describe each segment in `panoptic_seg`.
|
298 |
+
Each dict contains keys "id", "category_id", "isthing".
|
299 |
+
"""
|
300 |
+
if self.training:
|
301 |
+
losses = {}
|
302 |
+
if self.task_switch['mask'] or self.task_switch['grounding'] or self.task_switch['spatial']:
|
303 |
+
losses_seg = self.forward_seg(batched_inputs)
|
304 |
+
losses.update(losses_seg)
|
305 |
+
if self.task_switch['openimage'] and self.task_switch['openimage']['mask']:
|
306 |
+
losses_openimage = self.forward_openimage(batched_inputs['openimage'])
|
307 |
+
losses_openimage = {key.replace('mask', 'openimage'):value for key, value in losses_openimage.items()}
|
308 |
+
losses_openimage = {key.replace('grounding', 'grounding_openimage'):value for key, value in losses_openimage.items()}
|
309 |
+
losses.update(losses_openimage)
|
310 |
+
for k in list(losses.keys()):
|
311 |
+
if k in self.criterion.weight_dict:
|
312 |
+
losses[k] *= self.criterion.weight_dict[k]
|
313 |
+
else: # remove this loss if not specified in `weight_dict`
|
314 |
+
losses.pop(k)
|
315 |
+
return losses
|
316 |
+
else:
|
317 |
+
if mode == 'interactive':
|
318 |
+
return self.evaluate_interactive(batched_inputs)
|
319 |
+
elif mode == 'interactive_grounding':
|
320 |
+
return self.evaluate_interactive_grounding(batched_inputs)
|
321 |
+
elif mode == 'grounding_spatial':
|
322 |
+
return self.evaluate_grounding_sptial(batched_inputs, mode)
|
323 |
+
elif mode in ['grounding_phrasecut', 'grounding_refcoco']:
|
324 |
+
return self.evaluate_grounding(batched_inputs, mode)
|
325 |
+
else:
|
326 |
+
return self.evaluate(batched_inputs)
|
327 |
+
|
328 |
+
|
329 |
+
def forward_seg(self, batched_inputs):
|
330 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
331 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
332 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
333 |
+
self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
|
334 |
+
|
335 |
+
extra = {}
|
336 |
+
# mask classification target
|
337 |
+
if "instances" in batched_inputs[0]:
|
338 |
+
# input bounding box is checked to be correct.
|
339 |
+
targets = self.prepare_targets(batched_inputs, images)
|
340 |
+
|
341 |
+
if self.task_switch['grounding']:
|
342 |
+
grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
|
343 |
+
grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens, padding_value=-1)
|
344 |
+
non_zero_query_mask = (grounding_tokens.sum(dim=-1) == -grounding_tokens.shape[-1])
|
345 |
+
grounding_tokens[non_zero_query_mask] = 0
|
346 |
+
|
347 |
+
extra['grounding_tokens'] = grounding_tokens
|
348 |
+
extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
349 |
+
|
350 |
+
if self.task_switch['spatial']:
|
351 |
+
pos_masks = [x['spatial_query']['rand_shape'].to(self.device) for x in batched_inputs]
|
352 |
+
neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs]
|
353 |
+
fp_masks = nn.utils.rnn.pad_sequence([(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs], padding_value=False, batch_first=True)
|
354 |
+
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks, 'false_positive_mask': fp_masks})
|
355 |
+
|
356 |
+
features = self.backbone(images.tensor)
|
357 |
+
mask_features, _, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
358 |
+
|
359 |
+
# forward spatial only without gradient
|
360 |
+
if self.task_switch['spatial']:
|
361 |
+
with torch.no_grad():
|
362 |
+
# generate random integeter between [0,3]
|
363 |
+
rand_iter_num = random.randint(0, self.train_max_iter)
|
364 |
+
for i in range(rand_iter_num):
|
365 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='spatial')
|
366 |
+
extra.update(outputs)
|
367 |
+
extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
|
368 |
+
|
369 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='seg')
|
370 |
+
|
371 |
+
extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
|
372 |
+
'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default')),
|
373 |
+
'false_positive_mask': extra['false_positive_mask']}
|
374 |
+
# bipartite matching-based loss
|
375 |
+
self.criterion.losses = self.losses['seg'] # seg criterion losses
|
376 |
+
|
377 |
+
if self.task_switch['mask']:
|
378 |
+
losses = self.criterion(outputs, targets, extra)
|
379 |
+
else:
|
380 |
+
losses = self.criterion.forward_vlp(outputs, targets, extra)
|
381 |
+
|
382 |
+
del outputs
|
383 |
+
return losses
|
384 |
+
|
385 |
+
def evaluate(self, batched_inputs):
|
386 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
387 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
388 |
+
|
389 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
390 |
+
img_bs = images.tensor.shape[0]
|
391 |
+
|
392 |
+
targets = targets_grounding = queries_grounding = None
|
393 |
+
features = self.backbone(images.tensor)
|
394 |
+
outputs = self.sem_seg_head(features, target_queries=queries_grounding)
|
395 |
+
|
396 |
+
mask_cls_results = outputs["pred_logits"]
|
397 |
+
mask_pred_results = outputs["pred_masks"]
|
398 |
+
box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
|
399 |
+
|
400 |
+
# upsample masks
|
401 |
+
mask_pred_results = F.interpolate(
|
402 |
+
mask_pred_results,
|
403 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
404 |
+
mode="bilinear",
|
405 |
+
align_corners=False,
|
406 |
+
)
|
407 |
+
|
408 |
+
input_size = mask_pred_results.shape[-2:]
|
409 |
+
del outputs
|
410 |
+
|
411 |
+
processed_results = []
|
412 |
+
for mask_cls_result, mask_pred_result, box_pred_result, input_per_image, image_size in zip(
|
413 |
+
mask_cls_results, mask_pred_results, box_pred_results, batched_inputs, images.image_sizes
|
414 |
+
):
|
415 |
+
height = input_per_image.get("height", image_size[0])
|
416 |
+
width = input_per_image.get("width", image_size[1])
|
417 |
+
processed_results.append({})
|
418 |
+
|
419 |
+
if self.sem_seg_postprocess_before_inference:
|
420 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
421 |
+
mask_pred_result, image_size, height, width
|
422 |
+
)
|
423 |
+
mask_cls_result = mask_cls_result.to(mask_pred_result)
|
424 |
+
|
425 |
+
# semantic segmentation inference
|
426 |
+
if self.semantic_on:
|
427 |
+
r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
|
428 |
+
if not self.sem_seg_postprocess_before_inference:
|
429 |
+
r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
|
430 |
+
processed_results[-1]["sem_seg"] = r
|
431 |
+
|
432 |
+
# panoptic segmentation inference
|
433 |
+
if self.panoptic_on:
|
434 |
+
panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
|
435 |
+
processed_results[-1]["panoptic_seg"] = panoptic_r
|
436 |
+
|
437 |
+
# instance segmentation inference
|
438 |
+
if self.instance_on:
|
439 |
+
if self.task_switch['bbox']:
|
440 |
+
box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
|
441 |
+
instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
|
442 |
+
processed_results[-1]["instances"] = instance_r
|
443 |
+
|
444 |
+
return processed_results
|
445 |
+
|
446 |
+
def evaluate_interactive(self, batched_inputs):
|
447 |
+
assert self.task_switch['spatial']
|
448 |
+
assert 'spatial_query' in batched_inputs[0]
|
449 |
+
assert len(batched_inputs) == 1, "only support batch size equal to 1"
|
450 |
+
|
451 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
452 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
453 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
454 |
+
img_bs = images.tensor.shape[0]
|
455 |
+
|
456 |
+
targets = targets_grounding = queries_grounding = None
|
457 |
+
extra = {}
|
458 |
+
|
459 |
+
features = self.backbone(images.tensor)
|
460 |
+
mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
461 |
+
|
462 |
+
image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
|
463 |
+
|
464 |
+
all_batch_shape_iou = []
|
465 |
+
pred_smask_pointer = None
|
466 |
+
prev_smask_pointer = None
|
467 |
+
pred_smask_all = None
|
468 |
+
|
469 |
+
# visualization code
|
470 |
+
# v_pred_mask = []
|
471 |
+
# v_pos_mask = []
|
472 |
+
# v_neg_mask = []
|
473 |
+
# v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
|
474 |
+
query_index = self.sem_seg_head.predictor.query_index
|
475 |
+
if self.interactive_mode in ['best', 'best_random']:
|
476 |
+
pos_masks = [x['spatial_query']['rand_shape'].to(self.device)[:,0] for x in batched_inputs]
|
477 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
|
478 |
+
|
479 |
+
neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False)[:,0] for x in batched_inputs]
|
480 |
+
|
481 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
|
482 |
+
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
|
483 |
+
elif self.interactive_mode == 'random':
|
484 |
+
assert False, "interactive mode not correctly implemented"
|
485 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
|
486 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
|
487 |
+
|
488 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
|
489 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
|
490 |
+
extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
|
491 |
+
else:
|
492 |
+
assert False, "invalid interactive mode"
|
493 |
+
|
494 |
+
for i in range(self.interactive_iter):
|
495 |
+
# v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
|
496 |
+
# v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
|
497 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
|
498 |
+
extra.update(outputs)
|
499 |
+
pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
|
500 |
+
# v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
|
501 |
+
|
502 |
+
s = image_sizes[0]
|
503 |
+
b = batched_inputs[0]
|
504 |
+
pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[0].sigmoid() > 0.5
|
505 |
+
gt_smask = b['gt_masks_orisize']
|
506 |
+
ious = get_iou(gt_smask, pred_smask_all)
|
507 |
+
all_batch_shape_iou += [ious]
|
508 |
+
if (ious > 0.9).sum() == len(ious):
|
509 |
+
all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
|
510 |
+
break
|
511 |
+
if self.interactive_mode in ['best', 'best_random']:
|
512 |
+
extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
|
513 |
+
elif self.interactive_mode == 'random':
|
514 |
+
extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
|
515 |
+
else:
|
516 |
+
assert False, "invalid interactive mode"
|
517 |
+
all_batch_shape_iou = torch.stack(all_batch_shape_iou)
|
518 |
+
processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
|
519 |
+
|
520 |
+
return processed_results
|
521 |
+
|
522 |
+
def evaluate_interactive_single(self, batched_inputs, extra={}):
|
523 |
+
assert self.task_switch['spatial']
|
524 |
+
assert 'spatial_query' in batched_inputs[0]
|
525 |
+
assert len(batched_inputs) == 1, "only support batch size equal to 1"
|
526 |
+
|
527 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
528 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
529 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
530 |
+
img_bs = images.tensor.shape[0]
|
531 |
+
|
532 |
+
targets = targets_grounding = queries_grounding = None
|
533 |
+
|
534 |
+
features = self.backbone(images.tensor)
|
535 |
+
mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
536 |
+
|
537 |
+
image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
|
538 |
+
nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
|
539 |
+
multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
|
540 |
+
mask_features = mask_features.repeat(nm,1,1,1)
|
541 |
+
|
542 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
|
543 |
+
pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
|
544 |
+
|
545 |
+
s = image_sizes[0]
|
546 |
+
b = batched_inputs[0]
|
547 |
+
pred_smask_ori = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
|
548 |
+
pred_smask_batch = pred_smask[:,:,:s[0],:s[1]].sigmoid() > 0.5
|
549 |
+
ious = []
|
550 |
+
if 'gt_masks_orisize' in b:
|
551 |
+
gt_smask = b['gt_masks_orisize'].to(pred_smask_ori.device)
|
552 |
+
ious = get_iou(gt_smask, pred_smask_ori)
|
553 |
+
processed_results = [{"mask_iou": ious, 'pred_mask_ori': pred_smask_ori, 'pred_mask_batch': pred_smask_batch}]
|
554 |
+
return processed_results
|
555 |
+
|
556 |
+
def evaluate_interactive_grounding(self, batched_inputs):
|
557 |
+
assert self.task_switch['spatial']
|
558 |
+
assert 'spatial_query' in batched_inputs[0]
|
559 |
+
assert len(batched_inputs) == 1, "only support batch size equal to 1"
|
560 |
+
|
561 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
562 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
563 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
564 |
+
img_bs = images.tensor.shape[0]
|
565 |
+
|
566 |
+
targets = targets_grounding = queries_grounding = None
|
567 |
+
extra = {}
|
568 |
+
|
569 |
+
features = self.backbone(images.tensor)
|
570 |
+
mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
571 |
+
|
572 |
+
image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
|
573 |
+
nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
|
574 |
+
multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
|
575 |
+
mask_features = mask_features.repeat(nm,1,1,1)
|
576 |
+
|
577 |
+
all_batch_shape_iou = []
|
578 |
+
pred_smask_pointer = None
|
579 |
+
prev_smask_pointer = None
|
580 |
+
pred_smask_all = None
|
581 |
+
|
582 |
+
# visualization code
|
583 |
+
# v_pred_mask = []
|
584 |
+
# v_pos_mask = []
|
585 |
+
# v_neg_mask = []
|
586 |
+
# v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
|
587 |
+
query_index = self.sem_seg_head.predictor.query_index
|
588 |
+
if self.interactive_mode in ['best', 'best_random']:
|
589 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
|
590 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
|
591 |
+
|
592 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
|
593 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
|
594 |
+
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
|
595 |
+
elif self.interactive_mode == 'random':
|
596 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
|
597 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
|
598 |
+
|
599 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
|
600 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
|
601 |
+
extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
|
602 |
+
else:
|
603 |
+
assert False, "invalid interactive mode"
|
604 |
+
|
605 |
+
grd_texts = batched_inputs[0]['classes']
|
606 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
|
607 |
+
token_emb = gtext['token_emb']
|
608 |
+
tokens = gtext['tokens']
|
609 |
+
query_emb = nn.utils.rnn.pad_sequence([_token_emb[_tokens.bool()] for _token_emb, _tokens in zip(token_emb, tokens['attention_mask'])], padding_value=-1)
|
610 |
+
non_zero_query_mask = (query_emb.sum(dim=-1) < 0)
|
611 |
+
|
612 |
+
extra['grounding_tokens'] = query_emb
|
613 |
+
extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
614 |
+
|
615 |
+
for i in range(self.interactive_iter):
|
616 |
+
# v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
|
617 |
+
# v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
|
618 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
|
619 |
+
extra.update(outputs)
|
620 |
+
pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
|
621 |
+
# v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
|
622 |
+
|
623 |
+
s = image_sizes[0]
|
624 |
+
b = batched_inputs[0]
|
625 |
+
pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[:,0].sigmoid() > 0.5
|
626 |
+
gt_smask = b['gt_masks_orisize']
|
627 |
+
ious = get_iou(gt_smask, pred_smask_all)
|
628 |
+
all_batch_shape_iou += [ious]
|
629 |
+
if (ious > 0.9).sum() == len(ious):
|
630 |
+
all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
|
631 |
+
break
|
632 |
+
if self.interactive_mode in ['best', 'best_random']:
|
633 |
+
extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
|
634 |
+
elif self.interactive_mode == 'random':
|
635 |
+
extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
|
636 |
+
else:
|
637 |
+
assert False, "invalid interactive mode"
|
638 |
+
all_batch_shape_iou = torch.stack(all_batch_shape_iou)
|
639 |
+
processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
|
640 |
+
|
641 |
+
# visualization
|
642 |
+
# VL.step()
|
643 |
+
# import cv2
|
644 |
+
# v_masks = []
|
645 |
+
# v_pos_masks = []
|
646 |
+
# v_neg_masks = []
|
647 |
+
# txt = []
|
648 |
+
|
649 |
+
# img = batched_inputs[0]['image'].permute(1,2,0).cpu().numpy()
|
650 |
+
# mask_img = VL.overlay_single_mask_to_image(img[:,:,::-1], v_gt_mask.cpu().float().numpy())
|
651 |
+
# acc_pos_mask = np.zeros(v_pos_mask[0].shape)
|
652 |
+
# acc_neg_mask = np.zeros(v_neg_mask[0].shape)
|
653 |
+
# for x,y,z,iou in zip(v_pos_mask, v_neg_mask, v_pred_mask, all_batch_shape_iou):
|
654 |
+
# # dilate x,y
|
655 |
+
# x = cv2.dilate(x, np.ones((5,5), np.uint8), iterations=3)
|
656 |
+
# y = cv2.dilate(y, np.ones((5,5), np.uint8), iterations=3)
|
657 |
+
# acc_pos_mask += x
|
658 |
+
# acc_neg_mask += y
|
659 |
+
|
660 |
+
# v_masks += [z]
|
661 |
+
# v_pos_masks += [acc_pos_mask.clip(0,1)]
|
662 |
+
# v_neg_masks += [acc_neg_mask.clip(0,1)]
|
663 |
+
# txt += ["pred_{}".format(str(iou[0].item())[0:5])]
|
664 |
+
|
665 |
+
# VL.add_image(img[:,:,::-1])
|
666 |
+
# VL.insert(mask_img, "gt_mask")
|
667 |
+
# VL.overlay_obj_mask_to_image_withposneg(img[:,:,::-1], v_masks, v_pos_masks, v_neg_masks, txt, max_len=20)
|
668 |
+
return processed_results
|
669 |
+
|
670 |
+
def evaluate_referring_image(self, batched_inputs, extra={}):
|
671 |
+
assert self.task_switch['spatial']
|
672 |
+
assert len(batched_inputs) == 1, "only support batch size equal to 1"
|
673 |
+
assert self.interactive_mode == 'best'
|
674 |
+
|
675 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
676 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
677 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
678 |
+
img_bs = images.tensor.shape[0]
|
679 |
+
|
680 |
+
targets = targets_grounding = queries_grounding = None
|
681 |
+
features = self.backbone(images.tensor)
|
682 |
+
mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
|
683 |
+
|
684 |
+
if 'spatial_query' in batched_inputs[0]:
|
685 |
+
image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
|
686 |
+
nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
|
687 |
+
multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
|
688 |
+
mask_features = mask_features.repeat(nm,1,1,1)
|
689 |
+
|
690 |
+
query_index = self.sem_seg_head.predictor.query_index
|
691 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
|
692 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
|
693 |
+
|
694 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
|
695 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
|
696 |
+
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
|
697 |
+
|
698 |
+
outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='refimg')
|
699 |
+
return outputs, images.tensor.shape
|
700 |
+
|
701 |
+
def evaluate_grounding(self, batched_inputs, mode):
|
702 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
703 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
704 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
705 |
+
assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
|
706 |
+
|
707 |
+
extra = {}
|
708 |
+
# mask_pred_results = []
|
709 |
+
# for idx, batch_per_image in enumerate(batched_inputs):
|
710 |
+
# grd_texts = batch_per_image['groundings']['texts']
|
711 |
+
# grd_masks = []
|
712 |
+
# for anno_text in grd_texts:
|
713 |
+
# gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
|
714 |
+
# token_emb = gtext['token_emb']
|
715 |
+
# tokens = gtext['tokens']
|
716 |
+
|
717 |
+
# grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
|
718 |
+
# extra['grounding_tokens'] = grd_emb[:,None]
|
719 |
+
|
720 |
+
# assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
|
721 |
+
# features = self.backbone(images.tensor)
|
722 |
+
# outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
723 |
+
|
724 |
+
# pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
|
725 |
+
# v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
|
726 |
+
# t_emb = grd_emb[-1:]
|
727 |
+
|
728 |
+
# t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
729 |
+
# v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
730 |
+
|
731 |
+
# temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
732 |
+
# out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
733 |
+
|
734 |
+
# matched_id = out_prob.max(0)[1]
|
735 |
+
# grd_masks += [pred_gmasks[matched_id,:,:]]
|
736 |
+
# mask_pred_results += [torch.cat(grd_masks)]
|
737 |
+
|
738 |
+
# comment for multi object inference.
|
739 |
+
mask_pred_results = []
|
740 |
+
for idx, batch_per_image in enumerate(batched_inputs):
|
741 |
+
grd_texts = batch_per_image['groundings']['texts']
|
742 |
+
if self.standard_text_for_eval:
|
743 |
+
standard_texts = []
|
744 |
+
for grd in batch_per_image['grounding_info']:
|
745 |
+
mask_file = grd['mask_file'].split('.')[0].split('/')[-1]
|
746 |
+
target = mask_file.split('_')[-1].replace('+', ' ')
|
747 |
+
site = mask_file.split('_')[-2].replace('+', ' ')
|
748 |
+
modality = mask_file.split('_')[-3].replace('+', ' ')
|
749 |
+
standard_texts.append(f'{target} in {site} {modality}')
|
750 |
+
grd_texts = standard_texts
|
751 |
+
batch_per_image['groundings']['texts'] = standard_texts
|
752 |
+
|
753 |
+
|
754 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
|
755 |
+
token_emb = gtext['token_emb']
|
756 |
+
tokens = gtext['tokens']
|
757 |
+
query_emb = token_emb[tokens['attention_mask'].bool()]
|
758 |
+
non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
|
759 |
+
|
760 |
+
extra['grounding_tokens'] = query_emb[:,None]
|
761 |
+
extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
762 |
+
|
763 |
+
features = self.backbone(images.tensor)
|
764 |
+
outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
765 |
+
|
766 |
+
pred_gmasks = outputs['pred_gmasks'][idx]
|
767 |
+
v_emb = outputs['pred_gtexts'][idx]
|
768 |
+
t_emb = gtext['class_emb']
|
769 |
+
|
770 |
+
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
771 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
772 |
+
|
773 |
+
temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
774 |
+
out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
775 |
+
|
776 |
+
matched_id = out_prob.max(0)[1]
|
777 |
+
mask_pred_results += [pred_gmasks[matched_id,:,:]]
|
778 |
+
|
779 |
+
for i in range(len(mask_pred_results)):
|
780 |
+
# upsample masks
|
781 |
+
mask_pred_results[i] = F.interpolate(
|
782 |
+
mask_pred_results[i][None,],
|
783 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
784 |
+
mode="bilinear",
|
785 |
+
align_corners=False,
|
786 |
+
)[0]
|
787 |
+
|
788 |
+
processed_results = []
|
789 |
+
for mask_pred_result, input_per_image, image_size in zip(
|
790 |
+
mask_pred_results, batched_inputs, images.image_sizes
|
791 |
+
):
|
792 |
+
height = input_per_image.get("height", image_size[0])
|
793 |
+
width = input_per_image.get("width", image_size[1])
|
794 |
+
processed_results.append({})
|
795 |
+
|
796 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
797 |
+
mask_pred_result, image_size, height, width
|
798 |
+
)
|
799 |
+
processed_results[-1]['grounding_mask'] = mask_pred_result
|
800 |
+
|
801 |
+
# compute bbox
|
802 |
+
# bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
|
803 |
+
# bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
804 |
+
# processed_results[-1]['grounding_box'] = bbox
|
805 |
+
|
806 |
+
return processed_results
|
807 |
+
|
808 |
+
def evaluate_grounding_sptial(self, batched_inputs, mode):
|
809 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
810 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
811 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
812 |
+
assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
|
813 |
+
|
814 |
+
extra = {}
|
815 |
+
dilation = 3
|
816 |
+
pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
|
817 |
+
pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
|
818 |
+
pos_masks = (F.conv2d(pos_masks.float(), self.dilation_kernel, padding=dilation//2) > 0).unbind(0)
|
819 |
+
|
820 |
+
neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
|
821 |
+
neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
|
822 |
+
|
823 |
+
mask_pred_results = []
|
824 |
+
for idx, batch_per_image in enumerate(batched_inputs):
|
825 |
+
grd_texts = batch_per_image['groundings']['texts']
|
826 |
+
grd_masks = []
|
827 |
+
for idx2, anno_text in enumerate(grd_texts):
|
828 |
+
extra.update({'spatial_query_pos_mask': [pos_masks[idx2]], 'spatial_query_neg_mask': [neg_masks[idx2]]})
|
829 |
+
|
830 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
|
831 |
+
token_emb = gtext['token_emb']
|
832 |
+
tokens = gtext['tokens']
|
833 |
+
|
834 |
+
grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
|
835 |
+
non_zero_query_mask = torch.zeros(grd_emb[:,None].shape[:-1], dtype=torch.bool, device=grd_emb.device)
|
836 |
+
extra['grounding_tokens'] = grd_emb[:,None]
|
837 |
+
extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
838 |
+
|
839 |
+
assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
|
840 |
+
features = self.backbone(images.tensor)
|
841 |
+
outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
842 |
+
|
843 |
+
pred_gmasks = outputs['pred_gmasks'][idx]
|
844 |
+
v_emb = outputs['pred_gtexts'][idx]
|
845 |
+
t_emb = gtext['class_emb']
|
846 |
+
|
847 |
+
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
848 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
849 |
+
|
850 |
+
temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
851 |
+
out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
852 |
+
|
853 |
+
matched_id = out_prob.max(0)[1]
|
854 |
+
grd_masks += [pred_gmasks[matched_id,:,:]]
|
855 |
+
# grd_masks += [outputs['prev_mask'][0]]
|
856 |
+
|
857 |
+
mask_pred_results += [torch.cat(grd_masks)]
|
858 |
+
|
859 |
+
# comment for multi object inference.
|
860 |
+
# mask_pred_results = []
|
861 |
+
# for idx, batch_per_image in enumerate(batched_inputs):
|
862 |
+
# grd_texts = batch_per_image['groundings']['texts']
|
863 |
+
# grd_texts = [x[0] for x in grd_texts]
|
864 |
+
|
865 |
+
# gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
|
866 |
+
# token_emb = gtext['token_emb']
|
867 |
+
# tokens = gtext['tokens']
|
868 |
+
# query_emb = token_emb[tokens['attention_mask'].bool()]
|
869 |
+
# non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
|
870 |
+
|
871 |
+
# extra['grounding_tokens'] = query_emb[:,None]
|
872 |
+
# extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
|
873 |
+
|
874 |
+
# features = self.backbone(images.tensor)
|
875 |
+
# outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
876 |
+
|
877 |
+
# pred_gmasks = outputs['pred_gmasks'][idx]
|
878 |
+
# v_emb = outputs['pred_gtexts'][idx]
|
879 |
+
# t_emb = gtext['class_emb']
|
880 |
+
|
881 |
+
# t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
882 |
+
# v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
883 |
+
|
884 |
+
# temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
885 |
+
# out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
886 |
+
|
887 |
+
# matched_id = out_prob.max(0)[1]
|
888 |
+
# mask_pred_results += [pred_gmasks[matched_id,:,:]]
|
889 |
+
|
890 |
+
for i in range(len(mask_pred_results)):
|
891 |
+
# upsample masks
|
892 |
+
mask_pred_results[i] = F.interpolate(
|
893 |
+
mask_pred_results[i][None,],
|
894 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
895 |
+
mode="bilinear",
|
896 |
+
align_corners=False,
|
897 |
+
)[0]
|
898 |
+
|
899 |
+
processed_results = []
|
900 |
+
for mask_pred_result, input_per_image, image_size in zip(
|
901 |
+
mask_pred_results, batched_inputs, images.image_sizes
|
902 |
+
):
|
903 |
+
height = input_per_image.get("height", image_size[0])
|
904 |
+
width = input_per_image.get("width", image_size[1])
|
905 |
+
processed_results.append({})
|
906 |
+
|
907 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
908 |
+
mask_pred_result, image_size, height, width
|
909 |
+
)
|
910 |
+
processed_results[-1]['grounding_mask'] = mask_pred_result
|
911 |
+
|
912 |
+
return processed_results
|
913 |
+
|
914 |
+
def prepare_targets(self, batched_inputs, images):
|
915 |
+
h_pad, w_pad = images.tensor.shape[-2:]
|
916 |
+
new_targets = []
|
917 |
+
for idx, batch_per_image in enumerate(batched_inputs):
|
918 |
+
target_dict = {}
|
919 |
+
if self.task_switch['mask']:
|
920 |
+
targets_per_image = batch_per_image['instances'].to(self.device)
|
921 |
+
# pad gt
|
922 |
+
gt_masks = targets_per_image.gt_masks.tensor
|
923 |
+
padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
|
924 |
+
padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
|
925 |
+
|
926 |
+
gt_boxes = targets_per_image.gt_boxes.tensor
|
927 |
+
ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
|
928 |
+
gt_boxes = gt_boxes / ratio
|
929 |
+
xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
|
930 |
+
gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
|
931 |
+
|
932 |
+
target_dict.update({
|
933 |
+
"labels": targets_per_image.gt_classes,
|
934 |
+
"is_things": targets_per_image.is_things,
|
935 |
+
"masks": padded_masks,
|
936 |
+
"boxes": gt_boxes,
|
937 |
+
})
|
938 |
+
|
939 |
+
if self.task_switch['spatial']:
|
940 |
+
# prepare targets for spatial query
|
941 |
+
target_dict['gt_spatial_masks'] = batch_per_image['spatial_query']['gt_masks']
|
942 |
+
|
943 |
+
if self.task_switch['grounding']:
|
944 |
+
grd_masks = batch_per_image['groundings']['masks']
|
945 |
+
grd_texts = batch_per_image['groundings']['texts']
|
946 |
+
grd_hash = batch_per_image['groundings']['hash']
|
947 |
+
grd_task = batch_per_image['groundings']['mode']
|
948 |
+
|
949 |
+
if len(grd_masks) == 0:
|
950 |
+
padded_masks = None
|
951 |
+
else:
|
952 |
+
padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
|
953 |
+
padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
|
954 |
+
|
955 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
|
956 |
+
token_emb = gtext['token_emb']
|
957 |
+
tokens = gtext['tokens']
|
958 |
+
|
959 |
+
unique_hash_id = np.unique(grd_hash, return_index=True)[1]
|
960 |
+
selected_mask = np.zeros(len(grd_hash)).astype(bool)
|
961 |
+
selected_mask[unique_hash_id] = True
|
962 |
+
|
963 |
+
selected_token_emb = token_emb[selected_mask]
|
964 |
+
selected_attn_mask = tokens['attention_mask'][selected_mask]
|
965 |
+
query_emb = selected_token_emb[selected_attn_mask.bool()]
|
966 |
+
|
967 |
+
class_idx = tokens['attention_mask'].sum(dim=-1) - 1
|
968 |
+
class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
|
969 |
+
class_emb = token_emb[class_idx]
|
970 |
+
|
971 |
+
target_dict['grounding_masks'] = padded_masks
|
972 |
+
target_dict['grounding_query_embs'] = query_emb
|
973 |
+
target_dict['grounding_class_embs'] = class_emb
|
974 |
+
target_dict['grounding_hash'] = grd_hash
|
975 |
+
target_dict['grounding_task'] = grd_task
|
976 |
+
|
977 |
+
new_targets.append(target_dict)
|
978 |
+
return new_targets
|
979 |
+
|
980 |
+
def prepare_next_spaital_mask(self, outputs, batched_inputs, mode='best'):
|
981 |
+
gt_masks = [batched_inputs[i]['spatial_query']['gt_masks'] for i in range(len(batched_inputs))]
|
982 |
+
gt_masks = Spatial_ImageList.from_tensors(gt_masks, self.size_divisibility).tensor
|
983 |
+
|
984 |
+
pred_masks = (F.interpolate(outputs['prev_mask'], size=gt_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5)
|
985 |
+
prev_masks = nn.utils.rnn.pad_sequence(outputs['spatial_query_pos_mask'], padding_value=False, batch_first=True) | \
|
986 |
+
nn.utils.rnn.pad_sequence(outputs['spatial_query_neg_mask'], padding_value=False, batch_first=True)
|
987 |
+
|
988 |
+
fn = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks) # fn: False Negative, gt:1, pred:0, prev:0
|
989 |
+
fp = (~gt_masks & pred_masks) & (~prev_masks) # fp: False Positive, gt:0, pred:1, prev:0
|
990 |
+
|
991 |
+
# compute iou between gt and pred
|
992 |
+
iou = (gt_masks & pred_masks).sum(list(range(2,len(fn.shape)))) / ((gt_masks | pred_masks).sum(dim=list(range(2,len(fn.shape)))) + 1e-8)
|
993 |
+
fn_sum = fn.sum(dim=list(range(2,len(fn.shape))))
|
994 |
+
fp_sum = fp.sum(dim=list(range(2,len(fp.shape))))
|
995 |
+
|
996 |
+
is_postive = fn_sum > fp_sum
|
997 |
+
select_mask = torch.zeros_like(fn)
|
998 |
+
select_mask[is_postive] = fn[is_postive]
|
999 |
+
select_mask[~is_postive] = fp[~is_postive]
|
1000 |
+
# is_postive = torch.ones(len(fn_sum), device=torch.cuda.current_device()).bool()
|
1001 |
+
|
1002 |
+
# conv implementation
|
1003 |
+
bs,ns,h,w = select_mask.shape
|
1004 |
+
mask_dt = (distance_transform((~F.pad(select_mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(bs*ns,-1)
|
1005 |
+
if mode == 'best':
|
1006 |
+
max_xy_idx = torch.stack([torch.arange(bs*ns), mask_dt.max(dim=-1)[1].cpu()]).tolist()
|
1007 |
+
elif mode == 'best_random':
|
1008 |
+
max_xy_idx = torch.stack([torch.arange(bs*ns), torch.cat([(mask_dt[i] > 0).nonzero()[torch.randint(0, len((mask_dt[i] > 0).nonzero()), (1,))][0] for i in range(len(mask_dt))]).cpu()]).tolist()
|
1009 |
+
next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
|
1010 |
+
next_mask = next_mask.view(bs*ns,-1)
|
1011 |
+
next_mask[max_xy_idx] = True
|
1012 |
+
next_mask = next_mask.reshape((bs*ns,1,h,w)).float()
|
1013 |
+
dilation = 3
|
1014 |
+
next_mask = F.conv2d(next_mask, self.dilation_kernel, padding=dilation//2).reshape(bs,ns,h,w) > 0
|
1015 |
+
|
1016 |
+
# determine whether next mask is zero
|
1017 |
+
keep = (iou < 0.925)
|
1018 |
+
next_mask = next_mask & keep.view(bs,ns,1,1)
|
1019 |
+
|
1020 |
+
pos_mask = []
|
1021 |
+
neg_mask = []
|
1022 |
+
for idx, ip in enumerate(is_postive):
|
1023 |
+
mask_len = len(outputs['spatial_query_pos_mask'][idx])
|
1024 |
+
pos_mask += [outputs['spatial_query_pos_mask'][idx] | (next_mask[idx][:mask_len] & ip[:mask_len,None,None])]
|
1025 |
+
neg_mask += [outputs['spatial_query_neg_mask'][idx] | (next_mask[idx][:mask_len] & (~ip[:mask_len,None,None]))]
|
1026 |
+
|
1027 |
+
if 'false_positive_mask' in outputs:
|
1028 |
+
fp = outputs['false_positive_mask'] | fp
|
1029 |
+
return {'spatial_query_pos_mask': pos_mask, 'spatial_query_neg_mask': neg_mask, 'false_positive_mask': fp}
|
1030 |
+
|
1031 |
+
def semantic_inference(self, mask_cls, mask_pred):
|
1032 |
+
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
|
1033 |
+
mask_pred = mask_pred.sigmoid()
|
1034 |
+
semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
|
1035 |
+
return semseg
|
1036 |
+
|
1037 |
+
def panoptic_inference(self, mask_cls, mask_pred):
|
1038 |
+
scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
|
1039 |
+
mask_pred = mask_pred.sigmoid()
|
1040 |
+
|
1041 |
+
keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
|
1042 |
+
cur_scores = scores[keep]
|
1043 |
+
cur_classes = labels[keep]
|
1044 |
+
cur_masks = mask_pred[keep]
|
1045 |
+
cur_mask_cls = mask_cls[keep]
|
1046 |
+
cur_mask_cls = cur_mask_cls[:, :-1]
|
1047 |
+
|
1048 |
+
cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
|
1049 |
+
|
1050 |
+
h, w = cur_masks.shape[-2:]
|
1051 |
+
panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
|
1052 |
+
segments_info = []
|
1053 |
+
|
1054 |
+
current_segment_id = 0
|
1055 |
+
|
1056 |
+
if cur_masks.shape[0] == 0:
|
1057 |
+
# We didn't detect any mask :(
|
1058 |
+
return panoptic_seg, segments_info
|
1059 |
+
else:
|
1060 |
+
# take argmax
|
1061 |
+
cur_mask_ids = cur_prob_masks.argmax(0)
|
1062 |
+
stuff_memory_list = {}
|
1063 |
+
for k in range(cur_classes.shape[0]):
|
1064 |
+
pred_class = cur_classes[k].item()
|
1065 |
+
isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
|
1066 |
+
mask_area = (cur_mask_ids == k).sum().item()
|
1067 |
+
original_area = (cur_masks[k] >= 0.5).sum().item()
|
1068 |
+
mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
|
1069 |
+
|
1070 |
+
if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
|
1071 |
+
if mask_area / original_area < self.overlap_threshold:
|
1072 |
+
continue
|
1073 |
+
|
1074 |
+
# merge stuff regions
|
1075 |
+
if not isthing:
|
1076 |
+
if int(pred_class) in stuff_memory_list.keys():
|
1077 |
+
panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
|
1078 |
+
continue
|
1079 |
+
else:
|
1080 |
+
stuff_memory_list[int(pred_class)] = current_segment_id + 1
|
1081 |
+
|
1082 |
+
current_segment_id += 1
|
1083 |
+
panoptic_seg[mask] = current_segment_id
|
1084 |
+
|
1085 |
+
segments_info.append(
|
1086 |
+
{
|
1087 |
+
"id": current_segment_id,
|
1088 |
+
"isthing": bool(isthing),
|
1089 |
+
"category_id": int(pred_class),
|
1090 |
+
}
|
1091 |
+
)
|
1092 |
+
|
1093 |
+
return panoptic_seg, segments_info
|
1094 |
+
|
1095 |
+
def instance_inference(self, mask_cls, mask_pred, box_pred):
|
1096 |
+
# mask_pred is already processed to have the same shape as original input
|
1097 |
+
image_size = mask_pred.shape[-2:]
|
1098 |
+
|
1099 |
+
# [Q, K]
|
1100 |
+
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
|
1101 |
+
labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
|
1102 |
+
# scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
|
1103 |
+
scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
|
1104 |
+
|
1105 |
+
labels_per_image = labels[topk_indices]
|
1106 |
+
topk_indices = (topk_indices // self.sem_seg_head.num_classes)
|
1107 |
+
# mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
|
1108 |
+
mask_pred = mask_pred[topk_indices]
|
1109 |
+
if box_pred is not None:
|
1110 |
+
box_pred = box_pred[topk_indices]
|
1111 |
+
|
1112 |
+
# if this is panoptic segmentation, we only keep the "thing" classes
|
1113 |
+
if self.panoptic_on:
|
1114 |
+
keep = torch.zeros_like(scores_per_image).bool()
|
1115 |
+
for i, lab in enumerate(labels_per_image):
|
1116 |
+
keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
|
1117 |
+
|
1118 |
+
scores_per_image = scores_per_image[keep]
|
1119 |
+
labels_per_image = labels_per_image[keep]
|
1120 |
+
mask_pred = mask_pred[keep]
|
1121 |
+
|
1122 |
+
if box_pred is not None:
|
1123 |
+
box_pred = box_pred[keep]
|
1124 |
+
|
1125 |
+
result = Instances(image_size)
|
1126 |
+
# mask (before sigmoid)
|
1127 |
+
result.pred_masks = (mask_pred > 0).float()
|
1128 |
+
# result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
1129 |
+
# Uncomment the following to get boxes from masks (this is slow)
|
1130 |
+
|
1131 |
+
if box_pred is not None:
|
1132 |
+
result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
|
1133 |
+
else:
|
1134 |
+
result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
1135 |
+
|
1136 |
+
# calculate average mask prob
|
1137 |
+
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
|
1138 |
+
result.scores = scores_per_image * mask_scores_per_image
|
1139 |
+
result.pred_classes = labels_per_image
|
1140 |
+
|
1141 |
+
return result
|
1142 |
+
|
1143 |
+
def prepare_targets4query(self, targets, images, topk=5):
|
1144 |
+
h_pad, w_pad = images.tensor.shape[-2:]
|
1145 |
+
new_targets = []
|
1146 |
+
new_queries = []
|
1147 |
+
for targets_per_image in targets:
|
1148 |
+
# we randomly sample maximally topk concepts
|
1149 |
+
unique_target_classes = [k for k in set(targets_per_image.gt_classes.tolist())]
|
1150 |
+
selected_target_classes = random.sample(unique_target_classes, min(topk, len(unique_target_classes)))
|
1151 |
+
new_targets_per_image = []
|
1152 |
+
new_queries_per_image = []
|
1153 |
+
for clss in selected_target_classes:
|
1154 |
+
indices = (targets_per_image.gt_classes == clss).nonzero().view(-1)
|
1155 |
+
# pad gt
|
1156 |
+
gt_masks = targets_per_image.gt_masks[indices]
|
1157 |
+
padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
|
1158 |
+
padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
|
1159 |
+
|
1160 |
+
# convert class into concept name and then token seq
|
1161 |
+
self.sem_seg_head.predictor.lang_encoder.get_text_embeddings([BIOMED_CLASSES[clss]], name='grounding')
|
1162 |
+
query = getattr(self.sem_seg_head.predictor.lang_encoder, 'grounding_text_embeddings')
|
1163 |
+
|
1164 |
+
new_targets.append(
|
1165 |
+
{
|
1166 |
+
"labels": targets_per_image.gt_classes[indices],
|
1167 |
+
"masks": padded_masks,
|
1168 |
+
}
|
1169 |
+
)
|
1170 |
+
new_queries_per_image.append(query)
|
1171 |
+
new_queries.append(new_queries_per_image)
|
1172 |
+
|
1173 |
+
return new_targets, new_queries
|
1174 |
+
|
1175 |
+
|
1176 |
+
|
1177 |
+
@register_model
|
1178 |
+
def get_seem_model(cfg, **kwargs):
|
1179 |
+
return GeneralizedSEEM(cfg)
|
modeling/architectures/xdecoder_model.py
ADDED
@@ -0,0 +1,937 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Xueyan Zou ([email protected]), Ziyi Dou, Jianwei Yang
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
from typing import Tuple
|
9 |
+
import random
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
from timm.models.layers import trunc_normal_
|
17 |
+
from nltk.stem.lancaster import LancasterStemmer
|
18 |
+
from detectron2.structures import Boxes, ImageList, Instances, BitMasks, BoxMode
|
19 |
+
from detectron2.utils.memory import retry_if_cuda_oom
|
20 |
+
from detectron2.data import MetadataCatalog
|
21 |
+
|
22 |
+
from .build import register_model
|
23 |
+
from ..utils import configurable, get_class_names
|
24 |
+
from ..vision.backbone import build_backbone, Backbone
|
25 |
+
from ..body import build_xdecoder_head
|
26 |
+
from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
|
27 |
+
from ..language import build_language_encoder
|
28 |
+
from ..language.loss import vl_similarity, image_text_contrastive_loss_queue
|
29 |
+
from utilities.prompt_engineering import prompt_engineering
|
30 |
+
from utilities.constants import COCO_PANOPTIC_CLASSES
|
31 |
+
|
32 |
+
st = LancasterStemmer()
|
33 |
+
|
34 |
+
|
35 |
+
class GeneralizedXdecoder(nn.Module):
|
36 |
+
|
37 |
+
@configurable
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
*,
|
41 |
+
backbone: Backbone,
|
42 |
+
sem_seg_head: nn.Module,
|
43 |
+
criterion: nn.Module,
|
44 |
+
losses: dict,
|
45 |
+
num_queries: int,
|
46 |
+
object_mask_threshold: float,
|
47 |
+
overlap_threshold: float,
|
48 |
+
metadata,
|
49 |
+
task_switch: dict,
|
50 |
+
phrase_prob: float,
|
51 |
+
size_divisibility: int,
|
52 |
+
sem_seg_postprocess_before_inference: bool,
|
53 |
+
pixel_mean: Tuple[float],
|
54 |
+
pixel_std: Tuple[float],
|
55 |
+
# inference
|
56 |
+
semantic_on: bool,
|
57 |
+
panoptic_on: bool,
|
58 |
+
instance_on: bool,
|
59 |
+
test_topk_per_image: int,
|
60 |
+
train_dataset_name: str,
|
61 |
+
retrieval_emsemble: bool,
|
62 |
+
backbone_dim: int,
|
63 |
+
dim_proj: int,
|
64 |
+
):
|
65 |
+
"""
|
66 |
+
Args:
|
67 |
+
backbone: a backbone module, must follow detectron2's backbone interface
|
68 |
+
sem_seg_head: a module that predicts semantic segmentation from backbone features
|
69 |
+
criterion: a module that defines the loss
|
70 |
+
num_queries: int, number of queries
|
71 |
+
object_mask_threshold: float, threshold to filter query based on classification score
|
72 |
+
for panoptic segmentation inference
|
73 |
+
overlap_threshold: overlap threshold used in general inference for panoptic segmentation
|
74 |
+
metadata: dataset meta, get `thing` and `stuff` category names for panoptic
|
75 |
+
segmentation inference
|
76 |
+
size_divisibility: Some backbones require the input height and width to be divisible by a
|
77 |
+
specific integer. We can use this to override such requirement.
|
78 |
+
sem_seg_postprocess_before_inference: whether to resize the prediction back
|
79 |
+
to original input size before semantic segmentation inference or after.
|
80 |
+
For high-resolution dataset like Mapillary, resizing predictions before
|
81 |
+
inference will cause OOM error.
|
82 |
+
pixel_mean, pixel_std: list or tuple with #channels element, representing
|
83 |
+
the per-channel mean and std to be used to normalize the input image
|
84 |
+
semantic_on: bool, whether to output semantic segmentation prediction
|
85 |
+
instance_on: bool, whether to output instance segmentation prediction
|
86 |
+
panoptic_on: bool, whether to output panoptic segmentation prediction
|
87 |
+
test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
|
88 |
+
"""
|
89 |
+
super().__init__()
|
90 |
+
self.backbone = backbone
|
91 |
+
self.sem_seg_head = sem_seg_head
|
92 |
+
self.criterion = criterion
|
93 |
+
self.losses = losses
|
94 |
+
self.num_queries = num_queries
|
95 |
+
self.overlap_threshold = overlap_threshold
|
96 |
+
self.object_mask_threshold = object_mask_threshold
|
97 |
+
self.metadata = metadata
|
98 |
+
if size_divisibility < 0:
|
99 |
+
# use backbone size_divisibility if not set
|
100 |
+
size_divisibility = self.backbone.size_divisibility
|
101 |
+
self.size_divisibility = size_divisibility
|
102 |
+
self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
|
103 |
+
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
104 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
105 |
+
|
106 |
+
# additional args
|
107 |
+
self.semantic_on = semantic_on
|
108 |
+
self.instance_on = instance_on
|
109 |
+
self.panoptic_on = panoptic_on
|
110 |
+
|
111 |
+
# caption argument
|
112 |
+
self.task_switch = task_switch
|
113 |
+
self.phrase_prob = phrase_prob
|
114 |
+
|
115 |
+
self.test_topk_per_image = test_topk_per_image
|
116 |
+
self.train_class_names = get_class_names(train_dataset_name)
|
117 |
+
|
118 |
+
self.retrieval_emsemble = retrieval_emsemble
|
119 |
+
# backbone itc loss
|
120 |
+
if task_switch['retrieval'] and retrieval_emsemble:
|
121 |
+
self.backbone_proj = nn.Parameter(torch.empty(backbone_dim, dim_proj))
|
122 |
+
trunc_normal_(self.backbone_proj, std=.02)
|
123 |
+
|
124 |
+
if not self.semantic_on:
|
125 |
+
assert self.sem_seg_postprocess_before_inference
|
126 |
+
|
127 |
+
@classmethod
|
128 |
+
def from_config(cls, cfg):
|
129 |
+
enc_cfg = cfg['MODEL']['ENCODER']
|
130 |
+
dec_cfg = cfg['MODEL']['DECODER']
|
131 |
+
|
132 |
+
# Loss parameters:
|
133 |
+
deep_supervision = dec_cfg['DEEP_SUPERVISION']
|
134 |
+
no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']
|
135 |
+
|
136 |
+
# loss weights, switcher for task, and top layers to compute loss
|
137 |
+
loss_weights = {'mask': {'ce': dec_cfg['CLASS_WEIGHT'], 'dice': dec_cfg['DICE_WEIGHT'], 'bce': dec_cfg['MASK_WEIGHT']},
|
138 |
+
'bbox': {'l1': dec_cfg['BBOX_WEIGHT'], 'giou': dec_cfg['GIOU_WEIGHT']},
|
139 |
+
'caption': dec_cfg['CAPTION_WEIGHT'],
|
140 |
+
'captioning': dec_cfg['CAPTIONING_WEIGHT'],
|
141 |
+
'retrieval': {'decoder': dec_cfg['RETRIEVAL_WEIGHT'], 'backbone': dec_cfg['BACKBONER_WEIGHT']},
|
142 |
+
'grounding': {'ce': dec_cfg['GCLASS_WEIGHT'], 'dice': dec_cfg['GDICE_WEIGHT'], 'bce': dec_cfg['GMASK_WEIGHT']}}
|
143 |
+
|
144 |
+
task_switch = {'bbox': dec_cfg.get('DETECTION', False),
|
145 |
+
'mask': dec_cfg.get('MASK', True),
|
146 |
+
'caption': dec_cfg['CAPTION'].get('ENABLED', False),
|
147 |
+
'captioning': dec_cfg['CAPTIONING'].get('ENABLED', False),
|
148 |
+
'retrieval': dec_cfg['RETRIEVAL'].get('ENABLED', False),
|
149 |
+
'grounding': dec_cfg['GROUNDING'].get('ENABLED', False)}
|
150 |
+
|
151 |
+
top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10),
|
152 |
+
'caption': dec_cfg.get('TOP_CAPTION_LAYERS', 10),
|
153 |
+
'captioning': dec_cfg.get('TOP_CAPTIONING_LAYERS', 10),
|
154 |
+
'retrieval': dec_cfg.get('TOP_RETRIEVAL_LAYERS', 10),
|
155 |
+
'grounding': dec_cfg.get('TOP_GROUNDING_LAYERS', 10),}
|
156 |
+
|
157 |
+
# build model
|
158 |
+
extra = {'task_switch': task_switch}
|
159 |
+
backbone = build_backbone(cfg)
|
160 |
+
lang_encoder = build_language_encoder(cfg)
|
161 |
+
sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra)
|
162 |
+
|
163 |
+
# building criterion
|
164 |
+
matcher = HungarianMatcher(
|
165 |
+
cost_class=loss_weights['mask']['ce'],
|
166 |
+
cost_mask=loss_weights['mask']['bce'],
|
167 |
+
cost_dice=loss_weights['mask']['dice'],
|
168 |
+
num_points=dec_cfg['TRAIN_NUM_POINTS'],
|
169 |
+
)
|
170 |
+
|
171 |
+
# init weight dict and criterion loss functions.
|
172 |
+
losses = {'seg': [], 'vlp': []}
|
173 |
+
if task_switch['mask']:
|
174 |
+
losses['seg'] += ["labels", "masks"]
|
175 |
+
if task_switch['caption']:
|
176 |
+
losses['seg'] += ["captions"]
|
177 |
+
if task_switch['grounding']:
|
178 |
+
losses['seg'] += ["groundings"]
|
179 |
+
if task_switch['captioning']:
|
180 |
+
losses['vlp'] += ["captionings"]
|
181 |
+
if task_switch['retrieval']:
|
182 |
+
losses['vlp'] += ["retrievals"]
|
183 |
+
|
184 |
+
weight_dict = {}
|
185 |
+
for key, turn_on in task_switch.items():
|
186 |
+
if turn_on:
|
187 |
+
if isinstance(loss_weights[key], dict):
|
188 |
+
# HACK it should support bbox in the future
|
189 |
+
for key_, weight in loss_weights[key].items():
|
190 |
+
weight_dict["loss_{}_{}_0".format(key, key_)] = weight # NOTE: hard code for segmentation that has multiple loss
|
191 |
+
else:
|
192 |
+
weight_dict["loss_{}_0".format(key)] = loss_weights[key]
|
193 |
+
|
194 |
+
# generate full weight dict and remove not computed layers.
|
195 |
+
if deep_supervision:
|
196 |
+
dec_layers = dec_cfg['DEC_LAYERS']
|
197 |
+
aux_weight_dict = {}
|
198 |
+
for i in range(dec_layers - 1):
|
199 |
+
for k, v in weight_dict.items():
|
200 |
+
if (i+1) > (top_x_layers[k.split('_')[1]] - 1):
|
201 |
+
continue
|
202 |
+
aux_weight_dict.update({k.replace('_0', f"_{i+1}"): v})
|
203 |
+
weight_dict.update(aux_weight_dict)
|
204 |
+
|
205 |
+
grd_weight = {'text': dec_cfg['GROUNDING']['TEXT_WEIGHT'], 'class': dec_cfg['GROUNDING']['CLASS_WEIGHT']}
|
206 |
+
# generate critenrion for loss function.
|
207 |
+
criterion = SetCriterion(
|
208 |
+
sem_seg_head.num_classes,
|
209 |
+
matcher=matcher,
|
210 |
+
weight_dict=weight_dict,
|
211 |
+
top_x_layers=top_x_layers,
|
212 |
+
eos_coef=no_object_weight,
|
213 |
+
losses=[],
|
214 |
+
num_points=dec_cfg['TRAIN_NUM_POINTS'],
|
215 |
+
oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],
|
216 |
+
importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],
|
217 |
+
grounding_weight=grd_weight,
|
218 |
+
)
|
219 |
+
|
220 |
+
# extra logistic
|
221 |
+
train_dataset_name = cfg['DATASETS']['TRAIN'][0] # HACK for only one training set.
|
222 |
+
phrase_prob = dec_cfg['CAPTION'].get('PHRASE_PROB', 0.5)
|
223 |
+
|
224 |
+
return {
|
225 |
+
"backbone": backbone,
|
226 |
+
"sem_seg_head": sem_seg_head,
|
227 |
+
"criterion": criterion,
|
228 |
+
"losses": losses,
|
229 |
+
"num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
|
230 |
+
"object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
|
231 |
+
"overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
|
232 |
+
"metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),
|
233 |
+
"size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
|
234 |
+
"sem_seg_postprocess_before_inference": (
|
235 |
+
dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
|
236 |
+
or dec_cfg['TEST']['PANOPTIC_ON']
|
237 |
+
or dec_cfg['TEST']['INSTANCE_ON']
|
238 |
+
),
|
239 |
+
"pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
|
240 |
+
"pixel_std": cfg['INPUT']['PIXEL_STD'],
|
241 |
+
"task_switch": task_switch,
|
242 |
+
"phrase_prob": phrase_prob,
|
243 |
+
# inference
|
244 |
+
"semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
|
245 |
+
"instance_on": dec_cfg['TEST']['INSTANCE_ON'],
|
246 |
+
"panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
|
247 |
+
"test_topk_per_image": cfg['COCO']['TEST']['DETECTIONS_PER_IMAGE'],
|
248 |
+
"train_dataset_name": train_dataset_name,
|
249 |
+
"retrieval_emsemble": dec_cfg['RETRIEVAL']['ENSEMBLE'],
|
250 |
+
"backbone_dim": cfg['MODEL']['BACKBONE_DIM'],
|
251 |
+
"dim_proj": cfg['MODEL']['DIM_PROJ'],
|
252 |
+
}
|
253 |
+
|
254 |
+
@property
|
255 |
+
def device(self):
|
256 |
+
return self.pixel_mean.device
|
257 |
+
|
258 |
+
def forward(self, batched_inputs, mode=None):
|
259 |
+
"""
|
260 |
+
Args:
|
261 |
+
batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
|
262 |
+
Each item in the list contains the inputs for one image.
|
263 |
+
For now, each item in the list is a dict that contains:
|
264 |
+
* "image": Tensor, image in (C, H, W) format.
|
265 |
+
* "instances": per-region ground truth
|
266 |
+
* Other information that's included in the original dicts, such as:
|
267 |
+
"height", "width" (int): the output resolution of the model (may be different
|
268 |
+
from input resolution), used in inference.
|
269 |
+
Returns:
|
270 |
+
list[dict]:
|
271 |
+
each dict has the results for one image. The dict contains the following keys:
|
272 |
+
|
273 |
+
* "sem_seg":
|
274 |
+
A Tensor that represents the
|
275 |
+
per-pixel segmentation prediced by the head.
|
276 |
+
The prediction has shape KxHxW that represents the logits of
|
277 |
+
each class for each pixel.
|
278 |
+
* "panoptic_seg":
|
279 |
+
A tuple that represent panoptic output
|
280 |
+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
|
281 |
+
segments_info (list[dict]): Describe each segment in `panoptic_seg`.
|
282 |
+
Each dict contains keys "id", "category_id", "isthing".
|
283 |
+
"""
|
284 |
+
if self.training:
|
285 |
+
losses = {}
|
286 |
+
if self.task_switch['mask']:
|
287 |
+
losses_seg = self.forward_seg(batched_inputs['coco'])
|
288 |
+
losses.update(losses_seg)
|
289 |
+
if self.task_switch['retrieval'] or self.task_switch['captioning']:
|
290 |
+
losses_vlp = self.forward_vlp(batched_inputs['vlp'])
|
291 |
+
losses.update(losses_vlp)
|
292 |
+
for k in list(losses.keys()):
|
293 |
+
if k in self.criterion.weight_dict:
|
294 |
+
losses[k] *= self.criterion.weight_dict[k]
|
295 |
+
else: # remove this loss if not specified in `weight_dict`
|
296 |
+
losses.pop(k)
|
297 |
+
return losses
|
298 |
+
else:
|
299 |
+
if mode == 'retrieval':
|
300 |
+
return self.evaluate_retrieval(batched_inputs)
|
301 |
+
elif mode == 'captioning':
|
302 |
+
return self.evaluate_captioning(batched_inputs)
|
303 |
+
elif mode == 'classification':
|
304 |
+
return self.evaluate_classification(batched_inputs)
|
305 |
+
elif mode == 'grounding_refcoco':
|
306 |
+
return self.evaluate_grounding(batched_inputs, mode)
|
307 |
+
else:
|
308 |
+
return self.evaluate(batched_inputs)
|
309 |
+
|
310 |
+
|
311 |
+
def forward_seg(self, batched_inputs):
|
312 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
313 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
314 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
315 |
+
|
316 |
+
self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
|
317 |
+
|
318 |
+
extra = {}
|
319 |
+
# mask classification target
|
320 |
+
if "instances" in batched_inputs[0]:
|
321 |
+
# input bounding box is checked to be correct.
|
322 |
+
targets = self.prepare_targets(batched_inputs, images)
|
323 |
+
|
324 |
+
if self.task_switch['grounding']:
|
325 |
+
grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
|
326 |
+
grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens)
|
327 |
+
extra['grounding_tokens'] = grounding_tokens
|
328 |
+
|
329 |
+
features = self.backbone(images.tensor)
|
330 |
+
outputs = self.sem_seg_head(features, extra=extra)
|
331 |
+
|
332 |
+
_outputs = {}
|
333 |
+
for key, value in outputs.items():
|
334 |
+
if key == 'pred_logits':
|
335 |
+
_outputs[key] = value[:,:self.num_queries-1]
|
336 |
+
elif key == 'pred_masks':
|
337 |
+
_outputs[key] = value[:,:self.num_queries-1]
|
338 |
+
if self.task_switch['grounding']:
|
339 |
+
_outputs['pred_gmasks'] = value[:,self.num_queries:2*self.num_queries-1]
|
340 |
+
elif key == 'pred_captions':
|
341 |
+
_outputs[key] = value[:,:self.num_queries-1]
|
342 |
+
if self.task_switch['grounding']:
|
343 |
+
_outputs['pred_gtexts'] = value[:,self.num_queries:2*self.num_queries-1]
|
344 |
+
elif key == 'aux_outputs':
|
345 |
+
_outputs[key] = []
|
346 |
+
for i in range(len(value)):
|
347 |
+
_outputs[key] += [{}]
|
348 |
+
for _key, _value in value[i].items():
|
349 |
+
if _key == 'pred_logits':
|
350 |
+
_outputs[key][i][_key] = _value[:,:self.num_queries-1]
|
351 |
+
elif _key == 'pred_masks':
|
352 |
+
_outputs[key][i][_key] = _value[:,:self.num_queries-1]
|
353 |
+
if self.task_switch['grounding']:
|
354 |
+
_outputs[key][i]['pred_gmasks'] = _value[:,self.num_queries:2*self.num_queries-1]
|
355 |
+
elif _key == 'pred_captions':
|
356 |
+
_outputs[key][i][_key] = _value[:,:self.num_queries-1]
|
357 |
+
if self.task_switch['grounding']:
|
358 |
+
_outputs[key][i]['pred_gtexts'] = _value[:,self.num_queries:2*self.num_queries-1]
|
359 |
+
outputs = _outputs
|
360 |
+
|
361 |
+
extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
|
362 |
+
'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default'))}
|
363 |
+
|
364 |
+
# bipartite matching-based loss
|
365 |
+
self.criterion.losses = self.losses['seg'] # seg criterion losses
|
366 |
+
losses = self.criterion(outputs, targets, extra)
|
367 |
+
|
368 |
+
del outputs
|
369 |
+
del _outputs
|
370 |
+
return losses
|
371 |
+
|
372 |
+
def forward_vlp(self, batched_inputs):
|
373 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
374 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
375 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
376 |
+
targets_vlp = self.prepare_vlp_targets(batched_inputs, images.tensor.device)
|
377 |
+
|
378 |
+
extra = {"token_embedding": self.sem_seg_head.predictor.lang_encoder.lang_encoder.token_embedding,
|
379 |
+
"lang_encoder": self.sem_seg_head.predictor.lang_encoder,
|
380 |
+
"training": self.training}
|
381 |
+
|
382 |
+
features = self.backbone(images.tensor)
|
383 |
+
outputs = self.sem_seg_head(features, target_queries=None, target_vlp=targets_vlp, task='vlp', extra=extra)
|
384 |
+
|
385 |
+
for key, value in outputs.items():
|
386 |
+
if key == 'pred_captionings':
|
387 |
+
outputs[key] = value
|
388 |
+
elif key == 'pred_captions':
|
389 |
+
# outputs[key] = value[:,-1:]
|
390 |
+
outputs[key] = value
|
391 |
+
elif key == 'aux_outputs':
|
392 |
+
outputs[key] = []
|
393 |
+
for i in range(len(value)):
|
394 |
+
outputs[key] += [{}]
|
395 |
+
for _key, _value in value[i].items():
|
396 |
+
if _key == 'pred_captions':
|
397 |
+
# outputs[key][i][_key] = _value[:,-1:]
|
398 |
+
outputs[key][i][_key] = _value
|
399 |
+
elif _key == 'pred_captionings':
|
400 |
+
outputs[key][i][_key] = _value
|
401 |
+
|
402 |
+
self.criterion.losses = self.losses['vlp'] # seg criterion losses
|
403 |
+
losses = self.criterion.forward_vlp(outputs, targets_vlp, extra)
|
404 |
+
del outputs
|
405 |
+
|
406 |
+
if self.task_switch['retrieval'] and self.retrieval_emsemble:
|
407 |
+
# compute backbone vlp.
|
408 |
+
v_emb = features['res5']
|
409 |
+
bs,nc,_,_ = v_emb.shape
|
410 |
+
v_emb = v_emb.reshape(bs,nc,-1)
|
411 |
+
v_emb = F.adaptive_avg_pool1d(v_emb, 1).reshape(bs,nc) @ self.backbone_proj
|
412 |
+
t_emb = torch.cat([x['caption_proj'] for x in targets_vlp], dim=0)
|
413 |
+
loss_contrast = image_text_contrastive_loss_queue(v_emb, t_emb, self.sem_seg_head.predictor.lang_encoder, None)
|
414 |
+
losses['loss_retrieval_backbone_0'] = loss_contrast
|
415 |
+
return losses
|
416 |
+
|
417 |
+
def evaluate(self, batched_inputs):
|
418 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
419 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
420 |
+
|
421 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
422 |
+
img_bs = images.tensor.shape[0]
|
423 |
+
|
424 |
+
targets = targets_grounding = queries_grounding = None
|
425 |
+
features = self.backbone(images.tensor)
|
426 |
+
outputs = self.sem_seg_head(features, target_queries=queries_grounding)
|
427 |
+
|
428 |
+
mask_cls_results = outputs["pred_logits"]
|
429 |
+
mask_pred_results = outputs["pred_masks"]
|
430 |
+
box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
|
431 |
+
caption_pred_results = outputs["pred_captions"] if self.task_switch['caption'] else [None for i in range(len(mask_pred_results))]
|
432 |
+
|
433 |
+
# upsample masks
|
434 |
+
mask_pred_results = F.interpolate(
|
435 |
+
mask_pred_results,
|
436 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
437 |
+
mode="bicubic",
|
438 |
+
align_corners=False,
|
439 |
+
antialias=True
|
440 |
+
)
|
441 |
+
|
442 |
+
input_size = mask_pred_results.shape[-2:]
|
443 |
+
keep_sem_bgd = self.metadata.keep_sem_bgd if hasattr(self.metadata, 'keep_sem_bgd') else False
|
444 |
+
del outputs
|
445 |
+
|
446 |
+
processed_results = []
|
447 |
+
for mask_cls_result, mask_pred_result, box_pred_result, caption_pred_result, input_per_image, image_size in zip(
|
448 |
+
mask_cls_results, mask_pred_results, box_pred_results, caption_pred_results, batched_inputs, images.image_sizes
|
449 |
+
):
|
450 |
+
height = input_per_image.get("height", image_size[0])
|
451 |
+
width = input_per_image.get("width", image_size[1])
|
452 |
+
processed_results.append({})
|
453 |
+
|
454 |
+
if self.sem_seg_postprocess_before_inference:
|
455 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
456 |
+
mask_pred_result, image_size, height, width
|
457 |
+
)
|
458 |
+
mask_cls_result = mask_cls_result.to(mask_pred_result)
|
459 |
+
|
460 |
+
# semantic segmentation inference
|
461 |
+
if self.semantic_on:
|
462 |
+
r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result, keep_sem_bgd)
|
463 |
+
if not self.sem_seg_postprocess_before_inference:
|
464 |
+
r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
|
465 |
+
processed_results[-1]["sem_seg"] = r
|
466 |
+
|
467 |
+
# panoptic segmentation inference
|
468 |
+
if self.panoptic_on:
|
469 |
+
panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
|
470 |
+
processed_results[-1]["panoptic_seg"] = panoptic_r
|
471 |
+
|
472 |
+
# instance segmentation inference
|
473 |
+
if self.instance_on:
|
474 |
+
if self.task_switch['bbox']:
|
475 |
+
box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
|
476 |
+
instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
|
477 |
+
processed_results[-1]["instances"] = instance_r
|
478 |
+
if self.task_switch['caption']:
|
479 |
+
processed_results[-1]["captions"] = caption_pred_result
|
480 |
+
processed_results[-1]["masks"] = mask_pred_result
|
481 |
+
|
482 |
+
return processed_results
|
483 |
+
|
484 |
+
def evaluate_retrieval(self, batched_inputs):
|
485 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
486 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
487 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
488 |
+
img_bs = images.tensor.shape[0]
|
489 |
+
|
490 |
+
targets = targets_grounding = queries_grounding = None
|
491 |
+
features = self.backbone(images.tensor)
|
492 |
+
outputs = self.sem_seg_head(features, target_queries=queries_grounding)
|
493 |
+
v_emb_it = outputs['pred_captions'][:,-1]
|
494 |
+
|
495 |
+
# compute backbone score
|
496 |
+
if self.task_switch['retrieval'] and self.retrieval_emsemble:
|
497 |
+
_v_emb_it = features['res5']
|
498 |
+
bs,nc,_,_ = _v_emb_it.shape
|
499 |
+
_v_emb_it = _v_emb_it.reshape(bs,nc,-1)
|
500 |
+
_v_emb_it = F.adaptive_avg_pool1d(_v_emb_it, 1).reshape(bs,nc) @ self.backbone_proj
|
501 |
+
|
502 |
+
processed_results = []
|
503 |
+
for idx, batch_data in enumerate(batched_inputs):
|
504 |
+
caption_ids = []
|
505 |
+
t_emb_its = []
|
506 |
+
processed_results.append({})
|
507 |
+
for caption in batch_data['captions']:
|
508 |
+
lang_results = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(caption)
|
509 |
+
t_emb_it = lang_results['class_emb']
|
510 |
+
caption_ids.append(batch_data['image_id'])
|
511 |
+
t_emb_its.append(t_emb_it)
|
512 |
+
|
513 |
+
t_emb_it = torch.cat(t_emb_its, dim=0)
|
514 |
+
|
515 |
+
image_embeds = [v_emb_it[idx].unsqueeze(0)]
|
516 |
+
if self.task_switch['retrieval'] and self.retrieval_emsemble:
|
517 |
+
image_embeds += [_v_emb_it[idx].unsqueeze(0)]
|
518 |
+
caption_results = {
|
519 |
+
'image_embeds': image_embeds,
|
520 |
+
'text_embeds': t_emb_it,
|
521 |
+
'caption_ids': caption_ids,
|
522 |
+
'image_ids': batch_data['image_id'],
|
523 |
+
}
|
524 |
+
processed_results[-1]["caption"] = caption_results
|
525 |
+
|
526 |
+
del features
|
527 |
+
return processed_results
|
528 |
+
|
529 |
+
def evaluate_captioning(self, batched_inputs):
|
530 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
531 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
532 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
533 |
+
img_bs = images.tensor.shape[0]
|
534 |
+
|
535 |
+
if not hasattr(self, 'start_token'):
|
536 |
+
self.start_token = torch.tensor([[49406]*77], device=self.device)
|
537 |
+
|
538 |
+
targets = targets_grounding = queries_grounding = None
|
539 |
+
features = self.backbone(images.tensor)
|
540 |
+
|
541 |
+
captioning_mask = None
|
542 |
+
if 'captioning_mask' in batched_inputs[-1]:
|
543 |
+
captioning_mask = torch.cat([x['captioning_mask'] for x in batched_inputs])
|
544 |
+
|
545 |
+
outputs = self.sem_seg_head(features, target_queries=queries_grounding, task='captioning_infer', extra={'start_token': self.start_token, 'captioning_mask': captioning_mask})
|
546 |
+
|
547 |
+
processed_results = []
|
548 |
+
for idx, batch_data in enumerate(batched_inputs):
|
549 |
+
processed_results.append({})
|
550 |
+
processed_results[-1]["captioning_token"] = outputs['pred_captionings'][idx]
|
551 |
+
processed_results[-1]["captioning_text"] = outputs['pred_texts'][idx].split('.')[0]
|
552 |
+
processed_results[-1]["image_id"] = batched_inputs[idx]['image_id']
|
553 |
+
|
554 |
+
return processed_results
|
555 |
+
|
556 |
+
def evaluate_classification(self, batched_inputs):
|
557 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
558 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
559 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
560 |
+
img_bs = images.tensor.shape[0]
|
561 |
+
|
562 |
+
targets = targets_grounding = queries_grounding = None
|
563 |
+
features = self.backbone(images.tensor)
|
564 |
+
outputs = self.sem_seg_head(features, target_queries=queries_grounding)
|
565 |
+
|
566 |
+
processed_results = []
|
567 |
+
for idx, batch_data in enumerate(batched_inputs):
|
568 |
+
processed_results.append({})
|
569 |
+
processed_results[-1]["pred_class"] = outputs['pred_logits'][idx,-1]
|
570 |
+
return processed_results
|
571 |
+
|
572 |
+
def evaluate_grounding_baseline(self, batched_inputs, mode):
|
573 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
574 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
575 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
576 |
+
img_bs = images.tensor.shape[0]
|
577 |
+
|
578 |
+
targets = targets_grounding = queries_grounding = None
|
579 |
+
features = self.backbone(images.tensor)
|
580 |
+
outputs = self.sem_seg_head(features, target_queries=queries_grounding)
|
581 |
+
|
582 |
+
mask_pred_results = outputs["pred_masks"]
|
583 |
+
caption_pred_results = outputs["pred_captions"] if self.task_switch['caption'] else [None for i in range(len(mask_pred_results))]
|
584 |
+
|
585 |
+
# upsample masks
|
586 |
+
mask_pred_results = F.interpolate(
|
587 |
+
mask_pred_results,
|
588 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
589 |
+
mode="bicubic",
|
590 |
+
align_corners=False,
|
591 |
+
antialias=True
|
592 |
+
)
|
593 |
+
|
594 |
+
processed_results = []
|
595 |
+
for mask_pred_result, caption_pred_result, input_per_image, image_size in zip(
|
596 |
+
mask_pred_results, caption_pred_results, batched_inputs, images.image_sizes
|
597 |
+
):
|
598 |
+
height = input_per_image.get("height", image_size[0])
|
599 |
+
width = input_per_image.get("width", image_size[1])
|
600 |
+
processed_results.append({})
|
601 |
+
|
602 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
603 |
+
mask_pred_result, image_size, height, width
|
604 |
+
)[:-1]
|
605 |
+
|
606 |
+
texts_all = input_per_image['groundings']['texts']
|
607 |
+
grd_masks = []
|
608 |
+
for texts in texts_all:
|
609 |
+
if mode == 'grounding_refcoco':
|
610 |
+
self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, name='grounding', prompt=False, is_eval=True)
|
611 |
+
elif mode == 'grounding_phrasecut':
|
612 |
+
self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, name='grounding', prompt=True, is_eval=False)
|
613 |
+
t_emb = getattr(self.sem_seg_head.predictor.lang_encoder, "{}_text_embeddings".format('grounding')).t()
|
614 |
+
v_emb = caption_pred_result[:-1]
|
615 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
616 |
+
vt_sim = v_emb @ t_emb
|
617 |
+
max_id = vt_sim.max(0)[1][0]
|
618 |
+
grd_masks += [mask_pred_result[max_id]]
|
619 |
+
processed_results[-1]['grounding_mask'] = torch.stack(grd_masks)
|
620 |
+
|
621 |
+
return processed_results
|
622 |
+
|
623 |
+
def evaluate_grounding(self, batched_inputs, mode):
|
624 |
+
images = [x["image"].to(self.device) for x in batched_inputs]
|
625 |
+
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
626 |
+
images = ImageList.from_tensors(images, self.size_divisibility)
|
627 |
+
|
628 |
+
extra = {}
|
629 |
+
# mask_pred_results = []
|
630 |
+
# for idx, batch_per_image in enumerate(batched_inputs):
|
631 |
+
# grd_texts = batch_per_image['groundings']['texts']
|
632 |
+
# grd_masks = []
|
633 |
+
# for anno_text in grd_texts:
|
634 |
+
# gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
|
635 |
+
# token_emb = gtext['token_emb']
|
636 |
+
# tokens = gtext['tokens']
|
637 |
+
|
638 |
+
# grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
|
639 |
+
# extra['grounding_tokens'] = grd_emb[:,None]
|
640 |
+
|
641 |
+
# assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
|
642 |
+
# features = self.backbone(images.tensor)
|
643 |
+
# outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
644 |
+
|
645 |
+
# pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
|
646 |
+
# v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
|
647 |
+
# t_emb = grd_emb[-1:]
|
648 |
+
|
649 |
+
# t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
650 |
+
# v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
651 |
+
|
652 |
+
# temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
653 |
+
# out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
654 |
+
|
655 |
+
# matched_id = out_prob.max(0)[1]
|
656 |
+
# grd_masks += [pred_gmasks[matched_id,:,:]]
|
657 |
+
# mask_pred_results += [torch.cat(grd_masks)]
|
658 |
+
|
659 |
+
# comment for multi object inference.
|
660 |
+
mask_pred_results = []
|
661 |
+
for idx, batch_per_image in enumerate(batched_inputs):
|
662 |
+
grd_texts = batch_per_image['groundings']['texts']
|
663 |
+
grd_texts = [x[0] for x in grd_texts]
|
664 |
+
|
665 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
|
666 |
+
token_emb = gtext['token_emb']
|
667 |
+
tokens = gtext['tokens']
|
668 |
+
query_emb = token_emb[tokens['attention_mask'].bool()]
|
669 |
+
extra['grounding_tokens'] = query_emb[:,None]
|
670 |
+
|
671 |
+
features = self.backbone(images.tensor)
|
672 |
+
outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
|
673 |
+
|
674 |
+
pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
|
675 |
+
v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
|
676 |
+
t_emb = gtext['class_emb']
|
677 |
+
|
678 |
+
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
679 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
680 |
+
|
681 |
+
temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
|
682 |
+
out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
|
683 |
+
|
684 |
+
matched_id = out_prob.max(0)[1]
|
685 |
+
mask_pred_results += [pred_gmasks[matched_id,:,:]]
|
686 |
+
|
687 |
+
for i in range(len(mask_pred_results)):
|
688 |
+
# upsample masks
|
689 |
+
mask_pred_results[i] = F.interpolate(
|
690 |
+
mask_pred_results[i][None,],
|
691 |
+
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
692 |
+
mode="bicubic",
|
693 |
+
align_corners=False,
|
694 |
+
antialias=True
|
695 |
+
)[0]
|
696 |
+
|
697 |
+
processed_results = []
|
698 |
+
for mask_pred_result, input_per_image, image_size in zip(
|
699 |
+
mask_pred_results, batched_inputs, images.image_sizes
|
700 |
+
):
|
701 |
+
height = input_per_image.get("height", image_size[0])
|
702 |
+
width = input_per_image.get("width", image_size[1])
|
703 |
+
processed_results.append({})
|
704 |
+
|
705 |
+
mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
706 |
+
mask_pred_result, image_size, height, width
|
707 |
+
)
|
708 |
+
processed_results[-1]['grounding_mask'] = mask_pred_result
|
709 |
+
|
710 |
+
# compute bbox
|
711 |
+
# bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
|
712 |
+
# bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
713 |
+
# processed_results[-1]['grounding_box'] = bbox
|
714 |
+
|
715 |
+
return processed_results
|
716 |
+
|
717 |
+
def prepare_vlp_targets(self, batched_inputs, device):
|
718 |
+
input_ids = []
|
719 |
+
attention_mask = []
|
720 |
+
for cnt, x in enumerate(batched_inputs):
|
721 |
+
captions = x['captions']
|
722 |
+
randid = random.randint(0, len(captions)-1)
|
723 |
+
input_ids += x['tokens']['input_ids'][randid:randid+1]
|
724 |
+
attention_mask += x['tokens']['attention_mask'][randid:randid+1]
|
725 |
+
|
726 |
+
input_ids = torch.stack(input_ids)
|
727 |
+
attention_mask = torch.stack(attention_mask)
|
728 |
+
tokens = {"input_ids": input_ids, "attention_mask": attention_mask}
|
729 |
+
lang_results = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(tokens, token=True)
|
730 |
+
|
731 |
+
target_vlp = []
|
732 |
+
for cnt, x in enumerate(batched_inputs):
|
733 |
+
target_dict = {}
|
734 |
+
target_dict["caption_tokens"] = lang_results['token_emb'][cnt:cnt+1]
|
735 |
+
target_dict["caption_proj"] = lang_results['class_emb'][cnt:cnt+1]
|
736 |
+
target_dict["caption_tokenids"] = lang_results['tokens']['input_ids'][cnt:cnt+1]
|
737 |
+
target_dict["caption_mask"] = lang_results['tokens']['attention_mask'][cnt:cnt+1]
|
738 |
+
target_vlp.append(target_dict)
|
739 |
+
return target_vlp
|
740 |
+
|
741 |
+
def prepare_targets(self, batched_inputs, images):
|
742 |
+
h_pad, w_pad = images.tensor.shape[-2:]
|
743 |
+
new_targets = []
|
744 |
+
for idx, batch_per_image in enumerate(batched_inputs):
|
745 |
+
targets_per_image = batch_per_image["instances"].to(self.device)
|
746 |
+
|
747 |
+
# pad gt
|
748 |
+
gt_masks = targets_per_image.gt_masks
|
749 |
+
padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
|
750 |
+
padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
|
751 |
+
|
752 |
+
gt_boxes = targets_per_image.gt_boxes.tensor
|
753 |
+
ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
|
754 |
+
gt_boxes = gt_boxes / ratio
|
755 |
+
xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
|
756 |
+
gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
|
757 |
+
|
758 |
+
target_dict = {
|
759 |
+
"labels": targets_per_image.gt_classes,
|
760 |
+
"is_things": targets_per_image.is_things,
|
761 |
+
"masks": padded_masks,
|
762 |
+
"boxes": gt_boxes
|
763 |
+
}
|
764 |
+
|
765 |
+
if self.task_switch['caption']:
|
766 |
+
caption = batch_per_image["captions"]
|
767 |
+
caption_noun = batch_per_image["captions_noun"]
|
768 |
+
rand_index = random.randint(0, len(caption)-1)
|
769 |
+
|
770 |
+
text = caption[rand_index]
|
771 |
+
nouns = caption_noun[rand_index]
|
772 |
+
noun_captions = [prompt_engineering(noun, topk=10000, suffix='.') for noun in nouns] + [text]
|
773 |
+
|
774 |
+
self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(noun_captions, is_eval=False, name='caption_noun', prompt=False)
|
775 |
+
ctext = getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('caption_noun'))
|
776 |
+
target_dict["captions"] = ctext
|
777 |
+
|
778 |
+
target_dict["captions_hash"] = [(hash(st.stem(txt)) % 10**16) for txt in (nouns + [text])]
|
779 |
+
target_dict["labels_hash"] = [(hash(st.stem(COCO_PANOPTIC_CLASSES[label_id].replace('-other','').replace('-merged','').replace('-stuff',''))) % 10**16) for label_id in target_dict['labels']]
|
780 |
+
|
781 |
+
if self.task_switch['grounding']:
|
782 |
+
grd_masks = batch_per_image['groundings']['masks']
|
783 |
+
grd_texts = batch_per_image['groundings']['texts']
|
784 |
+
grd_hash = batch_per_image['groundings']['hash']
|
785 |
+
grd_task = batch_per_image['groundings']['mode']
|
786 |
+
|
787 |
+
if len(grd_masks) == 0:
|
788 |
+
padded_masks = None
|
789 |
+
else:
|
790 |
+
padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
|
791 |
+
padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
|
792 |
+
|
793 |
+
gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
|
794 |
+
token_emb = gtext['token_emb']
|
795 |
+
tokens = gtext['tokens']
|
796 |
+
|
797 |
+
unique_hash_id = np.unique(grd_hash, return_index=True)[1]
|
798 |
+
selected_mask = np.zeros(len(grd_hash)).astype(np.bool)
|
799 |
+
selected_mask[unique_hash_id] = True
|
800 |
+
|
801 |
+
selected_token_emb = token_emb[selected_mask]
|
802 |
+
selected_attn_mask = tokens['attention_mask'][selected_mask]
|
803 |
+
query_emb = selected_token_emb[selected_attn_mask.bool()]
|
804 |
+
|
805 |
+
class_idx = tokens['attention_mask'].sum(dim=-1) - 1
|
806 |
+
class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
|
807 |
+
class_emb = token_emb[class_idx]
|
808 |
+
|
809 |
+
target_dict['grounding_masks'] = padded_masks
|
810 |
+
target_dict['grounding_query_embs'] = query_emb
|
811 |
+
target_dict['grounding_class_embs'] = class_emb
|
812 |
+
target_dict['grounding_hash'] = grd_hash
|
813 |
+
target_dict['grounding_task'] = grd_task
|
814 |
+
|
815 |
+
new_targets.append(target_dict)
|
816 |
+
return new_targets
|
817 |
+
|
818 |
+
def semantic_inference(self, mask_cls, mask_pred, keep_sem_bgd=False):
|
819 |
+
if keep_sem_bgd:
|
820 |
+
mask_cls = F.softmax(mask_cls, dim=-1)
|
821 |
+
else:
|
822 |
+
mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
|
823 |
+
mask_pred = mask_pred.sigmoid()
|
824 |
+
semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
|
825 |
+
return semseg
|
826 |
+
|
827 |
+
def panoptic_inference(self, mask_cls, mask_pred):
|
828 |
+
scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
|
829 |
+
mask_pred = mask_pred.sigmoid()
|
830 |
+
|
831 |
+
keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
|
832 |
+
cur_scores = scores[keep]
|
833 |
+
cur_classes = labels[keep]
|
834 |
+
cur_masks = mask_pred[keep]
|
835 |
+
cur_mask_cls = mask_cls[keep]
|
836 |
+
cur_mask_cls = cur_mask_cls[:, :-1]
|
837 |
+
cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
|
838 |
+
|
839 |
+
h, w = cur_masks.shape[-2:]
|
840 |
+
panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
|
841 |
+
segments_info = []
|
842 |
+
|
843 |
+
current_segment_id = 0
|
844 |
+
|
845 |
+
if cur_masks.shape[0] == 0:
|
846 |
+
# We didn't detect any mask :(
|
847 |
+
return panoptic_seg, segments_info
|
848 |
+
else:
|
849 |
+
# take argmax
|
850 |
+
cur_mask_ids = cur_prob_masks.argmax(0)
|
851 |
+
stuff_memory_list = {}
|
852 |
+
thing_dataset_id_to_contiguous_id = self.metadata.thing_dataset_id_to_contiguous_id if hasattr(self.metadata, 'thing_dataset_id_to_contiguous_id') else {}
|
853 |
+
for k in range(cur_classes.shape[0]):
|
854 |
+
pred_class = cur_classes[k].item()
|
855 |
+
isthing = pred_class in thing_dataset_id_to_contiguous_id.values()
|
856 |
+
mask_area = (cur_mask_ids == k).sum().item()
|
857 |
+
original_area = (cur_masks[k] >= 0.5).sum().item()
|
858 |
+
mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
|
859 |
+
|
860 |
+
if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
|
861 |
+
if mask_area / original_area < self.overlap_threshold:
|
862 |
+
continue
|
863 |
+
|
864 |
+
# merge stuff regions
|
865 |
+
if not isthing:
|
866 |
+
if int(pred_class) in stuff_memory_list.keys():
|
867 |
+
panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
|
868 |
+
continue
|
869 |
+
else:
|
870 |
+
stuff_memory_list[int(pred_class)] = current_segment_id + 1
|
871 |
+
|
872 |
+
current_segment_id += 1
|
873 |
+
panoptic_seg[mask] = current_segment_id
|
874 |
+
|
875 |
+
segments_info.append(
|
876 |
+
{
|
877 |
+
"id": current_segment_id,
|
878 |
+
"isthing": bool(isthing),
|
879 |
+
"category_id": int(pred_class),
|
880 |
+
}
|
881 |
+
)
|
882 |
+
return panoptic_seg, segments_info
|
883 |
+
|
884 |
+
def instance_inference(self, mask_cls, mask_pred, box_pred):
|
885 |
+
# mask_pred is already processed to have the same shape as original input
|
886 |
+
image_size = mask_pred.shape[-2:]
|
887 |
+
|
888 |
+
# [Q, K]
|
889 |
+
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
|
890 |
+
labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
|
891 |
+
# scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
|
892 |
+
scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
|
893 |
+
|
894 |
+
labels_per_image = labels[topk_indices]
|
895 |
+
topk_indices = (topk_indices // self.sem_seg_head.num_classes)
|
896 |
+
# mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
|
897 |
+
mask_pred = mask_pred[topk_indices]
|
898 |
+
if box_pred is not None:
|
899 |
+
box_pred = box_pred[topk_indices]
|
900 |
+
|
901 |
+
# if this is panoptic segmentation, we only keep the "thing" classes
|
902 |
+
if self.panoptic_on:
|
903 |
+
thing_dataset_id_to_contiguous_id = self.metadata.thing_dataset_id_to_contiguous_id if hasattr(self.metadata, 'thing_dataset_id_to_contiguous_id') else {}
|
904 |
+
keep = torch.zeros_like(scores_per_image).bool()
|
905 |
+
for i, lab in enumerate(labels_per_image):
|
906 |
+
keep[i] = lab in thing_dataset_id_to_contiguous_id.values()
|
907 |
+
|
908 |
+
scores_per_image = scores_per_image[keep]
|
909 |
+
labels_per_image = labels_per_image[keep]
|
910 |
+
mask_pred = mask_pred[keep]
|
911 |
+
|
912 |
+
if box_pred is not None:
|
913 |
+
box_pred = box_pred[keep]
|
914 |
+
|
915 |
+
result = Instances(image_size)
|
916 |
+
# mask (before sigmoid)
|
917 |
+
result.pred_masks = (mask_pred > 0).float()
|
918 |
+
# result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
919 |
+
# Uncomment the following to get boxes from masks (this is slow)
|
920 |
+
|
921 |
+
if box_pred is not None:
|
922 |
+
result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
|
923 |
+
else:
|
924 |
+
result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
925 |
+
|
926 |
+
# calculate average mask prob
|
927 |
+
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
|
928 |
+
result.scores = scores_per_image * mask_scores_per_image
|
929 |
+
result.pred_classes = labels_per_image
|
930 |
+
|
931 |
+
return result
|
932 |
+
|
933 |
+
|
934 |
+
|
935 |
+
@register_model
|
936 |
+
def get_xdecoder_model(cfg, **kwargs):
|
937 |
+
return GeneralizedXdecoder(cfg)
|
modeling/body/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .xdecoder_head import *
|
2 |
+
from .build import *
|
3 |
+
|
4 |
+
def build_xdecoder_head(config, *args, **kwargs):
|
5 |
+
model_name = config['MODEL']['HEAD']
|
6 |
+
if not is_model(model_name):
|
7 |
+
raise ValueError(f'Unkown model: {model_name}')
|
8 |
+
|
9 |
+
body = model_entrypoints(model_name)(config, *args, **kwargs)
|
10 |
+
return body
|
modeling/body/build.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_model_entrypoints = {}
|
2 |
+
|
3 |
+
def register_body(fn):
|
4 |
+
module_name_split = fn.__module__.split('.')
|
5 |
+
model_name = module_name_split[-1]
|
6 |
+
_model_entrypoints[model_name] = fn
|
7 |
+
return fn
|
8 |
+
|
9 |
+
def model_entrypoints(model_name):
|
10 |
+
return _model_entrypoints[model_name]
|
11 |
+
|
12 |
+
def is_model(model_name):
|
13 |
+
return model_name in _model_entrypoints
|
modeling/body/xdecoder_head.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Xueyan Zou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
8 |
+
from typing import Dict
|
9 |
+
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from detectron2.layers import ShapeSpec
|
13 |
+
|
14 |
+
from .build import register_body
|
15 |
+
from ..vision.encoder import build_encoder
|
16 |
+
from ..interface import build_decoder
|
17 |
+
from ..utils import configurable
|
18 |
+
|
19 |
+
|
20 |
+
class XdecoderHead(nn.Module):
|
21 |
+
|
22 |
+
@configurable
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
input_shape: Dict[str, ShapeSpec],
|
26 |
+
*,
|
27 |
+
num_classes: int,
|
28 |
+
pixel_decoder: nn.Module,
|
29 |
+
loss_weight: float = 1.0,
|
30 |
+
ignore_value: int = -1,
|
31 |
+
# extra parameters
|
32 |
+
transformer_predictor: nn.Module,
|
33 |
+
transformer_in_feature: str,
|
34 |
+
binary_classes: bool,
|
35 |
+
):
|
36 |
+
"""
|
37 |
+
NOTE: this interface is experimental.
|
38 |
+
Args:
|
39 |
+
input_shape: shapes (channels and stride) of the input features
|
40 |
+
num_classes: number of classes to predict
|
41 |
+
pixel_decoder: the pixel decoder module
|
42 |
+
loss_weight: loss weight
|
43 |
+
ignore_value: category id to be ignored during training.
|
44 |
+
transformer_predictor: the transformer decoder that makes prediction
|
45 |
+
transformer_in_feature: input feature name to the transformer_predictor
|
46 |
+
"""
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
|
50 |
+
self.in_features = [k for k, v in input_shape]
|
51 |
+
feature_strides = [v.stride for k, v in input_shape]
|
52 |
+
feature_channels = [v.channels for k, v in input_shape]
|
53 |
+
|
54 |
+
self.ignore_value = ignore_value
|
55 |
+
self.common_stride = 4
|
56 |
+
self.loss_weight = loss_weight
|
57 |
+
|
58 |
+
self.pixel_decoder = pixel_decoder
|
59 |
+
self.predictor = transformer_predictor
|
60 |
+
self.transformer_in_feature = transformer_in_feature
|
61 |
+
|
62 |
+
self.num_classes = num_classes
|
63 |
+
|
64 |
+
if binary_classes:
|
65 |
+
self.num_classes = 1
|
66 |
+
|
67 |
+
@classmethod
|
68 |
+
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict):
|
69 |
+
|
70 |
+
in_features_type = cfg['MODEL']['DECODER']['TRANSFORMER_IN_FEATURE']
|
71 |
+
enc_cfg = cfg['MODEL']['ENCODER']
|
72 |
+
dec_cfg = cfg['MODEL']['DECODER']
|
73 |
+
|
74 |
+
# figure out in_channels to transformer predictor
|
75 |
+
if in_features_type == "transformer_encoder":
|
76 |
+
transformer_predictor_in_channels = enc_cfg['CONVS_DIM']
|
77 |
+
elif in_features_type == "pixel_embedding":
|
78 |
+
transformer_predictor_in_channels = enc_cfg['MASK_DIM']
|
79 |
+
elif in_features_type == "multi_scale_pixel_decoder":
|
80 |
+
transformer_predictor_in_channels = enc_cfg['CONVS_DIM']
|
81 |
+
else:
|
82 |
+
transformer_predictor_in_channels = input_shape[dec_cfg['TRANSFORMER_IN_FEATURE']].channels
|
83 |
+
|
84 |
+
return {
|
85 |
+
"input_shape": {
|
86 |
+
k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']
|
87 |
+
},
|
88 |
+
"ignore_value": enc_cfg['IGNORE_VALUE'],
|
89 |
+
"num_classes": enc_cfg.get('NUM_CLASSES', None),
|
90 |
+
"pixel_decoder": build_encoder(cfg, input_shape),
|
91 |
+
"loss_weight": enc_cfg['LOSS_WEIGHT'],
|
92 |
+
"transformer_in_feature": dec_cfg['TRANSFORMER_IN_FEATURE'],
|
93 |
+
"transformer_predictor": build_decoder(
|
94 |
+
cfg,
|
95 |
+
transformer_predictor_in_channels,
|
96 |
+
lang_encoder,
|
97 |
+
mask_classification=True,
|
98 |
+
extra=extra,
|
99 |
+
),
|
100 |
+
"binary_classes": enc_cfg['BINARY_CLASSES']
|
101 |
+
}
|
102 |
+
|
103 |
+
def forward(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
|
104 |
+
return self.layers(features, mask, target_queries, target_vlp, task, extra)
|
105 |
+
|
106 |
+
def layers(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
|
107 |
+
mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features)
|
108 |
+
|
109 |
+
if self.transformer_in_feature == "multi_scale_pixel_decoder":
|
110 |
+
predictions = self.predictor(multi_scale_features, mask_features, mask, target_queries, target_vlp, task, extra)
|
111 |
+
else:
|
112 |
+
if self.transformer_in_feature == "transformer_encoder":
|
113 |
+
assert (
|
114 |
+
transformer_encoder_features is not None
|
115 |
+
), "Please use the TransformerEncoderPixelDecoder."
|
116 |
+
predictions = self.predictor(transformer_encoder_features, mask_features, mask)
|
117 |
+
elif self.transformer_in_feature == "pixel_embedding":
|
118 |
+
predictions = self.predictor(mask_features, mask_features, mask)
|
119 |
+
else:
|
120 |
+
predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask)
|
121 |
+
return predictions
|
122 |
+
|
123 |
+
|
124 |
+
@register_body
|
125 |
+
def get_xdecoder_head(cfg, input_shape, lang_encoder, extra):
|
126 |
+
return XdecoderHead(cfg, input_shape, lang_encoder, extra)
|
modeling/interface/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .xdecoder import *
|
2 |
+
from .seem_v0 import *
|
3 |
+
from .seem_v1 import *
|
4 |
+
from .seem_demo import *
|
5 |
+
from .build import *
|
6 |
+
|
7 |
+
def build_decoder(config, *args, **kwargs):
|
8 |
+
model_name = config['MODEL']['DECODER']['NAME']
|
9 |
+
|
10 |
+
if not is_model(model_name):
|
11 |
+
raise ValueError(f'Unkown model: {model_name}')
|
12 |
+
|
13 |
+
return model_entrypoints(model_name)(config, *args, **kwargs)
|
modeling/interface/build.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_model_entrypoints = {}
|
2 |
+
|
3 |
+
|
4 |
+
def register_decoder(fn):
|
5 |
+
module_name_split = fn.__module__.split('.')
|
6 |
+
model_name = module_name_split[-1]
|
7 |
+
_model_entrypoints[model_name] = fn
|
8 |
+
return fn
|
9 |
+
|
10 |
+
def model_entrypoints(model_name):
|
11 |
+
return _model_entrypoints[model_name]
|
12 |
+
|
13 |
+
def is_model(model_name):
|
14 |
+
return model_name in _model_entrypoints
|
modeling/interface/modules.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn, Tensor
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from timm.models.layers import trunc_normal_
|
8 |
+
from detectron2.layers import Conv2d
|
9 |
+
import fvcore.nn.weight_init as weight_init
|
10 |
+
|
11 |
+
from ..utils import MultiheadAttention
|
12 |
+
|
13 |
+
|
14 |
+
class SelfAttentionLayer(nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, d_model, nhead, dropout=0.0,
|
17 |
+
activation="relu", normalize_before=False):
|
18 |
+
super().__init__()
|
19 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
20 |
+
|
21 |
+
self.norm = nn.LayerNorm(d_model)
|
22 |
+
self.dropout = nn.Dropout(dropout)
|
23 |
+
|
24 |
+
self.activation = _get_activation_fn(activation)
|
25 |
+
self.normalize_before = normalize_before
|
26 |
+
|
27 |
+
self._reset_parameters()
|
28 |
+
|
29 |
+
def _reset_parameters(self):
|
30 |
+
for p in self.parameters():
|
31 |
+
if p.dim() > 1:
|
32 |
+
nn.init.xavier_uniform_(p)
|
33 |
+
|
34 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
35 |
+
return tensor if pos is None else tensor + pos
|
36 |
+
|
37 |
+
def forward_post(self, tgt,
|
38 |
+
tgt_mask: Optional[Tensor] = None,
|
39 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
40 |
+
query_pos: Optional[Tensor] = None):
|
41 |
+
|
42 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
43 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
44 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
45 |
+
tgt = tgt + self.dropout(tgt2)
|
46 |
+
tgt = self.norm(tgt)
|
47 |
+
return tgt
|
48 |
+
|
49 |
+
def forward_pre(self, tgt,
|
50 |
+
tgt_mask: Optional[Tensor] = None,
|
51 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
52 |
+
query_pos: Optional[Tensor] = None):
|
53 |
+
tgt2 = self.norm(tgt)
|
54 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
55 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
56 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
57 |
+
tgt = tgt + self.dropout(tgt2)
|
58 |
+
|
59 |
+
return tgt
|
60 |
+
|
61 |
+
def forward(self, tgt,
|
62 |
+
tgt_mask: Optional[Tensor] = None,
|
63 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
64 |
+
query_pos: Optional[Tensor] = None):
|
65 |
+
if self.normalize_before:
|
66 |
+
return self.forward_pre(tgt, tgt_mask,
|
67 |
+
tgt_key_padding_mask, query_pos)
|
68 |
+
return self.forward_post(tgt, tgt_mask,
|
69 |
+
tgt_key_padding_mask, query_pos)
|
70 |
+
|
71 |
+
|
72 |
+
class CrossAttentionLayer(nn.Module):
|
73 |
+
|
74 |
+
def __init__(self, d_model, nhead, dropout=0.0,
|
75 |
+
activation="relu", normalize_before=False):
|
76 |
+
super().__init__()
|
77 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
78 |
+
|
79 |
+
self.norm = nn.LayerNorm(d_model)
|
80 |
+
self.dropout = nn.Dropout(dropout)
|
81 |
+
|
82 |
+
self.activation = _get_activation_fn(activation)
|
83 |
+
self.normalize_before = normalize_before
|
84 |
+
|
85 |
+
self._reset_parameters()
|
86 |
+
|
87 |
+
def _reset_parameters(self):
|
88 |
+
for p in self.parameters():
|
89 |
+
if p.dim() > 1:
|
90 |
+
nn.init.xavier_uniform_(p)
|
91 |
+
|
92 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
93 |
+
return tensor if pos is None else tensor + pos
|
94 |
+
|
95 |
+
def forward_post(self, tgt, memory,
|
96 |
+
memory_mask: Optional[Tensor] = None,
|
97 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
98 |
+
pos: Optional[Tensor] = None,
|
99 |
+
query_pos: Optional[Tensor] = None):
|
100 |
+
tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
101 |
+
key=self.with_pos_embed(memory, pos),
|
102 |
+
value=memory, attn_mask=memory_mask,
|
103 |
+
key_padding_mask=memory_key_padding_mask)
|
104 |
+
tgt = tgt + self.dropout(tgt2)
|
105 |
+
tgt = self.norm(tgt)
|
106 |
+
return tgt, avg_attn
|
107 |
+
|
108 |
+
def forward_pre(self, tgt, memory,
|
109 |
+
memory_mask: Optional[Tensor] = None,
|
110 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
111 |
+
pos: Optional[Tensor] = None,
|
112 |
+
query_pos: Optional[Tensor] = None):
|
113 |
+
tgt2 = self.norm(tgt)
|
114 |
+
tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
115 |
+
key=self.with_pos_embed(memory, pos),
|
116 |
+
value=memory, attn_mask=memory_mask,
|
117 |
+
key_padding_mask=memory_key_padding_mask)
|
118 |
+
tgt = tgt + self.dropout(tgt2)
|
119 |
+
|
120 |
+
return tgt, avg_attn
|
121 |
+
|
122 |
+
def forward(self, tgt, memory,
|
123 |
+
memory_mask: Optional[Tensor] = None,
|
124 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
125 |
+
pos: Optional[Tensor] = None,
|
126 |
+
query_pos: Optional[Tensor] = None):
|
127 |
+
if self.normalize_before:
|
128 |
+
return self.forward_pre(tgt, memory, memory_mask,
|
129 |
+
memory_key_padding_mask, pos, query_pos)
|
130 |
+
return self.forward_post(tgt, memory, memory_mask,
|
131 |
+
memory_key_padding_mask, pos, query_pos)
|
132 |
+
|
133 |
+
|
134 |
+
class FFNLayer(nn.Module):
|
135 |
+
|
136 |
+
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
|
137 |
+
activation="relu", normalize_before=False):
|
138 |
+
super().__init__()
|
139 |
+
# Implementation of Feedforward model
|
140 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
141 |
+
self.dropout = nn.Dropout(dropout)
|
142 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
143 |
+
|
144 |
+
self.norm = nn.LayerNorm(d_model)
|
145 |
+
|
146 |
+
self.activation = _get_activation_fn(activation)
|
147 |
+
self.normalize_before = normalize_before
|
148 |
+
|
149 |
+
self._reset_parameters()
|
150 |
+
|
151 |
+
def _reset_parameters(self):
|
152 |
+
for p in self.parameters():
|
153 |
+
if p.dim() > 1:
|
154 |
+
nn.init.xavier_uniform_(p)
|
155 |
+
|
156 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
157 |
+
return tensor if pos is None else tensor + pos
|
158 |
+
|
159 |
+
def forward_post(self, tgt):
|
160 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
161 |
+
tgt = tgt + self.dropout(tgt2)
|
162 |
+
tgt = self.norm(tgt)
|
163 |
+
return tgt
|
164 |
+
|
165 |
+
def forward_pre(self, tgt):
|
166 |
+
tgt2 = self.norm(tgt)
|
167 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
168 |
+
tgt = tgt + self.dropout(tgt2)
|
169 |
+
return tgt
|
170 |
+
|
171 |
+
def forward(self, tgt):
|
172 |
+
if self.normalize_before:
|
173 |
+
return self.forward_pre(tgt)
|
174 |
+
return self.forward_post(tgt)
|
175 |
+
|
176 |
+
|
177 |
+
def _get_activation_fn(activation):
|
178 |
+
"""Return an activation function given a string"""
|
179 |
+
if activation == "relu":
|
180 |
+
return F.relu
|
181 |
+
if activation == "gelu":
|
182 |
+
return F.gelu
|
183 |
+
if activation == "glu":
|
184 |
+
return F.glu
|
185 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
186 |
+
|
187 |
+
|
188 |
+
class MLP(nn.Module):
|
189 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
190 |
+
|
191 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
192 |
+
super().__init__()
|
193 |
+
self.num_layers = num_layers
|
194 |
+
h = [hidden_dim] * (num_layers - 1)
|
195 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
for i, layer in enumerate(self.layers):
|
199 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
200 |
+
return x
|
modeling/interface/prototype/__init__.py
ADDED
File without changes
|
modeling/interface/prototype/attention_data_struct_seemdemo.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Xueyan Zou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
predict_name_matcher = {"predictions_class": ["pred_logits"],
|
13 |
+
"predictions_mask":["pred_masks", "pred_gmasks", "pred_smasks"],
|
14 |
+
"predictions_caption":["pred_captions", "pred_gtexts"],
|
15 |
+
"predictions_maskemb":["pred_maskembs", "pred_smaskembs"],
|
16 |
+
"predictions_pos_spatial":["pred_pspatials"],
|
17 |
+
"predictions_neg_spatial":["pred_nspatials"],
|
18 |
+
"predictions_pos_visual":["pred_pvisuals"],
|
19 |
+
"predictions_neg_visual":["pred_nvisuals"]}
|
20 |
+
|
21 |
+
predict_index_matcher = {"predictions_class": ["queries_object"],
|
22 |
+
"predictions_mask":["queries_object", "queries_grounding", "queries_spatial"],
|
23 |
+
"predictions_caption": ["queries_object", "queries_grounding"],
|
24 |
+
"predictions_maskemb":["queries_object", "queries_spatial"],
|
25 |
+
"predictions_pos_spatial":["all"],
|
26 |
+
"predictions_neg_spatial":["all"],
|
27 |
+
"predictions_pos_visual":["all"],
|
28 |
+
"predictions_neg_visual":["all"]}
|
29 |
+
|
30 |
+
class Variable(object):
|
31 |
+
'''
|
32 |
+
Store dataset variable for attention
|
33 |
+
output: embedding that accumuates during cross/self attention
|
34 |
+
pos: positional embedding that is fixed during cross/self attention
|
35 |
+
name: name of the variable
|
36 |
+
type: type of the variable, e.g. queries, tokens
|
37 |
+
attn_mask: attention mask for corss attention
|
38 |
+
masking: masking for padding
|
39 |
+
'''
|
40 |
+
def __init__(self, output, name, _type, pos=None):
|
41 |
+
self.output = output
|
42 |
+
self.pos = pos
|
43 |
+
self.name = name
|
44 |
+
self.type = _type
|
45 |
+
self.attn_mask = None
|
46 |
+
self.masking = None
|
47 |
+
|
48 |
+
def copy(self,):
|
49 |
+
output = self.output.clone() if self.output is not None else None
|
50 |
+
pos = self.pos.clone() if self.pos is not None else None
|
51 |
+
return Variable(output, self.name, self.type, pos)
|
52 |
+
|
53 |
+
class AttentionDataStruct(nn.Module):
|
54 |
+
'''
|
55 |
+
Store dataset structure for cross/self attention
|
56 |
+
task_switch: switch for different tasks
|
57 |
+
|
58 |
+
p_attn_variables: prototype of variables that is used in cross/self attention
|
59 |
+
p_self_attn: prototype of variables that is used in self attention
|
60 |
+
p_cross_attn: prototype of variables that is used in cross attention
|
61 |
+
p_iter: prototype of iteration for different queries
|
62 |
+
p_masking: prototype of masking for different tokens
|
63 |
+
p_duplication: prototype of duplication for different quries
|
64 |
+
'''
|
65 |
+
def __init__(self, attn_arch, task_switch):
|
66 |
+
super(AttentionDataStruct, self).__init__()
|
67 |
+
self.task_switch = task_switch
|
68 |
+
|
69 |
+
# p stands for prototype
|
70 |
+
self.p_attn_variables = attn_arch['VARIABLE']
|
71 |
+
self.p_self_attn = attn_arch['SELF_ATTENTION']
|
72 |
+
self.p_cross_attn = attn_arch['CROSS_ATTENTION']
|
73 |
+
self.p_masking = attn_arch['MASKING']
|
74 |
+
self.p_duplication = attn_arch['DUPLICATION']
|
75 |
+
|
76 |
+
self.num_layers = attn_arch['NUM_LAYERS']
|
77 |
+
|
78 |
+
def reset(self, flags, task, extra):
|
79 |
+
# reset variables
|
80 |
+
self.attn_variables = {}
|
81 |
+
self.cross_attn_dict = {}
|
82 |
+
self.self_attn_dict = {}
|
83 |
+
self.duplication_dict = {}
|
84 |
+
self.query_index = {}
|
85 |
+
self.output = {}
|
86 |
+
self.flags = {}
|
87 |
+
self.spatial_memory = {}
|
88 |
+
|
89 |
+
# initialize duplication
|
90 |
+
for key, values in self.p_duplication.items():
|
91 |
+
for name in values:
|
92 |
+
self.duplication_dict["{}_{}".format(key, name)] = self.p_duplication[key][name]
|
93 |
+
|
94 |
+
# initialize flag
|
95 |
+
self.flags = {"object": True}
|
96 |
+
self.flags.update(flags)
|
97 |
+
|
98 |
+
# initialize task
|
99 |
+
self.task = task
|
100 |
+
|
101 |
+
# initialize output
|
102 |
+
if self.task_switch['mask']:
|
103 |
+
self.output['predictions_class'] = []
|
104 |
+
self.output['predictions_mask'] = []
|
105 |
+
self.output['predictions_maskemb'] = []
|
106 |
+
|
107 |
+
if self.task_switch['bbox']:
|
108 |
+
self.output['predictions_bbox'] = []
|
109 |
+
|
110 |
+
if self.task_switch['spatial'] and ('spatial' in self.flags and self.flags['spatial']==True):
|
111 |
+
self.output['predictions_pos_spatial'] = []
|
112 |
+
self.output['predictions_neg_spatial'] = []
|
113 |
+
|
114 |
+
if self.task_switch['spatial'] and ('memories_spatial' in self.flags and self.flags['memories_spatial']==True):
|
115 |
+
self.spatial_memory['prev_batch_mask'] = extra['prev_mask']
|
116 |
+
|
117 |
+
if (self.task_switch['grounding'] and ('grounding' in self.flags and self.flags['grounding']==True)) \
|
118 |
+
or (self.task_switch['audio'] and ('audio' in self.flags and self.flags['audio']==True)):
|
119 |
+
self.output['predictions_caption'] = []
|
120 |
+
|
121 |
+
if self.task_switch['visual']:
|
122 |
+
self.output['predictions_pos_visual'] = []
|
123 |
+
self.output['predictions_neg_visual'] = []
|
124 |
+
|
125 |
+
# initialize cross_attn, whether the variable is used in cross attention
|
126 |
+
for key, values in self.p_cross_attn.items():
|
127 |
+
for name in values:
|
128 |
+
self.cross_attn_dict["{}_{}".format(key, name)] = self.p_cross_attn[key][name]
|
129 |
+
|
130 |
+
# initialize self_attn, whether the variable is used in self attention, and the interactions between queries
|
131 |
+
for key, values in self.p_self_attn.items():
|
132 |
+
for name in values:
|
133 |
+
self.self_attn_dict["{}_{}".format(key, name)] = self.p_self_attn[key][name]
|
134 |
+
|
135 |
+
# initialize masking
|
136 |
+
self.masking = self.p_masking
|
137 |
+
|
138 |
+
# initialize query_index
|
139 |
+
self.query_index = {"all":[0, None]}
|
140 |
+
|
141 |
+
|
142 |
+
def set(self, name, _type, output=None, pos=None, var=None):
|
143 |
+
if var is not None:
|
144 |
+
self.attn_variables[name] = var
|
145 |
+
elif name in self.duplication_dict:
|
146 |
+
assert self.duplication_dict[name] in self.attn_variables, "Duplication variable {} is not initialized yet.".format(name)
|
147 |
+
self.attn_variables[name] = self.attn_variables[self.duplication_dict[name]].copy()
|
148 |
+
else:
|
149 |
+
var = Variable(output, name, _type, pos)
|
150 |
+
self.attn_variables[name] = var
|
151 |
+
|
152 |
+
def set_results(self, results):
|
153 |
+
for name in self.cross_attn_name:
|
154 |
+
self.attn_variables[name].attn_mask = results['attn_mask'][:,self.query_index[name][0]:self.query_index[name][1]]
|
155 |
+
for key in self.output:
|
156 |
+
self.output[key].append(results[key])
|
157 |
+
|
158 |
+
def set_maskings(self, name, masking):
|
159 |
+
self.attn_variables[name].masking = masking
|
160 |
+
|
161 |
+
def cross_attn_variables(self, ):
|
162 |
+
cross_attn_name = [key for key, value in self.cross_attn_dict.items()
|
163 |
+
if (value==True) and (key in self.attn_variables)
|
164 |
+
and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
|
165 |
+
self.cross_attn_name = cross_attn_name
|
166 |
+
|
167 |
+
output = torch.cat([self.attn_variables[name].output for name in cross_attn_name])
|
168 |
+
pos_emb = torch.cat([self.attn_variables[name].pos for name in cross_attn_name])
|
169 |
+
|
170 |
+
index = 0
|
171 |
+
for name in cross_attn_name:
|
172 |
+
self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
|
173 |
+
index += self.attn_variables[name].output.shape[0]
|
174 |
+
return output, pos_emb
|
175 |
+
|
176 |
+
def cross_attn_mask(self, size, num_heads):
|
177 |
+
attn_mask = torch.cat([self.attn_variables[name].attn_mask for name in self.cross_attn_name], dim=1)
|
178 |
+
|
179 |
+
# hard code memories_spatial to previous selected mask
|
180 |
+
if 'memories_spatial' in self.cross_attn_name:
|
181 |
+
memory_attn_mask = self.spatial_memory['prev_batch_mask']
|
182 |
+
bs,c,_,_ = memory_attn_mask.shape
|
183 |
+
memory_attn_mask = F.interpolate(memory_attn_mask, size, mode='bilinear', align_corners=False)
|
184 |
+
memory_attn_mask = (memory_attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, num_heads, 1, 1).flatten(0, 1) < 0.5).bool().detach()
|
185 |
+
attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = memory_attn_mask
|
186 |
+
|
187 |
+
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
|
188 |
+
return attn_mask
|
189 |
+
|
190 |
+
def self_attn(self, bs, num_heads):
|
191 |
+
self_attn_name = [key for key, value in self.self_attn_dict.items()
|
192 |
+
if len(value)>0 and key in self.attn_variables
|
193 |
+
and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
|
194 |
+
self.self_attn_name = self_attn_name
|
195 |
+
|
196 |
+
output = torch.cat([self.attn_variables[name].output for name in self_attn_name])
|
197 |
+
pos_emb = torch.cat([self.attn_variables[name].pos for name in self_attn_name])
|
198 |
+
|
199 |
+
index = 0
|
200 |
+
for name in self_attn_name:
|
201 |
+
self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
|
202 |
+
index += self.attn_variables[name].output.shape[0]
|
203 |
+
|
204 |
+
self_attn_mask = torch.ones((bs, output.shape[0], output.shape[0]), dtype=torch.bool, device=output.device)
|
205 |
+
self_attn_pair = []
|
206 |
+
# build self_attention mask by query interaction
|
207 |
+
for key1, value in self.self_attn_dict.items():
|
208 |
+
for key2 in value:
|
209 |
+
if key1 not in self_attn_name or key2 not in self_attn_name:
|
210 |
+
# exclude the variables that are not used in the current layer
|
211 |
+
continue
|
212 |
+
if (key1 in self.masking or key2 in self.masking) and (key1 != key2):
|
213 |
+
self_attn_pair += [[key1, key2]]
|
214 |
+
self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1], self.query_index[key2][0]:self.query_index[key2][1]] = False
|
215 |
+
|
216 |
+
# build self_attention mask by masking, for birectional
|
217 |
+
for key in self.masking:
|
218 |
+
if key in self_attn_name:
|
219 |
+
self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]][self.attn_variables[key].masking] = True
|
220 |
+
self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]].transpose(1,2)[self.attn_variables[key].masking] = True
|
221 |
+
|
222 |
+
# build self_attention mask by masking, for uni-directional
|
223 |
+
for key1, key2 in self_attn_pair:
|
224 |
+
if key1 not in self_attn_name or key2 not in self_attn_name:
|
225 |
+
# exclude the variables that are not used in the current layer
|
226 |
+
continue
|
227 |
+
if key1 in self.masking:
|
228 |
+
self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]][self.attn_variables[key1].masking] = True # HACK, not verified
|
229 |
+
if key2 in self.masking:
|
230 |
+
self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]].transpose(1,2)[self.attn_variables[key2].masking] = True
|
231 |
+
|
232 |
+
self_attn_mask = self_attn_mask.repeat_interleave(num_heads, dim=0)
|
233 |
+
return output, pos_emb, self_attn_mask
|
234 |
+
|
235 |
+
def update_variables(self, output, mode):
|
236 |
+
name_set = self.self_attn_name if mode=='self_attn' else self.cross_attn_name
|
237 |
+
for key in name_set:
|
238 |
+
self.attn_variables[key].output = output[self.query_index[key][0]:self.query_index[key][1]]
|
239 |
+
|
240 |
+
def update_spatial_results(self, results):
|
241 |
+
v_emb = results['pred_smaskembs']
|
242 |
+
pred_smasks = results['pred_smasks']
|
243 |
+
|
244 |
+
s_emb = results['pred_pspatials']
|
245 |
+
pred_logits = v_emb @ s_emb.transpose(1,2)
|
246 |
+
logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
|
247 |
+
logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
|
248 |
+
logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
|
249 |
+
pred_masks_pos = pred_smasks[logits_idx][:,None,]
|
250 |
+
|
251 |
+
extra = {"prev_mask": pred_masks_pos}
|
252 |
+
return extra
|
253 |
+
|
254 |
+
def organize_output(self, ):
|
255 |
+
outputs = {}
|
256 |
+
outputs['aux_outputs'] = [{} for i in range(self.num_layers)]
|
257 |
+
|
258 |
+
for key, values in self.output.items():
|
259 |
+
for _key, idx_name in zip(predict_name_matcher[key], predict_index_matcher[key]):
|
260 |
+
if idx_name not in self.query_index:
|
261 |
+
continue
|
262 |
+
outputs[_key] = self.output[key][-1][:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
|
263 |
+
for idx, aux_values in enumerate(self.output[key][:-1]):
|
264 |
+
outputs['aux_outputs'][idx][_key] = aux_values[:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
|
265 |
+
return outputs
|
modeling/interface/prototype/attention_data_struct_seemv0.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
predict_name_matcher = {"predictions_class": ["pred_logits"],
|
6 |
+
"predictions_mask":["pred_masks", "pred_gmasks", "pred_smasks"],
|
7 |
+
"predictions_caption":["pred_captions", "pred_gtexts"],
|
8 |
+
"predictions_maskemb":["pred_smaskembs"],
|
9 |
+
"predictions_pos_spatial":["pred_pspatials"],
|
10 |
+
"predictions_neg_spatial":["pred_nspatials"],}
|
11 |
+
|
12 |
+
predict_index_matcher = {"predictions_class": ["queries_object"],
|
13 |
+
"predictions_mask":["queries_object", "queries_grounding", "queries_spatial"],
|
14 |
+
"predictions_caption": ["queries_object", "queries_grounding"],
|
15 |
+
"predictions_maskemb":["queries_spatial"],
|
16 |
+
"predictions_pos_spatial":["all"],
|
17 |
+
"predictions_neg_spatial":["all"],}
|
18 |
+
|
19 |
+
class Variable(object):
|
20 |
+
'''
|
21 |
+
Store dataset variable for attention
|
22 |
+
output: embedding that accumuates during cross/self attention
|
23 |
+
pos: positional embedding that is fixed during cross/self attention
|
24 |
+
name: name of the variable
|
25 |
+
type: type of the variable, e.g. queries, tokens
|
26 |
+
attn_mask: attention mask for corss attention
|
27 |
+
masking: masking for padding
|
28 |
+
'''
|
29 |
+
def __init__(self, output, name, _type, pos=None):
|
30 |
+
self.output = output
|
31 |
+
self.pos = pos
|
32 |
+
self.name = name
|
33 |
+
self.type = _type
|
34 |
+
self.attn_mask = None
|
35 |
+
self.masking = None
|
36 |
+
|
37 |
+
def copy(self,):
|
38 |
+
output = self.output.clone() if self.output is not None else None
|
39 |
+
pos = self.pos.clone() if self.pos is not None else None
|
40 |
+
return Variable(output, self.name, self.type, pos)
|
41 |
+
|
42 |
+
class AttentionDataStruct(nn.Module):
|
43 |
+
'''
|
44 |
+
Store dataset structure for cross/self attention
|
45 |
+
task_switch: switch for different tasks
|
46 |
+
|
47 |
+
p_attn_variables: prototype of variables that is used in cross/self attention
|
48 |
+
p_self_attn: prototype of variables that is used in self attention
|
49 |
+
p_cross_attn: prototype of variables that is used in cross attention
|
50 |
+
p_iter: prototype of iteration for different queries
|
51 |
+
p_masking: prototype of masking for different tokens
|
52 |
+
p_duplication: prototype of duplication for different quries
|
53 |
+
'''
|
54 |
+
def __init__(self, attn_arch, task_switch):
|
55 |
+
super(AttentionDataStruct, self).__init__()
|
56 |
+
self.task_switch = task_switch
|
57 |
+
|
58 |
+
# p stands for prototype
|
59 |
+
self.p_attn_variables = attn_arch['VARIABLE']
|
60 |
+
self.p_self_attn = attn_arch['SELF_ATTENTION']
|
61 |
+
self.p_cross_attn = attn_arch['CROSS_ATTENTION']
|
62 |
+
self.p_masking = attn_arch['MASKING']
|
63 |
+
self.p_duplication = attn_arch['DUPLICATION']
|
64 |
+
|
65 |
+
self.num_layers = attn_arch['NUM_LAYERS']
|
66 |
+
|
67 |
+
def reset(self, flags, task, extra):
|
68 |
+
# reset variables
|
69 |
+
self.attn_variables = {}
|
70 |
+
self.cross_attn_dict = {}
|
71 |
+
self.self_attn_dict = {}
|
72 |
+
self.duplication_dict = {}
|
73 |
+
self.query_index = {}
|
74 |
+
self.output = {}
|
75 |
+
self.flags = {}
|
76 |
+
self.spatial_memory = {}
|
77 |
+
|
78 |
+
# initialize duplication
|
79 |
+
for key, values in self.p_duplication.items():
|
80 |
+
for name in values:
|
81 |
+
self.duplication_dict["{}_{}".format(key, name)] = self.p_duplication[key][name]
|
82 |
+
|
83 |
+
# initialize flag
|
84 |
+
self.flags = {"object": True}
|
85 |
+
self.flags.update(flags)
|
86 |
+
|
87 |
+
# initialize task
|
88 |
+
self.task = task
|
89 |
+
|
90 |
+
# initialize output
|
91 |
+
if self.task_switch['mask']:
|
92 |
+
self.output['predictions_class'] = []
|
93 |
+
self.output['predictions_mask'] = []
|
94 |
+
|
95 |
+
if self.task_switch['bbox']:
|
96 |
+
self.output['predictions_bbox'] = []
|
97 |
+
|
98 |
+
if self.task_switch['spatial'] and ('spatial' in self.flags and self.flags['spatial']==True):
|
99 |
+
self.output['predictions_maskemb'] = []
|
100 |
+
self.output['predictions_pos_spatial'] = []
|
101 |
+
self.output['predictions_neg_spatial'] = []
|
102 |
+
# self.spatial_memory['spatial_query_mode'] = extra['spatial_query_mode']
|
103 |
+
|
104 |
+
if self.task_switch['spatial'] and ('memories_spatial' in self.flags and self.flags['memories_spatial']==True):
|
105 |
+
self.spatial_memory['prev_batch_mask'] = extra['prev_mask']
|
106 |
+
|
107 |
+
if self.task_switch['grounding'] and ('grounding' in self.flags and self.flags['grounding']==True):
|
108 |
+
self.output['predictions_caption'] = []
|
109 |
+
|
110 |
+
# initialize cross_attn, whether the variable is used in cross attention
|
111 |
+
for key, values in self.p_cross_attn.items():
|
112 |
+
for name in values:
|
113 |
+
self.cross_attn_dict["{}_{}".format(key, name)] = self.p_cross_attn[key][name]
|
114 |
+
|
115 |
+
# initialize self_attn, whether the variable is used in self attention, and the interactions between queries
|
116 |
+
for key, values in self.p_self_attn.items():
|
117 |
+
for name in values:
|
118 |
+
self.self_attn_dict["{}_{}".format(key, name)] = self.p_self_attn[key][name]
|
119 |
+
|
120 |
+
# initialize masking
|
121 |
+
self.masking = self.p_masking
|
122 |
+
|
123 |
+
# initialize query_index
|
124 |
+
self.query_index = {"all":[0, None]}
|
125 |
+
|
126 |
+
|
127 |
+
def set(self, name, _type, output=None, pos=None, var=None):
|
128 |
+
if var is not None:
|
129 |
+
self.attn_variables[name] = var
|
130 |
+
elif name in self.duplication_dict:
|
131 |
+
assert self.duplication_dict[name] in self.attn_variables, "Duplication variable {} is not initialized yet.".format(name)
|
132 |
+
self.attn_variables[name] = self.attn_variables[self.duplication_dict[name]].copy()
|
133 |
+
else:
|
134 |
+
var = Variable(output, name, _type, pos)
|
135 |
+
self.attn_variables[name] = var
|
136 |
+
|
137 |
+
def set_results(self, results):
|
138 |
+
for name in self.cross_attn_name:
|
139 |
+
self.attn_variables[name].attn_mask = results['attn_mask'][:,self.query_index[name][0]:self.query_index[name][1]]
|
140 |
+
for key in self.output:
|
141 |
+
self.output[key].append(results[key])
|
142 |
+
|
143 |
+
def set_maskings(self, name, masking):
|
144 |
+
self.attn_variables[name].masking = masking
|
145 |
+
|
146 |
+
def cross_attn_variables(self, ):
|
147 |
+
cross_attn_name = [key for key, value in self.cross_attn_dict.items()
|
148 |
+
if (value==True) and (key in self.attn_variables)
|
149 |
+
and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
|
150 |
+
self.cross_attn_name = cross_attn_name
|
151 |
+
|
152 |
+
output = torch.cat([self.attn_variables[name].output for name in cross_attn_name])
|
153 |
+
pos_emb = torch.cat([self.attn_variables[name].pos for name in cross_attn_name])
|
154 |
+
|
155 |
+
index = 0
|
156 |
+
for name in cross_attn_name:
|
157 |
+
self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
|
158 |
+
index += self.attn_variables[name].output.shape[0]
|
159 |
+
return output, pos_emb
|
160 |
+
|
161 |
+
def cross_attn_mask(self, size, num_heads):
|
162 |
+
attn_mask = torch.cat([self.attn_variables[name].attn_mask for name in self.cross_attn_name], dim=1)
|
163 |
+
|
164 |
+
# hard code memories_spatial to previous selected mask
|
165 |
+
if 'memories_spatial' in self.cross_attn_name:
|
166 |
+
memory_attn_mask = self.spatial_memory['prev_batch_mask']
|
167 |
+
bs,c,_,_ = memory_attn_mask.shape
|
168 |
+
memory_attn_mask = F.interpolate(memory_attn_mask, size, mode='bilinear', align_corners=False)
|
169 |
+
memory_attn_mask = (memory_attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, num_heads, 1, 1).flatten(0, 1) < 0.5).bool().detach()
|
170 |
+
attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = memory_attn_mask
|
171 |
+
|
172 |
+
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
|
173 |
+
return attn_mask
|
174 |
+
|
175 |
+
def self_attn(self, bs, num_heads):
|
176 |
+
self_attn_name = [key for key, value in self.self_attn_dict.items()
|
177 |
+
if len(value)>0 and key in self.attn_variables
|
178 |
+
and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
|
179 |
+
self.self_attn_name = self_attn_name
|
180 |
+
|
181 |
+
output = torch.cat([self.attn_variables[name].output for name in self_attn_name])
|
182 |
+
pos_emb = torch.cat([self.attn_variables[name].pos for name in self_attn_name])
|
183 |
+
|
184 |
+
index = 0
|
185 |
+
for name in self_attn_name:
|
186 |
+
self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
|
187 |
+
index += self.attn_variables[name].output.shape[0]
|
188 |
+
|
189 |
+
self_attn_mask = torch.ones((bs, output.shape[0], output.shape[0]), dtype=torch.bool, device=output.device)
|
190 |
+
self_attn_pair = []
|
191 |
+
# build self_attention mask by query interaction
|
192 |
+
for key1, value in self.self_attn_dict.items():
|
193 |
+
for key2 in value:
|
194 |
+
if key1 not in self_attn_name or key2 not in self_attn_name:
|
195 |
+
# exclude the variables that are not used in the current layer
|
196 |
+
continue
|
197 |
+
if (key1 in self.masking or key2 in self.masking) and (key1 != key2):
|
198 |
+
self_attn_pair += [[key1, key2]]
|
199 |
+
self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1], self.query_index[key2][0]:self.query_index[key2][1]] = False
|
200 |
+
|
201 |
+
# build self_attention mask by masking, for birectional
|
202 |
+
for key in self.masking:
|
203 |
+
if key in self_attn_name:
|
204 |
+
self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]][self.attn_variables[key].masking] = True
|
205 |
+
self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]].transpose(1,2)[self.attn_variables[key].masking] = True
|
206 |
+
|
207 |
+
# build self_attention mask by masking, for uni-directional
|
208 |
+
for key1, key2 in self_attn_pair:
|
209 |
+
if key1 not in self_attn_name or key2 not in self_attn_name:
|
210 |
+
# exclude the variables that are not used in the current layer
|
211 |
+
continue
|
212 |
+
if key1 in self.masking:
|
213 |
+
self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]][self.attn_variables[key1].masking] = True # HACK, not verified
|
214 |
+
if key2 in self.masking:
|
215 |
+
self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]].transpose(1,2)[self.attn_variables[key2].masking] = True
|
216 |
+
|
217 |
+
self_attn_mask = self_attn_mask.repeat_interleave(num_heads, dim=0)
|
218 |
+
return output, pos_emb, self_attn_mask
|
219 |
+
|
220 |
+
def update_variables(self, output, mode):
|
221 |
+
name_set = self.self_attn_name if mode=='self_attn' else self.cross_attn_name
|
222 |
+
for key in name_set:
|
223 |
+
self.attn_variables[key].output = output[self.query_index[key][0]:self.query_index[key][1]]
|
224 |
+
|
225 |
+
def update_spatial_results(self, results):
|
226 |
+
v_emb = results['pred_smaskembs']
|
227 |
+
pred_smasks = results['pred_smasks']
|
228 |
+
|
229 |
+
s_emb = results['pred_pspatials']
|
230 |
+
pred_logits = v_emb @ s_emb.transpose(1,2)
|
231 |
+
logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
|
232 |
+
logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
|
233 |
+
logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
|
234 |
+
pred_masks_pos = pred_smasks[logits_idx][:,None,]
|
235 |
+
|
236 |
+
# s_emb = results['pred_nspatials']
|
237 |
+
# pred_logits = v_emb @ s_emb.transpose(1,2)
|
238 |
+
# logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
|
239 |
+
# logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
|
240 |
+
# logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
|
241 |
+
# pred_masks_neg = pred_smasks[logits_idx][:,None,]
|
242 |
+
# # clip the negative mask to 0, and then multiply by -1
|
243 |
+
# pred_masks_neg = (pred_masks_neg.clip(0) * -1)
|
244 |
+
# keep_neg = (s_emb.sum(dim=list(range(1, s_emb.dim()))) != 0).float()
|
245 |
+
# pred_masks_neg = pred_masks_neg * keep_neg[:,None,None,None]
|
246 |
+
# extra = {"prev_mask": pred_masks_pos + pred_masks_neg}
|
247 |
+
|
248 |
+
extra = {"prev_mask": pred_masks_pos}
|
249 |
+
return extra
|
250 |
+
|
251 |
+
def organize_output(self, ):
|
252 |
+
outputs = {}
|
253 |
+
outputs['aux_outputs'] = [{} for i in range(self.num_layers)]
|
254 |
+
for key, values in self.output.items():
|
255 |
+
for _key, idx_name in zip(predict_name_matcher[key], predict_index_matcher[key]):
|
256 |
+
if idx_name not in self.query_index:
|
257 |
+
continue
|
258 |
+
outputs[_key] = self.output[key][-1][:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
|
259 |
+
for idx, aux_values in enumerate(self.output[key][:-1]):
|
260 |
+
outputs['aux_outputs'][idx][_key] = aux_values[:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
|
261 |
+
if self.task == 'spatial' or self.task == 'refimg':
|
262 |
+
outputs = self.update_spatial_results(outputs)
|
263 |
+
# outputs = self.update_spatial_results(outputs)
|
264 |
+
return outputs
|
modeling/interface/prototype/attention_data_struct_seemv1.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
predict_name_matcher = {"predictions_class": ["pred_logits"],
|
6 |
+
"predictions_mask":["pred_masks", "pred_gmasks", "pred_smasks"],
|
7 |
+
"predictions_caption":["pred_captions", "pred_gtexts", "pred_stexts"],
|
8 |
+
"predictions_maskemb":["pred_smaskembs"],
|
9 |
+
"predictions_pos_spatial":["pred_pspatials"],
|
10 |
+
"predictions_neg_spatial":["pred_nspatials"],}
|
11 |
+
|
12 |
+
predict_index_matcher = {"predictions_class": ["queries_object"],
|
13 |
+
"predictions_mask":["queries_object", "queries_grounding", "queries_spatial"],
|
14 |
+
"predictions_caption": ["queries_object", "queries_grounding", "queries_spatial"],
|
15 |
+
"predictions_maskemb":["queries_spatial"],
|
16 |
+
"predictions_pos_spatial":["all"],
|
17 |
+
"predictions_neg_spatial":["all"],}
|
18 |
+
|
19 |
+
class Variable(object):
|
20 |
+
'''
|
21 |
+
Store dataset variable for attention
|
22 |
+
output: embedding that accumuates during cross/self attention
|
23 |
+
pos: positional embedding that is fixed during cross/self attention
|
24 |
+
name: name of the variable
|
25 |
+
type: type of the variable, e.g. queries, tokens
|
26 |
+
attn_mask: attention mask for corss attention
|
27 |
+
masking: masking for padding
|
28 |
+
'''
|
29 |
+
def __init__(self, output, name, _type, pos=None):
|
30 |
+
self.output = output
|
31 |
+
self.pos = pos
|
32 |
+
self.name = name
|
33 |
+
self.type = _type
|
34 |
+
self.attn_mask = None
|
35 |
+
self.masking = None
|
36 |
+
|
37 |
+
def copy(self,):
|
38 |
+
output = self.output.clone() if self.output is not None else None
|
39 |
+
pos = self.pos.clone() if self.pos is not None else None
|
40 |
+
return Variable(output, self.name, self.type, pos)
|
41 |
+
|
42 |
+
def rand_sample(self, max_len):
|
43 |
+
rand_idx = torch.randint(0, len(self.pos), (max_len,))
|
44 |
+
self.output = self.output[rand_idx]
|
45 |
+
self.pos = self.pos[rand_idx]
|
46 |
+
return self
|
47 |
+
|
48 |
+
class AttentionDataStruct(nn.Module):
|
49 |
+
'''
|
50 |
+
Store dataset structure for cross/self attention
|
51 |
+
task_switch: switch for different tasks
|
52 |
+
|
53 |
+
p_attn_variables: prototype of variables that is used in cross/self attention
|
54 |
+
p_self_attn: prototype of variables that is used in self attention
|
55 |
+
p_cross_attn: prototype of variables that is used in cross attention
|
56 |
+
p_iter: prototype of iteration for different queries
|
57 |
+
p_masking: prototype of masking for different tokens
|
58 |
+
p_duplication: prototype of duplication for different quries
|
59 |
+
'''
|
60 |
+
def __init__(self, attn_arch, task_switch):
|
61 |
+
super(AttentionDataStruct, self).__init__()
|
62 |
+
self.task_switch = task_switch
|
63 |
+
|
64 |
+
# p stands for prototype
|
65 |
+
self.p_attn_variables = attn_arch['VARIABLE']
|
66 |
+
self.p_self_attn = attn_arch['SELF_ATTENTION']
|
67 |
+
self.p_cross_attn = attn_arch['CROSS_ATTENTION']
|
68 |
+
self.p_masking = attn_arch['MASKING']
|
69 |
+
self.p_duplication = attn_arch['DUPLICATION']
|
70 |
+
|
71 |
+
self.num_layers = attn_arch['NUM_LAYERS']
|
72 |
+
|
73 |
+
def reset(self, flags, task, extra):
|
74 |
+
# reset variables
|
75 |
+
self.attn_variables = {}
|
76 |
+
self.cross_attn_dict = {}
|
77 |
+
self.self_attn_dict = {}
|
78 |
+
self.duplication_dict = {}
|
79 |
+
self.query_index = {}
|
80 |
+
self.output = {}
|
81 |
+
self.flags = {}
|
82 |
+
self.spatial_memory = {}
|
83 |
+
self.extra = {}
|
84 |
+
|
85 |
+
# initialize duplication
|
86 |
+
for key, values in self.p_duplication.items():
|
87 |
+
for name in values:
|
88 |
+
self.duplication_dict["{}_{}".format(key, name)] = self.p_duplication[key][name]
|
89 |
+
|
90 |
+
# initialize flag
|
91 |
+
self.flags = {"object": True}
|
92 |
+
self.flags.update(flags)
|
93 |
+
|
94 |
+
# initialize task
|
95 |
+
self.task = task
|
96 |
+
|
97 |
+
# initialize output
|
98 |
+
if self.task_switch['mask']:
|
99 |
+
self.output['predictions_class'] = []
|
100 |
+
self.output['predictions_mask'] = []
|
101 |
+
|
102 |
+
if self.task_switch['bbox']:
|
103 |
+
self.output['predictions_bbox'] = []
|
104 |
+
|
105 |
+
if self.task_switch['spatial'] and ('memories_spatial' in self.flags and self.flags['memories_spatial']==True):
|
106 |
+
self.spatial_memory['prev_batch_mask'] = extra['prev_mask']
|
107 |
+
|
108 |
+
if self.task_switch['grounding'] and ('grounding' in self.flags and self.flags['grounding']==True):
|
109 |
+
self.output['predictions_caption'] = []
|
110 |
+
|
111 |
+
if self.task_switch['spatial'] and ('spatial' in self.flags and self.flags['spatial']==True):
|
112 |
+
self.output['predictions_maskemb'] = []
|
113 |
+
self.output['predictions_pos_spatial'] = []
|
114 |
+
self.output['predictions_neg_spatial'] = []
|
115 |
+
self.output['predictions_mask'] = [] if 'predictions_mask' not in self.output else self.output['predictions_mask']
|
116 |
+
self.output['predictions_class'] = [] if 'predictions_class' not in self.output else self.output['predictions_class']
|
117 |
+
self.output['predictions_caption'] = [] if 'predictions_caption' not in self.output else self.output['predictions_caption']
|
118 |
+
|
119 |
+
# initialize cross_attn, whether the variable is used in cross attention
|
120 |
+
for key, values in self.p_cross_attn.items():
|
121 |
+
for name in values:
|
122 |
+
self.cross_attn_dict["{}_{}".format(key, name)] = self.p_cross_attn[key][name]
|
123 |
+
|
124 |
+
# initialize self_attn, whether the variable is used in self attention, and the interactions between queries
|
125 |
+
for key, values in self.p_self_attn.items():
|
126 |
+
for name in values:
|
127 |
+
self.self_attn_dict["{}_{}".format(key, name)] = self.p_self_attn[key][name]
|
128 |
+
|
129 |
+
# initialize masking
|
130 |
+
self.masking = self.p_masking
|
131 |
+
|
132 |
+
# initialize query_index
|
133 |
+
self.query_index = {"all":[0, None]}
|
134 |
+
|
135 |
+
|
136 |
+
def set(self, name, _type, output=None, pos=None, var=None, sample_size=None):
|
137 |
+
if var is not None:
|
138 |
+
self.attn_variables[name] = var
|
139 |
+
elif name in self.duplication_dict:
|
140 |
+
assert self.duplication_dict[name] in self.attn_variables, "Duplication variable {} is not initialized yet.".format(name)
|
141 |
+
var = self.attn_variables[self.duplication_dict[name]].copy()
|
142 |
+
if sample_size is not None:
|
143 |
+
var = var.rand_sample(sample_size)
|
144 |
+
self.attn_variables[name] = var
|
145 |
+
else:
|
146 |
+
var = Variable(output, name, _type, pos)
|
147 |
+
self.attn_variables[name] = var
|
148 |
+
|
149 |
+
def set_results(self, results):
|
150 |
+
for name in self.cross_attn_name:
|
151 |
+
self.attn_variables[name].attn_mask = results['attn_mask'][:,self.query_index[name][0]:self.query_index[name][1]]
|
152 |
+
for key in self.output:
|
153 |
+
self.output[key].append(results[key])
|
154 |
+
|
155 |
+
def set_maskings(self, name, masking):
|
156 |
+
self.attn_variables[name].masking = masking
|
157 |
+
|
158 |
+
def set_extra(self, extra):
|
159 |
+
self.extra.update(extra)
|
160 |
+
|
161 |
+
def cross_attn_variables(self, ):
|
162 |
+
cross_attn_name = [key for key, value in self.cross_attn_dict.items()
|
163 |
+
if (value==True) and (key in self.attn_variables)
|
164 |
+
and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
|
165 |
+
self.cross_attn_name = cross_attn_name
|
166 |
+
|
167 |
+
output = torch.cat([self.attn_variables[name].output for name in cross_attn_name])
|
168 |
+
pos_emb = torch.cat([self.attn_variables[name].pos for name in cross_attn_name])
|
169 |
+
|
170 |
+
index = 0
|
171 |
+
for name in cross_attn_name:
|
172 |
+
self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
|
173 |
+
index += self.attn_variables[name].output.shape[0]
|
174 |
+
return output, pos_emb
|
175 |
+
|
176 |
+
def cross_attn_mask(self, size, num_heads):
|
177 |
+
attn_mask = torch.cat([self.attn_variables[name].attn_mask for name in self.cross_attn_name], dim=1)
|
178 |
+
|
179 |
+
# hard code memories_spatial to previous selected mask
|
180 |
+
if 'memories_spatial' in self.cross_attn_name:
|
181 |
+
memory_attn_mask = self.spatial_memory['prev_batch_mask']
|
182 |
+
bs,c,_,_ = memory_attn_mask.shape
|
183 |
+
memory_attn_mask = F.interpolate(memory_attn_mask, size, mode='bilinear', align_corners=False)
|
184 |
+
memory_attn_mask = (memory_attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, num_heads, 1, 1).flatten(0, 1) < 0.5).bool().detach()
|
185 |
+
repeat = (self.query_index['memories_spatial'][1] - self.query_index['memories_spatial'][0]) // c
|
186 |
+
mem_len = self.query_index['memories_spatial'][1] - self.query_index['memories_spatial'][0]
|
187 |
+
probs = torch.tensor([1./repeat for i in range(c)])
|
188 |
+
indices = torch.multinomial(probs, num_samples=mem_len, replacement=True).sort()[0]
|
189 |
+
attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = memory_attn_mask[:,indices]
|
190 |
+
self.extra['memory_indices'] = indices
|
191 |
+
|
192 |
+
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
|
193 |
+
return attn_mask
|
194 |
+
|
195 |
+
def self_attn(self, bs, num_heads):
|
196 |
+
self_attn_name = [key for key, value in self.self_attn_dict.items()
|
197 |
+
if len(value)>0 and key in self.attn_variables
|
198 |
+
and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
|
199 |
+
self.self_attn_name = self_attn_name
|
200 |
+
|
201 |
+
output = torch.cat([self.attn_variables[name].output for name in self_attn_name])
|
202 |
+
pos_emb = torch.cat([self.attn_variables[name].pos for name in self_attn_name])
|
203 |
+
|
204 |
+
index = 0
|
205 |
+
for name in self_attn_name:
|
206 |
+
self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
|
207 |
+
index += self.attn_variables[name].output.shape[0]
|
208 |
+
|
209 |
+
self_attn_mask = torch.ones((bs, output.shape[0], output.shape[0]), dtype=torch.bool, device=output.device)
|
210 |
+
self_attn_pair = []
|
211 |
+
# build self_attention mask by query interaction
|
212 |
+
for key1, value in self.self_attn_dict.items():
|
213 |
+
for key2 in value:
|
214 |
+
if key1 not in self_attn_name or key2 not in self_attn_name:
|
215 |
+
# exclude the variables that are not used in the current layer
|
216 |
+
continue
|
217 |
+
if (key1 in self.masking or key2 in self.masking) and (key1 != key2):
|
218 |
+
self_attn_pair += [[key1, key2]]
|
219 |
+
self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1], self.query_index[key2][0]:self.query_index[key2][1]] = False
|
220 |
+
|
221 |
+
# build self_attention mask by masking, for birectional
|
222 |
+
for key in self.masking:
|
223 |
+
if key in self_attn_name:
|
224 |
+
self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]][self.attn_variables[key].masking] = True
|
225 |
+
self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]].transpose(1,2)[self.attn_variables[key].masking] = True
|
226 |
+
|
227 |
+
# build self_attention mask by masking, for uni-directional
|
228 |
+
for key1, key2 in self_attn_pair:
|
229 |
+
if key1 not in self_attn_name or key2 not in self_attn_name:
|
230 |
+
# exclude the variables that are not used in the current layer
|
231 |
+
continue
|
232 |
+
if key1 in self.masking:
|
233 |
+
self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]][self.attn_variables[key1].masking] = True # HACK, not verified
|
234 |
+
if key2 in self.masking:
|
235 |
+
self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]].transpose(1,2)[self.attn_variables[key2].masking] = True
|
236 |
+
|
237 |
+
# build self_attention mask masking for spatial query
|
238 |
+
# spatial query attend with itself
|
239 |
+
if 'queries_spatial' in self_attn_name and 'tokens_spatial' in self_attn_name:
|
240 |
+
diag_mask = ~(torch.eye(self.extra['spatial_query_number']).repeat_interleave(self.extra['sample_size'],dim=0).repeat_interleave(self.extra['sample_size'],dim=1)).bool()
|
241 |
+
self_attn_mask[:,self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1],self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1]] = diag_mask[None,]
|
242 |
+
# spatial query attend with spatial token
|
243 |
+
indices = self.extra['spatial_indices'].permute(0,2,1)
|
244 |
+
diag_index = torch.arange(self.extra['spatial_query_number'], device=indices.device).repeat_interleave(self.extra['sample_size'],dim=0)[None,:,None]
|
245 |
+
diag_mask = ~(indices == diag_index)
|
246 |
+
self_attn_mask[:,self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1],self.query_index['tokens_spatial'][0]:self.query_index['tokens_spatial'][1]] = diag_mask
|
247 |
+
# spatial token attend with itself
|
248 |
+
diag_mask = ~(indices == indices.transpose(1,2))
|
249 |
+
self_attn_mask[:,self.query_index['tokens_spatial'][0]:self.query_index['tokens_spatial'][1],self.query_index['tokens_spatial'][0]:self.query_index['tokens_spatial'][1]] = diag_mask
|
250 |
+
|
251 |
+
if 'memory_indices' in self.extra:
|
252 |
+
# spatial query attend with memory
|
253 |
+
memory_indices = self.extra['memory_indices'][None,None,:]
|
254 |
+
diag_index = torch.arange(self.extra['spatial_query_number'], device=memory_indices.device).repeat_interleave(self.extra['sample_size'],dim=0)[None,:,None]
|
255 |
+
diag_mask = ~(diag_index == memory_indices)
|
256 |
+
self_attn_mask[:,self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1],self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = diag_mask
|
257 |
+
# memory attend with itself
|
258 |
+
diag_mask = ~(memory_indices == memory_indices.transpose(1,2))
|
259 |
+
self_attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1],self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = diag_mask
|
260 |
+
|
261 |
+
self_attn_mask = self_attn_mask.repeat_interleave(num_heads, dim=0)
|
262 |
+
return output, pos_emb, self_attn_mask
|
263 |
+
|
264 |
+
def update_variables(self, output, mode):
|
265 |
+
name_set = self.self_attn_name if mode=='self_attn' else self.cross_attn_name
|
266 |
+
for key in name_set:
|
267 |
+
self.attn_variables[key].output = output[self.query_index[key][0]:self.query_index[key][1]]
|
268 |
+
|
269 |
+
def update_spatial_results(self, results):
|
270 |
+
v_emb = results['pred_smaskembs']
|
271 |
+
pred_smasks = results['pred_smasks']
|
272 |
+
|
273 |
+
s_emb = results['pred_pspatials']
|
274 |
+
diag_mask = ~(torch.eye(self.extra['spatial_query_number'], device=s_emb.device).repeat_interleave(self.extra['sample_size'],dim=0)).bool()
|
275 |
+
offset = torch.zeros_like(diag_mask, device=s_emb.device).float()
|
276 |
+
offset.masked_fill_(diag_mask, float("-inf"))
|
277 |
+
|
278 |
+
pred_logits = v_emb @ s_emb.transpose(1,2) + offset[None,]
|
279 |
+
bs,_,ns=pred_logits.shape
|
280 |
+
_,_,h,w=pred_smasks.shape
|
281 |
+
|
282 |
+
logits_idx_y = pred_logits.max(dim=1)[1]
|
283 |
+
logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)[:,None].repeat(1, logits_idx_y.shape[1])
|
284 |
+
logits_idx = torch.stack([logits_idx_x, logits_idx_y]).view(2,-1).tolist()
|
285 |
+
pred_masks_pos = pred_smasks[logits_idx].reshape(bs,ns,h,w)
|
286 |
+
extra = {"prev_mask": pred_masks_pos}
|
287 |
+
return extra
|
288 |
+
|
289 |
+
def organize_output(self, ):
|
290 |
+
outputs = {}
|
291 |
+
outputs['aux_outputs'] = [{} for i in range(self.num_layers)]
|
292 |
+
for key, values in self.output.items():
|
293 |
+
for _key, idx_name in zip(predict_name_matcher[key], predict_index_matcher[key]):
|
294 |
+
if idx_name not in self.query_index:
|
295 |
+
continue
|
296 |
+
outputs[_key] = self.output[key][-1][:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
|
297 |
+
for idx, aux_values in enumerate(self.output[key][:-1]):
|
298 |
+
outputs['aux_outputs'][idx][_key] = aux_values[:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
|
299 |
+
if self.task == 'spatial' or self.task == 'refimg':
|
300 |
+
outputs = self.update_spatial_results(outputs)
|
301 |
+
# outputs = self.update_spatial_results(outputs)
|
302 |
+
return outputs
|
modeling/interface/seem_demo.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# SEEM -- Segment Everything Everywhere All At Once
|
3 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
4 |
+
# Written by Xueyan Zou ([email protected]), Jianwei Yang ([email protected])
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn, Tensor
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
from timm.models.layers import trunc_normal_
|
15 |
+
from detectron2.layers import Conv2d
|
16 |
+
import fvcore.nn.weight_init as weight_init
|
17 |
+
|
18 |
+
from .build import register_decoder
|
19 |
+
from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
|
20 |
+
from .prototype.attention_data_struct_seemdemo import AttentionDataStruct
|
21 |
+
from ..utils import rand_sample_plain as rand_sample
|
22 |
+
from ..utils import prepare_features, configurable
|
23 |
+
from ..modules import PositionEmbeddingSine
|
24 |
+
from ..modules.point_features import point_sample
|
25 |
+
|
26 |
+
|
27 |
+
class SEEMDecoder(nn.Module):
|
28 |
+
|
29 |
+
@configurable
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
lang_encoder: nn.Module,
|
33 |
+
in_channels,
|
34 |
+
mask_classification=True,
|
35 |
+
*,
|
36 |
+
hidden_dim: int,
|
37 |
+
dim_proj: int,
|
38 |
+
num_queries: int,
|
39 |
+
contxt_len: int,
|
40 |
+
nheads: int,
|
41 |
+
dim_feedforward: int,
|
42 |
+
dec_layers: int,
|
43 |
+
pre_norm: bool,
|
44 |
+
mask_dim: int,
|
45 |
+
task_switch: dict,
|
46 |
+
enforce_input_project: bool,
|
47 |
+
max_spatial_len: int,
|
48 |
+
attn_arch: dict,
|
49 |
+
):
|
50 |
+
"""
|
51 |
+
NOTE: this interface is experimental.
|
52 |
+
Args:
|
53 |
+
in_channels: channels of the input features
|
54 |
+
mask_classification: whether to add mask classifier or not
|
55 |
+
num_classes: number of classes
|
56 |
+
hidden_dim: Transformer feature dimension
|
57 |
+
num_queries: number of queries
|
58 |
+
nheads: number of heads
|
59 |
+
dim_feedforward: feature dimension in feedforward network
|
60 |
+
enc_layers: number of Transformer encoder layers
|
61 |
+
dec_layers: number of Transformer decoder layers
|
62 |
+
pre_norm: whether to use pre-LayerNorm or not
|
63 |
+
mask_dim: mask feature dimension
|
64 |
+
enforce_input_project: add input project 1x1 conv even if input
|
65 |
+
channels and hidden dim is identical
|
66 |
+
"""
|
67 |
+
super().__init__()
|
68 |
+
assert mask_classification, "Only support mask classification model"
|
69 |
+
self.mask_classification = mask_classification
|
70 |
+
|
71 |
+
# positional encoding
|
72 |
+
N_steps = hidden_dim // 2
|
73 |
+
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
|
74 |
+
|
75 |
+
# define Transformer decoder here
|
76 |
+
self.num_heads = nheads
|
77 |
+
self.num_layers = dec_layers
|
78 |
+
self.contxt_len = contxt_len
|
79 |
+
self.transformer_self_attention_layers = nn.ModuleList()
|
80 |
+
self.transformer_cross_attention_layers = nn.ModuleList()
|
81 |
+
self.transformer_ffn_layers = nn.ModuleList()
|
82 |
+
|
83 |
+
for _ in range(self.num_layers):
|
84 |
+
self.transformer_self_attention_layers.append(
|
85 |
+
SelfAttentionLayer(
|
86 |
+
d_model=hidden_dim,
|
87 |
+
nhead=nheads,
|
88 |
+
dropout=0.0,
|
89 |
+
normalize_before=pre_norm,
|
90 |
+
)
|
91 |
+
)
|
92 |
+
|
93 |
+
self.transformer_cross_attention_layers.append(
|
94 |
+
CrossAttentionLayer(
|
95 |
+
d_model=hidden_dim,
|
96 |
+
nhead=nheads,
|
97 |
+
dropout=0.0,
|
98 |
+
normalize_before=pre_norm,
|
99 |
+
)
|
100 |
+
)
|
101 |
+
|
102 |
+
self.transformer_ffn_layers.append(
|
103 |
+
FFNLayer(
|
104 |
+
d_model=hidden_dim,
|
105 |
+
dim_feedforward=dim_feedforward,
|
106 |
+
dropout=0.0,
|
107 |
+
normalize_before=pre_norm,
|
108 |
+
)
|
109 |
+
)
|
110 |
+
|
111 |
+
self.decoder_norm = nn.LayerNorm(hidden_dim)
|
112 |
+
|
113 |
+
self.num_queries = num_queries
|
114 |
+
# learnable query features
|
115 |
+
self.query_feat = nn.Embedding(num_queries, hidden_dim)
|
116 |
+
# learnable query p.e.
|
117 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
118 |
+
# learnable positive negative indicator
|
119 |
+
self.pn_indicator = nn.Embedding(2, hidden_dim)
|
120 |
+
|
121 |
+
# level embedding (we always use 3 scales)
|
122 |
+
self.num_feature_levels = 3
|
123 |
+
self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
|
124 |
+
self.input_proj = nn.ModuleList()
|
125 |
+
|
126 |
+
for _ in range(self.num_feature_levels):
|
127 |
+
if in_channels != hidden_dim or enforce_input_project:
|
128 |
+
self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
|
129 |
+
weight_init.c2_xavier_fill(self.input_proj[-1])
|
130 |
+
else:
|
131 |
+
self.input_proj.append(nn.Sequential())
|
132 |
+
|
133 |
+
self.task_switch = task_switch
|
134 |
+
self.query_index = {}
|
135 |
+
|
136 |
+
# output FFNs
|
137 |
+
self.lang_encoder = lang_encoder
|
138 |
+
if self.task_switch['mask']:
|
139 |
+
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
|
140 |
+
|
141 |
+
self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
|
142 |
+
trunc_normal_(self.class_embed, std=.02)
|
143 |
+
|
144 |
+
if task_switch['bbox']:
|
145 |
+
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
146 |
+
|
147 |
+
if task_switch['spatial']:
|
148 |
+
# spatial query
|
149 |
+
self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(3)])
|
150 |
+
trunc_normal_(self.mask_sptial_embed[0], std=.02)
|
151 |
+
trunc_normal_(self.mask_sptial_embed[1], std=.02)
|
152 |
+
trunc_normal_(self.mask_sptial_embed[2], std=.02)
|
153 |
+
|
154 |
+
self.max_spatial_len = max_spatial_len
|
155 |
+
# spatial memory
|
156 |
+
num_spatial_memories = attn_arch['SPATIAL_MEMORIES']
|
157 |
+
self.spatial_embed = nn.Embedding(num_spatial_memories, hidden_dim)
|
158 |
+
self.spatial_featured = nn.Embedding(num_spatial_memories, hidden_dim)
|
159 |
+
|
160 |
+
# build AttentionDataStruct
|
161 |
+
attn_arch['NUM_LAYERS'] = self.num_layers
|
162 |
+
self.attention_data = AttentionDataStruct(attn_arch, task_switch)
|
163 |
+
|
164 |
+
@classmethod
|
165 |
+
def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
|
166 |
+
ret = {}
|
167 |
+
|
168 |
+
ret["lang_encoder"] = lang_encoder
|
169 |
+
ret["in_channels"] = in_channels
|
170 |
+
ret["mask_classification"] = mask_classification
|
171 |
+
|
172 |
+
enc_cfg = cfg['MODEL']['ENCODER']
|
173 |
+
dec_cfg = cfg['MODEL']['DECODER']
|
174 |
+
|
175 |
+
ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
|
176 |
+
ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
|
177 |
+
ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
|
178 |
+
ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
|
179 |
+
|
180 |
+
# Transformer parameters:
|
181 |
+
ret["nheads"] = dec_cfg['NHEADS']
|
182 |
+
ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
|
183 |
+
|
184 |
+
# NOTE: because we add learnable query features which requires supervision,
|
185 |
+
# we add minus 1 to decoder layers to be consistent with our loss
|
186 |
+
# implementation: that is, number of auxiliary losses is always
|
187 |
+
# equal to number of decoder layers. With learnable query features, the number of
|
188 |
+
# auxiliary losses equals number of decoders plus 1.
|
189 |
+
assert dec_cfg['DEC_LAYERS'] >= 1
|
190 |
+
ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
|
191 |
+
ret["pre_norm"] = dec_cfg['PRE_NORM']
|
192 |
+
ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
|
193 |
+
ret["mask_dim"] = enc_cfg['MASK_DIM']
|
194 |
+
ret["task_switch"] = extra['task_switch']
|
195 |
+
ret["max_spatial_len"] = dec_cfg['MAX_SPATIAL_LEN']
|
196 |
+
|
197 |
+
# attn data struct
|
198 |
+
ret["attn_arch"] = cfg['ATTENTION_ARCH']
|
199 |
+
|
200 |
+
return ret
|
201 |
+
|
202 |
+
def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
|
203 |
+
# x is a list of multi-scale feature
|
204 |
+
assert len(x) == self.num_feature_levels; del mask
|
205 |
+
spatial_extra_flag = 'spatial_query_pos_mask' in extra.keys() or task == 'refimg'
|
206 |
+
grounding_extra_flag = 'grounding_tokens' in extra.keys()
|
207 |
+
visual_extra_flag = 'visual_query_pos' in extra.keys()
|
208 |
+
audio_extra_flag = 'audio_tokens' in extra.keys()
|
209 |
+
spatial_memory_flag = 'prev_mask' in extra.keys()
|
210 |
+
flags = {"spatial": spatial_extra_flag, "grounding": grounding_extra_flag, "memories_spatial": spatial_memory_flag, "visual": visual_extra_flag, "audio": audio_extra_flag}
|
211 |
+
self.attention_data.reset(flags, task, extra)
|
212 |
+
|
213 |
+
src, pos, size_list = prepare_features(x, self.num_feature_levels, self.pe_layer, self.input_proj, self.level_embed)
|
214 |
+
_, bs, _ = src[0].shape
|
215 |
+
|
216 |
+
# QxNxC
|
217 |
+
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
|
218 |
+
output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
|
219 |
+
self.attention_data.set('queries_object', 'queries', output, query_embed)
|
220 |
+
|
221 |
+
if self.task_switch['spatial'] and spatial_extra_flag:
|
222 |
+
# get divisor
|
223 |
+
_,h,w = extra['spatial_query_pos_mask'][0].shape
|
224 |
+
divisor = torch.tensor([h,w], device=output.device)[None,]
|
225 |
+
|
226 |
+
# Get mean pos spatial query
|
227 |
+
non_zero_pos_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
|
228 |
+
non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
|
229 |
+
non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
|
230 |
+
spatial_query_pos = point_sample(mask_features, non_zero_pos_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
|
231 |
+
spatial_query_pos = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask)]).transpose(0,1).nan_to_num()
|
232 |
+
|
233 |
+
# Get mean neg spatial query
|
234 |
+
non_zero_neg_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
|
235 |
+
non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
|
236 |
+
non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
|
237 |
+
spatial_query_neg = point_sample(mask_features, non_zero_neg_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
|
238 |
+
spatial_query_neg = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask)]).transpose(0,1).nan_to_num()
|
239 |
+
|
240 |
+
# merge positive and negative sample points for self attention
|
241 |
+
|
242 |
+
# Get layerwise spatial query
|
243 |
+
src_spatial_queries = []
|
244 |
+
src_spatial_maskings = []
|
245 |
+
for i in range(len(src)):
|
246 |
+
hw,_,dc = src[i].shape
|
247 |
+
src_mask_features = src[i].view(size_list[i][0],size_list[i][1],bs,dc)
|
248 |
+
src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
|
249 |
+
|
250 |
+
non_zero_query_point_pos = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
|
251 |
+
non_zero_query_point_neg = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
|
252 |
+
non_zero_query_point = [torch.cat([x,y], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
|
253 |
+
|
254 |
+
pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
|
255 |
+
pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
|
256 |
+
|
257 |
+
non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
|
258 |
+
non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
|
259 |
+
non_zero_query_point[non_zero_query_mask] = 0
|
260 |
+
|
261 |
+
spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
|
262 |
+
spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
|
263 |
+
spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
|
264 |
+
|
265 |
+
src_spatial_queries += [spatial_tokens]
|
266 |
+
src_spatial_maskings += [non_zero_query_mask]
|
267 |
+
|
268 |
+
if 'refimg' in task:
|
269 |
+
output_refimg = {}
|
270 |
+
output_refimg['visual_query_pos'] = spatial_query_pos
|
271 |
+
output_refimg['visual_query_neg'] = spatial_query_neg
|
272 |
+
output_refimg['src_visual_queries'] = src_spatial_queries
|
273 |
+
output_refimg['src_visual_maskings'] = src_spatial_maskings
|
274 |
+
return output_refimg
|
275 |
+
|
276 |
+
if task != 'demo':
|
277 |
+
# Get object query for spatial index
|
278 |
+
self.attention_data.set('queries_spatial', 'queries')
|
279 |
+
|
280 |
+
if self.task_switch['visual'] and visual_extra_flag:
|
281 |
+
visual_query_pos = extra['visual_query_pos']
|
282 |
+
visual_query_neg = extra['visual_query_neg']
|
283 |
+
src_visual_queries = extra['src_visual_queries']
|
284 |
+
src_visual_maskings = extra['src_visual_maskings']
|
285 |
+
|
286 |
+
if self.task_switch['grounding'] and grounding_extra_flag:
|
287 |
+
# Get grounding tokens
|
288 |
+
grounding_tokens = extra['grounding_tokens']
|
289 |
+
_grounding_tokens = grounding_tokens.detach().clone()
|
290 |
+
|
291 |
+
self.attention_data.set('tokens_grounding', 'tokens', grounding_tokens, _grounding_tokens)
|
292 |
+
self.attention_data.set_maskings('tokens_grounding', extra['grounding_nonzero_mask'])
|
293 |
+
|
294 |
+
if self.task_switch['audio'] and audio_extra_flag:
|
295 |
+
# Get grounding tokens
|
296 |
+
grounding_tokens = extra['audio_tokens']
|
297 |
+
_grounding_tokens = grounding_tokens.detach().clone()
|
298 |
+
|
299 |
+
self.attention_data.set('tokens_audio', 'tokens', grounding_tokens, _grounding_tokens)
|
300 |
+
self.attention_data.set_maskings('tokens_audio', extra['audio_nonzero_mask'])
|
301 |
+
|
302 |
+
output, query_embed = self.attention_data.cross_attn_variables()
|
303 |
+
# prediction heads on learnable query features
|
304 |
+
results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
|
305 |
+
results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
|
306 |
+
results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
|
307 |
+
results["predictions_pos_visual"] = visual_query_pos.transpose(0,1) if visual_extra_flag else None
|
308 |
+
results["predictions_neg_visual"] = visual_query_neg.transpose(0,1) if visual_extra_flag else None
|
309 |
+
self.attention_data.set_results(results)
|
310 |
+
|
311 |
+
for i in range(self.num_layers):
|
312 |
+
level_index = i % self.num_feature_levels
|
313 |
+
# CROSS ATTENTION
|
314 |
+
output, avg_attn = self.transformer_cross_attention_layers[i](
|
315 |
+
output, src[level_index],
|
316 |
+
memory_mask=self.attention_data.cross_attn_mask(size_list[level_index], self.num_heads),
|
317 |
+
memory_key_padding_mask=None, # here we do not apply masking on padded region
|
318 |
+
pos=pos[level_index], query_pos=query_embed
|
319 |
+
)
|
320 |
+
self.attention_data.update_variables(output, 'cross_attn')
|
321 |
+
|
322 |
+
# SELF ATTENTION
|
323 |
+
self_attn_mask = torch.zeros((bs, self.num_queries, self.num_queries), device=query_embed.device).bool() # Default False (attend oq)
|
324 |
+
if self.task_switch['spatial'] and spatial_extra_flag:
|
325 |
+
# get spatial tokens
|
326 |
+
spatial_tokens = src_spatial_queries[level_index]
|
327 |
+
_spatial_tokens = spatial_tokens.detach().clone()
|
328 |
+
|
329 |
+
self.attention_data.set('tokens_spatial', 'tokens', spatial_tokens, _spatial_tokens)
|
330 |
+
self.attention_data.set_maskings('tokens_spatial', src_spatial_maskings[level_index])
|
331 |
+
|
332 |
+
if self.task_switch['visual'] and visual_extra_flag:
|
333 |
+
# get spatial tokens
|
334 |
+
visual_tokens = src_visual_queries[level_index]
|
335 |
+
_visual_tokens = visual_tokens.detach().clone()
|
336 |
+
|
337 |
+
self.attention_data.set('tokens_visual', 'tokens', visual_tokens, _visual_tokens)
|
338 |
+
self.attention_data.set_maskings('tokens_visual', src_visual_maskings[level_index])
|
339 |
+
|
340 |
+
output, query_embed, self_attn_mask = self.attention_data.self_attn(bs, self.num_heads)
|
341 |
+
output = self.transformer_self_attention_layers[i](
|
342 |
+
output, tgt_mask=self_attn_mask,
|
343 |
+
tgt_key_padding_mask=None,
|
344 |
+
query_pos=query_embed)
|
345 |
+
|
346 |
+
# FFN
|
347 |
+
output = self.transformer_ffn_layers[i](
|
348 |
+
output
|
349 |
+
)
|
350 |
+
|
351 |
+
self.attention_data.update_variables(output, 'self_attn')
|
352 |
+
output, query_embed = self.attention_data.cross_attn_variables()
|
353 |
+
results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i)
|
354 |
+
results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
|
355 |
+
results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
|
356 |
+
results["predictions_pos_visual"] = visual_query_pos.transpose(0,1) if visual_extra_flag else None
|
357 |
+
results["predictions_neg_visual"] = visual_query_neg.transpose(0,1) if visual_extra_flag else None
|
358 |
+
self.attention_data.set_results(results)
|
359 |
+
|
360 |
+
return self.attention_data.organize_output()
|
361 |
+
|
362 |
+
def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1):
|
363 |
+
decoder_output = self.decoder_norm(output)
|
364 |
+
decoder_output = decoder_output.transpose(0, 1)
|
365 |
+
class_embed = decoder_output @ self.class_embed
|
366 |
+
outputs_class = self.lang_encoder.compute_similarity(class_embed)
|
367 |
+
mask_embed = self.mask_embed(decoder_output)
|
368 |
+
outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
|
369 |
+
|
370 |
+
outputs_bbox = [None for i in range(len(outputs_mask))]
|
371 |
+
if self.task_switch['bbox']:
|
372 |
+
outputs_bbox = self.bbox_embed(decoder_output)
|
373 |
+
|
374 |
+
# NOTE: prediction is of higher-resolution
|
375 |
+
# [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
|
376 |
+
attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
|
377 |
+
|
378 |
+
# must use bool type
|
379 |
+
# If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
|
380 |
+
attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
|
381 |
+
attn_mask = attn_mask.detach()
|
382 |
+
|
383 |
+
outputs_caption = class_embed
|
384 |
+
|
385 |
+
results = {
|
386 |
+
"attn_mask": attn_mask,
|
387 |
+
"predictions_class": outputs_class,
|
388 |
+
"predictions_mask": outputs_mask,
|
389 |
+
"predictions_bbox": outputs_bbox,
|
390 |
+
"predictions_caption": outputs_caption,
|
391 |
+
"predictions_maskemb": mask_embed,
|
392 |
+
}
|
393 |
+
return results
|
394 |
+
|
395 |
+
@register_decoder
|
396 |
+
def get_seem_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
|
397 |
+
return SEEMDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
|
modeling/interface/seem_v0.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# SEEM -- Segment Everything Everywhere All at Once
|
3 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
4 |
+
# Written by Xueyan Zou ([email protected])
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn, Tensor
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
from timm.models.layers import trunc_normal_
|
15 |
+
from detectron2.layers import Conv2d
|
16 |
+
import fvcore.nn.weight_init as weight_init
|
17 |
+
|
18 |
+
from .build import register_decoder
|
19 |
+
from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
|
20 |
+
from .prototype.attention_data_struct_seemv0 import AttentionDataStruct
|
21 |
+
from ..utils import rand_sample_plain as rand_sample
|
22 |
+
from ..utils import prepare_features, configurable
|
23 |
+
from ..modules import PositionEmbeddingSine
|
24 |
+
from ..modules.point_features import point_sample
|
25 |
+
|
26 |
+
|
27 |
+
class SEEMDecoder(nn.Module):
|
28 |
+
|
29 |
+
@configurable
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
lang_encoder: nn.Module,
|
33 |
+
in_channels,
|
34 |
+
mask_classification=True,
|
35 |
+
*,
|
36 |
+
hidden_dim: int,
|
37 |
+
dim_proj: int,
|
38 |
+
num_queries: int,
|
39 |
+
contxt_len: int,
|
40 |
+
nheads: int,
|
41 |
+
dim_feedforward: int,
|
42 |
+
dec_layers: int,
|
43 |
+
pre_norm: bool,
|
44 |
+
mask_dim: int,
|
45 |
+
task_switch: dict,
|
46 |
+
enforce_input_project: bool,
|
47 |
+
max_spatial_len: int,
|
48 |
+
attn_arch: dict,
|
49 |
+
):
|
50 |
+
"""
|
51 |
+
NOTE: this interface is experimental.
|
52 |
+
Args:
|
53 |
+
in_channels: channels of the input features
|
54 |
+
mask_classification: whether to add mask classifier or not
|
55 |
+
num_classes: number of classes
|
56 |
+
hidden_dim: Transformer feature dimension
|
57 |
+
num_queries: number of queries
|
58 |
+
nheads: number of heads
|
59 |
+
dim_feedforward: feature dimension in feedforward network
|
60 |
+
enc_layers: number of Transformer encoder layers
|
61 |
+
dec_layers: number of Transformer decoder layers
|
62 |
+
pre_norm: whether to use pre-LayerNorm or not
|
63 |
+
mask_dim: mask feature dimension
|
64 |
+
enforce_input_project: add input project 1x1 conv even if input
|
65 |
+
channels and hidden dim is identical
|
66 |
+
"""
|
67 |
+
super().__init__()
|
68 |
+
assert mask_classification, "Only support mask classification model"
|
69 |
+
self.mask_classification = mask_classification
|
70 |
+
|
71 |
+
# positional encoding
|
72 |
+
N_steps = hidden_dim // 2
|
73 |
+
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
|
74 |
+
|
75 |
+
# define Transformer decoder here
|
76 |
+
self.num_heads = nheads
|
77 |
+
self.num_layers = dec_layers
|
78 |
+
self.contxt_len = contxt_len
|
79 |
+
self.transformer_self_attention_layers = nn.ModuleList()
|
80 |
+
self.transformer_cross_attention_layers = nn.ModuleList()
|
81 |
+
self.transformer_ffn_layers = nn.ModuleList()
|
82 |
+
|
83 |
+
for _ in range(self.num_layers):
|
84 |
+
self.transformer_self_attention_layers.append(
|
85 |
+
SelfAttentionLayer(
|
86 |
+
d_model=hidden_dim,
|
87 |
+
nhead=nheads,
|
88 |
+
dropout=0.0,
|
89 |
+
normalize_before=pre_norm,
|
90 |
+
)
|
91 |
+
)
|
92 |
+
|
93 |
+
self.transformer_cross_attention_layers.append(
|
94 |
+
CrossAttentionLayer(
|
95 |
+
d_model=hidden_dim,
|
96 |
+
nhead=nheads,
|
97 |
+
dropout=0.0,
|
98 |
+
normalize_before=pre_norm,
|
99 |
+
)
|
100 |
+
)
|
101 |
+
|
102 |
+
self.transformer_ffn_layers.append(
|
103 |
+
FFNLayer(
|
104 |
+
d_model=hidden_dim,
|
105 |
+
dim_feedforward=dim_feedforward,
|
106 |
+
dropout=0.0,
|
107 |
+
normalize_before=pre_norm,
|
108 |
+
)
|
109 |
+
)
|
110 |
+
|
111 |
+
self.decoder_norm = nn.LayerNorm(hidden_dim)
|
112 |
+
|
113 |
+
self.num_queries = num_queries
|
114 |
+
# learnable query features
|
115 |
+
self.query_feat = nn.Embedding(num_queries, hidden_dim)
|
116 |
+
# learnable query p.e.
|
117 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
118 |
+
|
119 |
+
# level embedding (we always use 3 scales)
|
120 |
+
self.num_feature_levels = 3
|
121 |
+
self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
|
122 |
+
self.input_proj = nn.ModuleList()
|
123 |
+
|
124 |
+
for _ in range(self.num_feature_levels):
|
125 |
+
if in_channels != hidden_dim or enforce_input_project:
|
126 |
+
self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
|
127 |
+
weight_init.c2_xavier_fill(self.input_proj[-1])
|
128 |
+
else:
|
129 |
+
self.input_proj.append(nn.Sequential())
|
130 |
+
|
131 |
+
self.task_switch = task_switch
|
132 |
+
self.query_index = {}
|
133 |
+
|
134 |
+
# output FFNs
|
135 |
+
self.lang_encoder = lang_encoder
|
136 |
+
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
|
137 |
+
self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
|
138 |
+
trunc_normal_(self.class_embed, std=.02)
|
139 |
+
|
140 |
+
if task_switch['bbox']:
|
141 |
+
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
142 |
+
|
143 |
+
if task_switch['spatial']:
|
144 |
+
# spatial query
|
145 |
+
self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(3)])
|
146 |
+
trunc_normal_(self.mask_sptial_embed[0], std=.02)
|
147 |
+
trunc_normal_(self.mask_sptial_embed[1], std=.02)
|
148 |
+
trunc_normal_(self.mask_sptial_embed[2], std=.02)
|
149 |
+
|
150 |
+
self.max_spatial_len = max_spatial_len
|
151 |
+
# spatial memory
|
152 |
+
num_spatial_memories = attn_arch['SPATIAL_MEMORIES']
|
153 |
+
self.spatial_embed = nn.Embedding(num_spatial_memories, hidden_dim)
|
154 |
+
self.spatial_featured = nn.Embedding(num_spatial_memories, hidden_dim)
|
155 |
+
|
156 |
+
# learnable positive negative indicator
|
157 |
+
self.pn_indicator = nn.Embedding(2, hidden_dim)
|
158 |
+
|
159 |
+
# build AttentionDataStruct
|
160 |
+
attn_arch['NUM_LAYERS'] = self.num_layers
|
161 |
+
self.attention_data = AttentionDataStruct(attn_arch, task_switch)
|
162 |
+
|
163 |
+
@classmethod
|
164 |
+
def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
|
165 |
+
ret = {}
|
166 |
+
|
167 |
+
ret["lang_encoder"] = lang_encoder
|
168 |
+
ret["in_channels"] = in_channels
|
169 |
+
ret["mask_classification"] = mask_classification
|
170 |
+
|
171 |
+
enc_cfg = cfg['MODEL']['ENCODER']
|
172 |
+
dec_cfg = cfg['MODEL']['DECODER']
|
173 |
+
|
174 |
+
ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
|
175 |
+
ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
|
176 |
+
ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
|
177 |
+
ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
|
178 |
+
|
179 |
+
# Transformer parameters:
|
180 |
+
ret["nheads"] = dec_cfg['NHEADS']
|
181 |
+
ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
|
182 |
+
|
183 |
+
# NOTE: because we add learnable query features which requires supervision,
|
184 |
+
# we add minus 1 to decoder layers to be consistent with our loss
|
185 |
+
# implementation: that is, number of auxiliary losses is always
|
186 |
+
# equal to number of decoder layers. With learnable query features, the number of
|
187 |
+
# auxiliary losses equals number of decoders plus 1.
|
188 |
+
assert dec_cfg['DEC_LAYERS'] >= 1
|
189 |
+
ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
|
190 |
+
ret["pre_norm"] = dec_cfg['PRE_NORM']
|
191 |
+
ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
|
192 |
+
ret["mask_dim"] = enc_cfg['MASK_DIM']
|
193 |
+
ret["task_switch"] = extra['task_switch']
|
194 |
+
ret["max_spatial_len"] = dec_cfg['MAX_SPATIAL_LEN']
|
195 |
+
|
196 |
+
# attn data struct
|
197 |
+
ret["attn_arch"] = cfg['ATTENTION_ARCH']
|
198 |
+
|
199 |
+
return ret
|
200 |
+
|
201 |
+
def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
|
202 |
+
# x is a list of multi-scale feature
|
203 |
+
assert len(x) == self.num_feature_levels; del mask
|
204 |
+
spatial_extra_flag = 'spatial_query_pos_mask' in extra.keys() or task == 'refimg' or 'refimg_tokens' in extra
|
205 |
+
grounding_extra_flag = 'grounding_tokens' in extra.keys()
|
206 |
+
spatial_memory_flag = 'prev_mask' in extra.keys()
|
207 |
+
flags = {"spatial": spatial_extra_flag, "grounding": grounding_extra_flag, "memories_spatial": spatial_memory_flag}
|
208 |
+
self.attention_data.reset(flags, task, extra)
|
209 |
+
|
210 |
+
src, pos, size_list = prepare_features(x, self.num_feature_levels, self.pe_layer, self.input_proj, self.level_embed)
|
211 |
+
_, bs, _ = src[0].shape
|
212 |
+
|
213 |
+
# QxNxC
|
214 |
+
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
|
215 |
+
output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
|
216 |
+
self.attention_data.set('queries_object', 'queries', output, query_embed)
|
217 |
+
|
218 |
+
if self.task_switch['spatial'] and spatial_extra_flag:
|
219 |
+
if 'refimg_tokens' not in extra:
|
220 |
+
# get divisor
|
221 |
+
_,h,w = extra['spatial_query_pos_mask'][0].shape
|
222 |
+
divisor = torch.tensor([h,w], device=output.device)[None,]
|
223 |
+
|
224 |
+
# Get mean pos spatial query
|
225 |
+
non_zero_pos_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
|
226 |
+
non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
|
227 |
+
non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
|
228 |
+
spatial_query_pos = point_sample(mask_features, non_zero_pos_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
|
229 |
+
spatial_query_pos = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask)]).transpose(0,1).nan_to_num()
|
230 |
+
|
231 |
+
# Get mean neg spatial query
|
232 |
+
non_zero_neg_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
|
233 |
+
non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
|
234 |
+
non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
|
235 |
+
spatial_query_neg = point_sample(mask_features, non_zero_neg_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
|
236 |
+
spatial_query_neg = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask)]).transpose(0,1).nan_to_num()
|
237 |
+
|
238 |
+
# merge positive and negative sample points for self attention
|
239 |
+
# pos_neg_points = [x|y for x,y in zip(extra['spatial_query_pos_mask'], extra['spatial_query_neg_mask'])]
|
240 |
+
|
241 |
+
# Get layerwise spatial query
|
242 |
+
src_spatial_queries = []
|
243 |
+
src_spatial_maskings = []
|
244 |
+
for i in range(len(src)):
|
245 |
+
hw,_,dc = src[i].shape
|
246 |
+
src_mask_features = src[i].view(size_list[i][0],size_list[i][1],bs,dc)
|
247 |
+
src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
|
248 |
+
|
249 |
+
non_zero_query_point_pos = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
|
250 |
+
non_zero_query_point_neg = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
|
251 |
+
non_zero_query_point = [torch.cat([x,y], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
|
252 |
+
|
253 |
+
pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
|
254 |
+
pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
|
255 |
+
|
256 |
+
non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
|
257 |
+
non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
|
258 |
+
non_zero_query_point[non_zero_query_mask] = 0
|
259 |
+
|
260 |
+
spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
|
261 |
+
spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
|
262 |
+
spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
|
263 |
+
|
264 |
+
src_spatial_queries += [spatial_tokens]
|
265 |
+
src_spatial_maskings += [non_zero_query_mask]
|
266 |
+
|
267 |
+
if 'refimg' in task:
|
268 |
+
output_refimg = {}
|
269 |
+
output_refimg['spatial_query_pos'] = spatial_query_pos
|
270 |
+
output_refimg['spatial_query_neg'] = spatial_query_neg
|
271 |
+
output_refimg['src_spatial_queries'] = src_spatial_queries
|
272 |
+
output_refimg['src_spatial_maskings'] = src_spatial_maskings
|
273 |
+
return output_refimg
|
274 |
+
else:
|
275 |
+
spatial_query_pos = extra['refimg_tokens']['spatial_query_pos']
|
276 |
+
spatial_query_neg = extra['refimg_tokens']['spatial_query_neg']
|
277 |
+
src_spatial_queries = extra['refimg_tokens']['src_spatial_queries']
|
278 |
+
src_spatial_maskings = extra['refimg_tokens']['src_spatial_maskings']
|
279 |
+
|
280 |
+
# Get object query for spatial index
|
281 |
+
self.attention_data.set('queries_spatial', 'queries')
|
282 |
+
|
283 |
+
# set spatial memory
|
284 |
+
spatial_output = self.spatial_featured.weight.unsqueeze(1).repeat(1, bs, 1)
|
285 |
+
spatial_embed = self.spatial_embed.weight.unsqueeze(1).repeat(1, bs, 1)
|
286 |
+
self.attention_data.set('memories_spatial', 'memories', spatial_output, spatial_embed)
|
287 |
+
|
288 |
+
# if 'queries_spatial' in extra:
|
289 |
+
# self.attention_data.set('queries_spatial', 'queries', var=extra['queries_spatial'])
|
290 |
+
|
291 |
+
# if spatial_memory_flag:
|
292 |
+
# prev_mask = (extra['prev_mask'].sigmoid() > 0.5).detach()
|
293 |
+
# non_zero_query_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in prev_mask]
|
294 |
+
# non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
|
295 |
+
# non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
|
296 |
+
# spatial_memory = point_sample(mask_features, non_zero_query_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
|
297 |
+
# spatial_memory = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_memory.transpose(1,2), ~non_zero_query_mask)]).transpose(0,1).nan_to_num()
|
298 |
+
|
299 |
+
if self.task_switch['grounding'] and grounding_extra_flag:
|
300 |
+
# Get grounding tokens
|
301 |
+
grounding_tokens = extra['grounding_tokens']
|
302 |
+
_grounding_tokens = grounding_tokens.detach().clone()
|
303 |
+
|
304 |
+
self.attention_data.set('tokens_grounding', 'tokens', grounding_tokens, _grounding_tokens)
|
305 |
+
self.attention_data.set('queries_grounding', 'queries')
|
306 |
+
self.attention_data.set_maskings('tokens_grounding', extra['grounding_nonzero_mask'])
|
307 |
+
|
308 |
+
output, query_embed = self.attention_data.cross_attn_variables()
|
309 |
+
# prediction heads on learnable query features
|
310 |
+
results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
|
311 |
+
results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
|
312 |
+
results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
|
313 |
+
self.attention_data.set_results(results)
|
314 |
+
|
315 |
+
for i in range(self.num_layers):
|
316 |
+
level_index = i % self.num_feature_levels
|
317 |
+
# CROSS ATTENTION
|
318 |
+
output, avg_attn = self.transformer_cross_attention_layers[i](
|
319 |
+
output, src[level_index],
|
320 |
+
memory_mask=self.attention_data.cross_attn_mask(size_list[level_index], self.num_heads),
|
321 |
+
memory_key_padding_mask=None, # here we do not apply masking on padded region
|
322 |
+
pos=pos[level_index], query_pos=query_embed
|
323 |
+
)
|
324 |
+
self.attention_data.update_variables(output, 'cross_attn')
|
325 |
+
|
326 |
+
# SELF ATTENTION
|
327 |
+
self_attn_mask = torch.zeros((bs, self.num_queries, self.num_queries), device=query_embed.device).bool() # Default False (attend oq)
|
328 |
+
if self.task_switch['spatial'] and spatial_extra_flag:
|
329 |
+
# get spatial tokens
|
330 |
+
spatial_tokens = src_spatial_queries[level_index]
|
331 |
+
_spatial_tokens = spatial_tokens.detach().clone()
|
332 |
+
|
333 |
+
self.attention_data.set('tokens_spatial', 'tokens', spatial_tokens, _spatial_tokens)
|
334 |
+
self.attention_data.set_maskings('tokens_spatial', src_spatial_maskings[level_index])
|
335 |
+
|
336 |
+
output, query_embed, self_attn_mask = self.attention_data.self_attn(bs, self.num_heads)
|
337 |
+
|
338 |
+
output = self.transformer_self_attention_layers[i](
|
339 |
+
output, tgt_mask=self_attn_mask,
|
340 |
+
tgt_key_padding_mask=None,
|
341 |
+
query_pos=query_embed)
|
342 |
+
|
343 |
+
# FFN
|
344 |
+
output = self.transformer_ffn_layers[i](
|
345 |
+
output
|
346 |
+
)
|
347 |
+
|
348 |
+
self.attention_data.update_variables(output, 'self_attn')
|
349 |
+
output, query_embed = self.attention_data.cross_attn_variables()
|
350 |
+
results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i)
|
351 |
+
results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
|
352 |
+
results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
|
353 |
+
self.attention_data.set_results(results)
|
354 |
+
|
355 |
+
return self.attention_data.organize_output()
|
356 |
+
|
357 |
+
def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1):
|
358 |
+
decoder_output = self.decoder_norm(output)
|
359 |
+
decoder_output = decoder_output.transpose(0, 1)
|
360 |
+
class_embed = decoder_output @ self.class_embed
|
361 |
+
outputs_class = self.lang_encoder.compute_similarity(class_embed)
|
362 |
+
mask_embed = self.mask_embed(decoder_output)
|
363 |
+
outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
|
364 |
+
|
365 |
+
outputs_bbox = [None for i in range(len(outputs_mask))]
|
366 |
+
if self.task_switch['bbox']:
|
367 |
+
outputs_bbox = self.bbox_embed(decoder_output)
|
368 |
+
|
369 |
+
# NOTE: prediction is of higher-resolution
|
370 |
+
# [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
|
371 |
+
attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
|
372 |
+
|
373 |
+
# must use bool type
|
374 |
+
# If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
|
375 |
+
attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
|
376 |
+
attn_mask = attn_mask.detach()
|
377 |
+
|
378 |
+
outputs_caption = class_embed
|
379 |
+
|
380 |
+
results = {
|
381 |
+
"attn_mask": attn_mask,
|
382 |
+
"predictions_class": outputs_class,
|
383 |
+
"predictions_mask": outputs_mask,
|
384 |
+
"predictions_bbox": outputs_bbox,
|
385 |
+
"predictions_caption": outputs_caption,
|
386 |
+
"predictions_maskemb": mask_embed,
|
387 |
+
}
|
388 |
+
return results
|
389 |
+
|
390 |
+
@register_decoder
|
391 |
+
def get_seem_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
|
392 |
+
return SEEMDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
|
modeling/interface/seem_v1.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# SEEM -- Segment Everything Everywhere All at Once
|
3 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
4 |
+
# Written by Xueyan Zou ([email protected])
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn, Tensor
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
from timm.models.layers import trunc_normal_
|
15 |
+
from detectron2.layers import Conv2d
|
16 |
+
import fvcore.nn.weight_init as weight_init
|
17 |
+
|
18 |
+
from .build import register_decoder
|
19 |
+
from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
|
20 |
+
from .prototype.attention_data_struct_seemv1 import AttentionDataStruct
|
21 |
+
from ..utils import rand_sample, prepare_features, configurable
|
22 |
+
from ..modules import PositionEmbeddingSine
|
23 |
+
from ..modules.point_features import point_sample
|
24 |
+
|
25 |
+
|
26 |
+
class SEEMDecoder(nn.Module):
|
27 |
+
|
28 |
+
@configurable
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
lang_encoder: nn.Module,
|
32 |
+
in_channels,
|
33 |
+
mask_classification=True,
|
34 |
+
*,
|
35 |
+
hidden_dim: int,
|
36 |
+
dim_proj: int,
|
37 |
+
num_queries: int,
|
38 |
+
contxt_len: int,
|
39 |
+
nheads: int,
|
40 |
+
dim_feedforward: int,
|
41 |
+
dec_layers: int,
|
42 |
+
pre_norm: bool,
|
43 |
+
mask_dim: int,
|
44 |
+
task_switch: dict,
|
45 |
+
enforce_input_project: bool,
|
46 |
+
max_spatial_len: int,
|
47 |
+
attn_arch: dict,
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
NOTE: this interface is experimental.
|
51 |
+
Args:
|
52 |
+
in_channels: channels of the input features
|
53 |
+
mask_classification: whether to add mask classifier or not
|
54 |
+
num_classes: number of classes
|
55 |
+
hidden_dim: Transformer feature dimension
|
56 |
+
num_queries: number of queries
|
57 |
+
nheads: number of heads
|
58 |
+
dim_feedforward: feature dimension in feedforward network
|
59 |
+
enc_layers: number of Transformer encoder layers
|
60 |
+
dec_layers: number of Transformer decoder layers
|
61 |
+
pre_norm: whether to use pre-LayerNorm or not
|
62 |
+
mask_dim: mask feature dimension
|
63 |
+
enforce_input_project: add input project 1x1 conv even if input
|
64 |
+
channels and hidden dim is identical
|
65 |
+
"""
|
66 |
+
super().__init__()
|
67 |
+
assert mask_classification, "Only support mask classification model"
|
68 |
+
self.mask_classification = mask_classification
|
69 |
+
|
70 |
+
# positional encoding
|
71 |
+
N_steps = hidden_dim // 2
|
72 |
+
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
|
73 |
+
|
74 |
+
# define Transformer decoder here
|
75 |
+
self.num_heads = nheads
|
76 |
+
self.num_layers = dec_layers
|
77 |
+
self.contxt_len = contxt_len
|
78 |
+
self.transformer_self_attention_layers = nn.ModuleList()
|
79 |
+
self.transformer_cross_attention_layers = nn.ModuleList()
|
80 |
+
self.transformer_ffn_layers = nn.ModuleList()
|
81 |
+
|
82 |
+
for _ in range(self.num_layers):
|
83 |
+
self.transformer_self_attention_layers.append(
|
84 |
+
SelfAttentionLayer(
|
85 |
+
d_model=hidden_dim,
|
86 |
+
nhead=nheads,
|
87 |
+
dropout=0.0,
|
88 |
+
normalize_before=pre_norm,
|
89 |
+
)
|
90 |
+
)
|
91 |
+
|
92 |
+
self.transformer_cross_attention_layers.append(
|
93 |
+
CrossAttentionLayer(
|
94 |
+
d_model=hidden_dim,
|
95 |
+
nhead=nheads,
|
96 |
+
dropout=0.0,
|
97 |
+
normalize_before=pre_norm,
|
98 |
+
)
|
99 |
+
)
|
100 |
+
|
101 |
+
self.transformer_ffn_layers.append(
|
102 |
+
FFNLayer(
|
103 |
+
d_model=hidden_dim,
|
104 |
+
dim_feedforward=dim_feedforward,
|
105 |
+
dropout=0.0,
|
106 |
+
normalize_before=pre_norm,
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
self.decoder_norm = nn.LayerNorm(hidden_dim)
|
111 |
+
|
112 |
+
self.num_queries = num_queries
|
113 |
+
# learnable query features
|
114 |
+
self.query_feat = nn.Embedding(num_queries, hidden_dim)
|
115 |
+
# learnable query p.e.
|
116 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
117 |
+
|
118 |
+
# level embedding (we always use 3 scales)
|
119 |
+
self.num_feature_levels = 3
|
120 |
+
self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
|
121 |
+
self.input_proj = nn.ModuleList()
|
122 |
+
|
123 |
+
for _ in range(self.num_feature_levels):
|
124 |
+
if in_channels != hidden_dim or enforce_input_project:
|
125 |
+
self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
|
126 |
+
weight_init.c2_xavier_fill(self.input_proj[-1])
|
127 |
+
else:
|
128 |
+
self.input_proj.append(nn.Sequential())
|
129 |
+
|
130 |
+
self.task_switch = task_switch
|
131 |
+
self.query_index = {}
|
132 |
+
|
133 |
+
# output FFNs
|
134 |
+
self.lang_encoder = lang_encoder
|
135 |
+
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
|
136 |
+
self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
|
137 |
+
trunc_normal_(self.class_embed, std=.02)
|
138 |
+
|
139 |
+
if task_switch['bbox']:
|
140 |
+
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
141 |
+
|
142 |
+
if task_switch['spatial']:
|
143 |
+
# spatial query
|
144 |
+
self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(3)])
|
145 |
+
trunc_normal_(self.mask_sptial_embed[0], std=.02)
|
146 |
+
trunc_normal_(self.mask_sptial_embed[1], std=.02)
|
147 |
+
trunc_normal_(self.mask_sptial_embed[2], std=.02)
|
148 |
+
|
149 |
+
self.max_spatial_len = max_spatial_len
|
150 |
+
# spatial memory
|
151 |
+
num_spatial_memories = attn_arch['SPATIAL_MEMORIES']
|
152 |
+
self.spatial_embed = nn.Embedding(num_spatial_memories, hidden_dim)
|
153 |
+
self.spatial_featured = nn.Embedding(num_spatial_memories, hidden_dim)
|
154 |
+
|
155 |
+
# learnable positive negative indicator
|
156 |
+
self.pn_indicator = nn.Embedding(2, hidden_dim)
|
157 |
+
|
158 |
+
# build AttentionDataStruct
|
159 |
+
attn_arch['NUM_LAYERS'] = self.num_layers
|
160 |
+
self.attention_data = AttentionDataStruct(attn_arch, task_switch)
|
161 |
+
self.sample_size = attn_arch['QUERY_NUMBER']
|
162 |
+
|
163 |
+
@classmethod
|
164 |
+
def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
|
165 |
+
ret = {}
|
166 |
+
|
167 |
+
ret["lang_encoder"] = lang_encoder
|
168 |
+
ret["in_channels"] = in_channels
|
169 |
+
ret["mask_classification"] = mask_classification
|
170 |
+
|
171 |
+
enc_cfg = cfg['MODEL']['ENCODER']
|
172 |
+
dec_cfg = cfg['MODEL']['DECODER']
|
173 |
+
|
174 |
+
ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
|
175 |
+
ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
|
176 |
+
ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
|
177 |
+
ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
|
178 |
+
|
179 |
+
# Transformer parameters:
|
180 |
+
ret["nheads"] = dec_cfg['NHEADS']
|
181 |
+
ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
|
182 |
+
|
183 |
+
# NOTE: because we add learnable query features which requires supervision,
|
184 |
+
# we add minus 1 to decoder layers to be consistent with our loss
|
185 |
+
# implementation: that is, number of auxiliary losses is always
|
186 |
+
# equal to number of decoder layers. With learnable query features, the number of
|
187 |
+
# auxiliary losses equals number of decoders plus 1.
|
188 |
+
assert dec_cfg['DEC_LAYERS'] >= 1
|
189 |
+
ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
|
190 |
+
ret["pre_norm"] = dec_cfg['PRE_NORM']
|
191 |
+
ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
|
192 |
+
ret["mask_dim"] = enc_cfg['MASK_DIM']
|
193 |
+
ret["task_switch"] = extra['task_switch']
|
194 |
+
ret["max_spatial_len"] = dec_cfg['MAX_SPATIAL_LEN']
|
195 |
+
|
196 |
+
# attn data struct
|
197 |
+
ret["attn_arch"] = cfg['ATTENTION_ARCH']
|
198 |
+
|
199 |
+
return ret
|
200 |
+
|
201 |
+
def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
|
202 |
+
# x is a list of multi-scale feature
|
203 |
+
assert len(x) == self.num_feature_levels; del mask
|
204 |
+
spatial_extra_flag = 'spatial_query_pos_mask' in extra.keys() or task == 'refimg' or 'refimg_tokens' in extra
|
205 |
+
grounding_extra_flag = 'grounding_tokens' in extra.keys()
|
206 |
+
spatial_memory_flag = 'prev_mask' in extra.keys()
|
207 |
+
flags = {"spatial": spatial_extra_flag, "grounding": grounding_extra_flag, "memories_spatial": spatial_memory_flag}
|
208 |
+
self.attention_data.reset(flags, task, extra)
|
209 |
+
|
210 |
+
src, pos, size_list = prepare_features(x, self.num_feature_levels, self.pe_layer, self.input_proj, self.level_embed)
|
211 |
+
_,bs,_ = src[0].shape
|
212 |
+
|
213 |
+
# QxNxC
|
214 |
+
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
|
215 |
+
output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
|
216 |
+
self.attention_data.set('queries_object', 'queries', output, query_embed)
|
217 |
+
|
218 |
+
if self.task_switch['spatial'] and spatial_extra_flag:
|
219 |
+
if 'refimg_tokens' not in extra:
|
220 |
+
# get divisor
|
221 |
+
c,h,w = extra['spatial_query_pos_mask'][0].shape
|
222 |
+
divisor = torch.tensor([1,h,w], device=output.device)[None,]
|
223 |
+
|
224 |
+
# Get mean pos spatial query
|
225 |
+
non_zero_pos_point = [rand_sample(m, divisor, self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
|
226 |
+
non_zero_pos_index = [m[:,0:1].long() for m in non_zero_pos_point]
|
227 |
+
non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
|
228 |
+
non_zero_pos_index = nn.utils.rnn.pad_sequence(non_zero_pos_index, padding_value=-1).permute(1,0,2)[:,:,0]
|
229 |
+
non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
|
230 |
+
spatial_query_pos = point_sample(mask_features, non_zero_pos_point[:,:,1:].flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
|
231 |
+
num_mask_per_batch = [len(m) for m in extra['spatial_query_pos_mask']]
|
232 |
+
spatial_query_pos = nn.utils.rnn.pad_sequence([torch.stack([x[ns==n].mean(dim=0, keepdim=False) if (ns==n).sum() > 0 else -torch.ones((x.shape[1]), device=spatial_query_pos.device) for n in range(mb)]) for x, m, ns, mb in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask, non_zero_pos_index, num_mask_per_batch)], padding_value=-1).nan_to_num()
|
233 |
+
|
234 |
+
# Get mean neg spatial query
|
235 |
+
non_zero_neg_point = [rand_sample(m, divisor, self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
|
236 |
+
non_zero_neg_index = [m[:,0:1].long() for m in non_zero_neg_point]
|
237 |
+
non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
|
238 |
+
non_zero_neg_index = nn.utils.rnn.pad_sequence(non_zero_neg_index, padding_value=-1).permute(1,0,2)[:,:,0]
|
239 |
+
non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
|
240 |
+
spatial_query_neg = point_sample(mask_features, non_zero_neg_point[:,:,1:].flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
|
241 |
+
num_mask_per_batch = [len(m) for m in extra['spatial_query_neg_mask']]
|
242 |
+
spatial_query_neg = nn.utils.rnn.pad_sequence([torch.stack([x[ns==n].mean(dim=0, keepdim=False) if (ns==n).sum() > 0 else -torch.ones((x.shape[1]), device=spatial_query_neg.device) for n in range(mb)]) for x, m, ns, mb in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask, non_zero_neg_index, num_mask_per_batch)], padding_value=-1).nan_to_num()
|
243 |
+
# Get layerwise spatial query
|
244 |
+
src_spatial_queries = []
|
245 |
+
src_spatial_maskings = []
|
246 |
+
src_spatial_indices = []
|
247 |
+
for i in range(len(src)):
|
248 |
+
hw,_,dc = src[i].shape
|
249 |
+
src_mask_features = src[i].view(size_list[i][0],size_list[i][1],bs,dc)
|
250 |
+
src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
|
251 |
+
|
252 |
+
non_zero_query_point_pos = [rand_sample(m, divisor, self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
|
253 |
+
non_zero_query_point_neg = [rand_sample(m, divisor, self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
|
254 |
+
non_zero_query_point = [torch.cat([x[:,1:],y[:,1:]], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
|
255 |
+
non_zero_query_index = [torch.cat([x[:,0:1],y[:,0:1]], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
|
256 |
+
|
257 |
+
pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
|
258 |
+
pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
|
259 |
+
|
260 |
+
non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
|
261 |
+
non_zero_query_index = nn.utils.rnn.pad_sequence(non_zero_query_index, padding_value=-1).permute(1,0,2)
|
262 |
+
non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
|
263 |
+
non_zero_query_point[non_zero_query_mask] = 0
|
264 |
+
|
265 |
+
spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
|
266 |
+
spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
|
267 |
+
spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
|
268 |
+
|
269 |
+
src_spatial_queries += [spatial_tokens]
|
270 |
+
src_spatial_maskings += [non_zero_query_mask]
|
271 |
+
src_spatial_indices += [non_zero_query_index]
|
272 |
+
|
273 |
+
if 'refimg' in task:
|
274 |
+
output_refimg = {}
|
275 |
+
output_refimg['spatial_query_pos'] = spatial_query_pos
|
276 |
+
output_refimg['spatial_query_neg'] = spatial_query_neg
|
277 |
+
output_refimg['src_spatial_queries'] = src_spatial_queries
|
278 |
+
output_refimg['src_spatial_maskings'] = src_spatial_maskings
|
279 |
+
return output_refimg
|
280 |
+
else:
|
281 |
+
spatial_query_pos = extra['refimg_tokens']['spatial_query_pos']
|
282 |
+
spatial_query_neg = extra['refimg_tokens']['spatial_query_neg']
|
283 |
+
src_spatial_queries = extra['refimg_tokens']['src_spatial_queries']
|
284 |
+
src_spatial_maskings = extra['refimg_tokens']['src_spatial_maskings']
|
285 |
+
|
286 |
+
# Get object query for spatial index
|
287 |
+
self.attention_data.set_extra({"spatial_query_number": len(spatial_query_pos), "sample_size": self.sample_size})
|
288 |
+
self.attention_data.set('queries_spatial', 'queries', sample_size=self.sample_size*len(spatial_query_pos))
|
289 |
+
|
290 |
+
# set spatial memory
|
291 |
+
spatial_output = self.spatial_featured.weight.unsqueeze(1).repeat(1, bs, 1)
|
292 |
+
spatial_embed = self.spatial_embed.weight.unsqueeze(1).repeat(1, bs, 1)
|
293 |
+
self.attention_data.set('memories_spatial', 'memories', spatial_output, spatial_embed)
|
294 |
+
|
295 |
+
if self.task_switch['grounding'] and grounding_extra_flag:
|
296 |
+
# Get grounding tokens
|
297 |
+
grounding_tokens = extra['grounding_tokens']
|
298 |
+
_grounding_tokens = grounding_tokens.detach().clone()
|
299 |
+
|
300 |
+
self.attention_data.set('tokens_grounding', 'tokens', grounding_tokens, _grounding_tokens)
|
301 |
+
self.attention_data.set('queries_grounding', 'queries')
|
302 |
+
self.attention_data.set_maskings('tokens_grounding', extra['grounding_nonzero_mask'])
|
303 |
+
|
304 |
+
output, query_embed = self.attention_data.cross_attn_variables()
|
305 |
+
# prediction heads on learnable query features
|
306 |
+
results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
|
307 |
+
results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
|
308 |
+
results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
|
309 |
+
self.attention_data.set_results(results)
|
310 |
+
|
311 |
+
for i in range(self.num_layers):
|
312 |
+
level_index = i % self.num_feature_levels
|
313 |
+
# CROSS ATTENTION
|
314 |
+
output, avg_attn = self.transformer_cross_attention_layers[i](
|
315 |
+
output, src[level_index],
|
316 |
+
memory_mask=self.attention_data.cross_attn_mask(size_list[level_index], self.num_heads),
|
317 |
+
memory_key_padding_mask=None, # here we do not apply masking on padded region
|
318 |
+
pos=pos[level_index], query_pos=query_embed
|
319 |
+
)
|
320 |
+
self.attention_data.update_variables(output, 'cross_attn')
|
321 |
+
|
322 |
+
# SELF ATTENTION
|
323 |
+
self_attn_mask = torch.zeros((bs, self.num_queries, self.num_queries), device=query_embed.device).bool() # Default False (attend oq)
|
324 |
+
if self.task_switch['spatial'] and spatial_extra_flag:
|
325 |
+
# get spatial tokens
|
326 |
+
spatial_tokens = src_spatial_queries[level_index]
|
327 |
+
_spatial_tokens = spatial_tokens.detach().clone()
|
328 |
+
|
329 |
+
self.attention_data.set('tokens_spatial', 'tokens', spatial_tokens, _spatial_tokens)
|
330 |
+
self.attention_data.set_maskings('tokens_spatial', src_spatial_maskings[level_index])
|
331 |
+
self.attention_data.set_extra({"spatial_indices": src_spatial_indices[level_index]})
|
332 |
+
|
333 |
+
output, query_embed, self_attn_mask = self.attention_data.self_attn(bs, self.num_heads)
|
334 |
+
|
335 |
+
output = self.transformer_self_attention_layers[i](
|
336 |
+
output, tgt_mask=self_attn_mask,
|
337 |
+
tgt_key_padding_mask=None,
|
338 |
+
query_pos=query_embed)
|
339 |
+
|
340 |
+
# FFN
|
341 |
+
output = self.transformer_ffn_layers[i](
|
342 |
+
output
|
343 |
+
)
|
344 |
+
|
345 |
+
self.attention_data.update_variables(output, 'self_attn')
|
346 |
+
output, query_embed = self.attention_data.cross_attn_variables()
|
347 |
+
results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i)
|
348 |
+
results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
|
349 |
+
results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
|
350 |
+
self.attention_data.set_results(results)
|
351 |
+
|
352 |
+
return self.attention_data.organize_output()
|
353 |
+
|
354 |
+
def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1):
|
355 |
+
decoder_output = self.decoder_norm(output)
|
356 |
+
decoder_output = decoder_output.transpose(0, 1)
|
357 |
+
class_embed = decoder_output @ self.class_embed
|
358 |
+
outputs_class = self.lang_encoder.compute_similarity(class_embed)
|
359 |
+
mask_embed = self.mask_embed(decoder_output)
|
360 |
+
outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
|
361 |
+
|
362 |
+
outputs_bbox = [None for i in range(len(outputs_mask))]
|
363 |
+
if self.task_switch['bbox']:
|
364 |
+
outputs_bbox = self.bbox_embed(decoder_output)
|
365 |
+
|
366 |
+
# NOTE: prediction is of higher-resolution
|
367 |
+
# [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
|
368 |
+
attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
|
369 |
+
|
370 |
+
# must use bool type
|
371 |
+
# If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
|
372 |
+
attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
|
373 |
+
attn_mask = attn_mask.detach()
|
374 |
+
|
375 |
+
outputs_caption = class_embed
|
376 |
+
|
377 |
+
results = {
|
378 |
+
"attn_mask": attn_mask,
|
379 |
+
"predictions_class": outputs_class,
|
380 |
+
"predictions_mask": outputs_mask,
|
381 |
+
"predictions_bbox": outputs_bbox,
|
382 |
+
"predictions_caption": outputs_caption,
|
383 |
+
"predictions_maskemb": mask_embed,
|
384 |
+
}
|
385 |
+
return results
|
386 |
+
|
387 |
+
@register_decoder
|
388 |
+
def get_seem_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
|
389 |
+
return SEEMDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
|
modeling/interface/xdecoder.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Xueyan Zou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import logging
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn, Tensor
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
from timm.models.layers import trunc_normal_
|
16 |
+
from detectron2.layers import Conv2d
|
17 |
+
import fvcore.nn.weight_init as weight_init
|
18 |
+
|
19 |
+
from .build import register_decoder
|
20 |
+
from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
|
21 |
+
from ..utils import configurable
|
22 |
+
from ..modules import PositionEmbeddingSine
|
23 |
+
|
24 |
+
|
25 |
+
class XDecoder(nn.Module):
|
26 |
+
|
27 |
+
@configurable
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
lang_encoder: nn.Module,
|
31 |
+
in_channels,
|
32 |
+
mask_classification=True,
|
33 |
+
*,
|
34 |
+
hidden_dim: int,
|
35 |
+
dim_proj: int,
|
36 |
+
num_queries: int,
|
37 |
+
contxt_len: int,
|
38 |
+
nheads: int,
|
39 |
+
dim_feedforward: int,
|
40 |
+
dec_layers: int,
|
41 |
+
pre_norm: bool,
|
42 |
+
mask_dim: int,
|
43 |
+
task_switch: dict,
|
44 |
+
captioning_step: int,
|
45 |
+
enforce_input_project: bool,
|
46 |
+
):
|
47 |
+
"""
|
48 |
+
NOTE: this interface is experimental.
|
49 |
+
Args:
|
50 |
+
in_channels: channels of the input features
|
51 |
+
mask_classification: whether to add mask classifier or not
|
52 |
+
num_classes: number of classes
|
53 |
+
hidden_dim: Transformer feature dimension
|
54 |
+
num_queries: number of queries
|
55 |
+
nheads: number of heads
|
56 |
+
dim_feedforward: feature dimension in feedforward network
|
57 |
+
enc_layers: number of Transformer encoder layers
|
58 |
+
dec_layers: number of Transformer decoder layers
|
59 |
+
pre_norm: whether to use pre-LayerNorm or not
|
60 |
+
mask_dim: mask feature dimension
|
61 |
+
enforce_input_project: add input project 1x1 conv even if input
|
62 |
+
channels and hidden dim is identical
|
63 |
+
"""
|
64 |
+
super().__init__()
|
65 |
+
assert mask_classification, "Only support mask classification model"
|
66 |
+
self.mask_classification = mask_classification
|
67 |
+
|
68 |
+
# positional encoding
|
69 |
+
N_steps = hidden_dim // 2
|
70 |
+
self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
|
71 |
+
|
72 |
+
# define Transformer decoder here
|
73 |
+
self.num_heads = nheads
|
74 |
+
self.num_layers = dec_layers
|
75 |
+
self.contxt_len = contxt_len
|
76 |
+
self.transformer_self_attention_layers = nn.ModuleList()
|
77 |
+
self.transformer_cross_attention_layers = nn.ModuleList()
|
78 |
+
self.transformer_ffn_layers = nn.ModuleList()
|
79 |
+
|
80 |
+
for _ in range(self.num_layers):
|
81 |
+
self.transformer_self_attention_layers.append(
|
82 |
+
SelfAttentionLayer(
|
83 |
+
d_model=hidden_dim,
|
84 |
+
nhead=nheads,
|
85 |
+
dropout=0.0,
|
86 |
+
normalize_before=pre_norm,
|
87 |
+
)
|
88 |
+
)
|
89 |
+
|
90 |
+
self.transformer_cross_attention_layers.append(
|
91 |
+
CrossAttentionLayer(
|
92 |
+
d_model=hidden_dim,
|
93 |
+
nhead=nheads,
|
94 |
+
dropout=0.0,
|
95 |
+
normalize_before=pre_norm,
|
96 |
+
)
|
97 |
+
)
|
98 |
+
|
99 |
+
self.transformer_ffn_layers.append(
|
100 |
+
FFNLayer(
|
101 |
+
d_model=hidden_dim,
|
102 |
+
dim_feedforward=dim_feedforward,
|
103 |
+
dropout=0.0,
|
104 |
+
normalize_before=pre_norm,
|
105 |
+
)
|
106 |
+
)
|
107 |
+
|
108 |
+
self.decoder_norm = nn.LayerNorm(hidden_dim)
|
109 |
+
|
110 |
+
self.num_queries = num_queries
|
111 |
+
# learnable query features
|
112 |
+
self.query_feat = nn.Embedding(num_queries, hidden_dim)
|
113 |
+
# learnable query p.e.
|
114 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
115 |
+
|
116 |
+
# level embedding (we always use 3 scales)
|
117 |
+
self.num_feature_levels = 3
|
118 |
+
self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
|
119 |
+
self.input_proj = nn.ModuleList()
|
120 |
+
|
121 |
+
for _ in range(self.num_feature_levels):
|
122 |
+
if in_channels != hidden_dim or enforce_input_project:
|
123 |
+
self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
|
124 |
+
weight_init.c2_xavier_fill(self.input_proj[-1])
|
125 |
+
else:
|
126 |
+
self.input_proj.append(nn.Sequential())
|
127 |
+
|
128 |
+
self.task_switch = task_switch
|
129 |
+
|
130 |
+
# output FFNs
|
131 |
+
self.lang_encoder = lang_encoder
|
132 |
+
if self.task_switch['mask']:
|
133 |
+
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
|
134 |
+
|
135 |
+
self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
|
136 |
+
trunc_normal_(self.class_embed, std=.02)
|
137 |
+
|
138 |
+
if task_switch['bbox']:
|
139 |
+
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
140 |
+
|
141 |
+
# Caption Project and query
|
142 |
+
if task_switch['captioning']:
|
143 |
+
self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
|
144 |
+
trunc_normal_(self.caping_embed, std=.02)
|
145 |
+
self.pos_embed_caping = nn.Embedding(contxt_len, hidden_dim)
|
146 |
+
self.captioning_step = captioning_step
|
147 |
+
|
148 |
+
# register self_attn_mask to avoid information leakage, it includes interaction between object query, class query and caping query
|
149 |
+
self_attn_mask = torch.zeros((1, num_queries + contxt_len, num_queries + contxt_len)).bool()
|
150 |
+
self_attn_mask[:, :num_queries, num_queries:] = True # object+class query does not attend with caption query.
|
151 |
+
self_attn_mask[:, num_queries:, num_queries:] = torch.triu(torch.ones((1, contxt_len, contxt_len)), diagonal=1).bool() # caption query only attend with previous token.
|
152 |
+
self_attn_mask[:, :num_queries-1, num_queries-1:num_queries] = True # object query does not attend with class query.
|
153 |
+
self_attn_mask[:, num_queries-1:num_queries, :num_queries-1] = True # class query does not attend with object query.
|
154 |
+
self.register_buffer("self_attn_mask", self_attn_mask)
|
155 |
+
|
156 |
+
|
157 |
+
@classmethod
|
158 |
+
def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
|
159 |
+
ret = {}
|
160 |
+
|
161 |
+
ret["lang_encoder"] = lang_encoder
|
162 |
+
ret["in_channels"] = in_channels
|
163 |
+
ret["mask_classification"] = mask_classification
|
164 |
+
|
165 |
+
enc_cfg = cfg['MODEL']['ENCODER']
|
166 |
+
dec_cfg = cfg['MODEL']['DECODER']
|
167 |
+
|
168 |
+
ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
|
169 |
+
ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
|
170 |
+
ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
|
171 |
+
ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
|
172 |
+
|
173 |
+
# Transformer parameters:
|
174 |
+
ret["nheads"] = dec_cfg['NHEADS']
|
175 |
+
ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
|
176 |
+
|
177 |
+
# NOTE: because we add learnable query features which requires supervision,
|
178 |
+
# we add minus 1 to decoder layers to be consistent with our loss
|
179 |
+
# implementation: that is, number of auxiliary losses is always
|
180 |
+
# equal to number of decoder layers. With learnable query features, the number of
|
181 |
+
# auxiliary losses equals number of decoders plus 1.
|
182 |
+
assert dec_cfg['DEC_LAYERS'] >= 1
|
183 |
+
ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
|
184 |
+
ret["pre_norm"] = dec_cfg['PRE_NORM']
|
185 |
+
ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
|
186 |
+
ret["mask_dim"] = enc_cfg['MASK_DIM']
|
187 |
+
|
188 |
+
ret["task_switch"] = extra['task_switch']
|
189 |
+
ret["captioning_step"] = dec_cfg['CAPTIONING'].get('STEP', 50)
|
190 |
+
|
191 |
+
return ret
|
192 |
+
|
193 |
+
def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
|
194 |
+
if task == 'captioning_infer':
|
195 |
+
return self.forward_captioning(x, mask_features, mask=mask, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra)
|
196 |
+
# x is a list of multi-scale feature
|
197 |
+
assert len(x) == self.num_feature_levels
|
198 |
+
src = []
|
199 |
+
pos = []
|
200 |
+
size_list = []
|
201 |
+
|
202 |
+
# disable mask, it does not affect performance
|
203 |
+
del mask
|
204 |
+
for i in range(self.num_feature_levels):
|
205 |
+
size_list.append(x[i].shape[-2:])
|
206 |
+
pos.append(self.pe_layer(x[i], None).flatten(2))
|
207 |
+
src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
|
208 |
+
|
209 |
+
# flatten NxCxHxW to HWxNxC
|
210 |
+
pos[-1] = pos[-1].permute(2, 0, 1)
|
211 |
+
src[-1] = src[-1].permute(2, 0, 1)
|
212 |
+
|
213 |
+
_, bs, _ = src[0].shape
|
214 |
+
|
215 |
+
# QxNxC
|
216 |
+
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
|
217 |
+
output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
|
218 |
+
|
219 |
+
predictions_class = []
|
220 |
+
predictions_mask = []
|
221 |
+
predictions_bbox = []
|
222 |
+
predictions_caption = []
|
223 |
+
predictions_captioning = []
|
224 |
+
|
225 |
+
self_tgt_mask = None
|
226 |
+
if self.training and task == 'vlp' and self.task_switch['captioning']:
|
227 |
+
# output = torch.cat((output, self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1)), dim=0) # concat object query, class token and caption token.
|
228 |
+
caping_lang_embed = torch.cat([caption['caption_tokens'] for caption in target_vlp], dim=0).transpose(0, 1) # language output
|
229 |
+
_caping_lang_embed = caping_lang_embed.detach().clone()
|
230 |
+
output = torch.cat((output, _caping_lang_embed), dim=0) # concat object query, class token and caption token.
|
231 |
+
caping_lang_embed += self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1)
|
232 |
+
query_embed = torch.cat((query_embed, caping_lang_embed), dim=0) # may not add at the beginning.
|
233 |
+
self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1)
|
234 |
+
elif (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']):
|
235 |
+
self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1)
|
236 |
+
grounding_tokens = extra['grounding_tokens']
|
237 |
+
_grounding_tokens = grounding_tokens.detach().clone()
|
238 |
+
# initialize with negative attention at the beginning.
|
239 |
+
pad_tgt_mask = torch.ones((1, self.num_queries + (self.num_queries-1) + len(grounding_tokens), self.num_queries + (self.num_queries-1) + len(grounding_tokens)), device=self_tgt_mask.device).bool().repeat(output.shape[1]*self.num_heads, 1, 1)
|
240 |
+
pad_tgt_mask[:,:self.num_queries,:self.num_queries] = self_tgt_mask
|
241 |
+
pad_tgt_mask[:,self.num_queries:,self.num_queries:] = False # grounding tokens could attend with eatch other
|
242 |
+
self_tgt_mask = pad_tgt_mask
|
243 |
+
output = torch.cat((output, output[:-1]), dim=0)
|
244 |
+
query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0) # also pad language embdding to fix embedding
|
245 |
+
else:
|
246 |
+
self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1)
|
247 |
+
|
248 |
+
# prediction heads on learnable query features
|
249 |
+
results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task)
|
250 |
+
attn_mask = results["attn_mask"]
|
251 |
+
predictions_class.append(results["outputs_class"])
|
252 |
+
predictions_mask.append(results["outputs_mask"])
|
253 |
+
predictions_bbox.append(results["outputs_bbox"])
|
254 |
+
predictions_caption.append(results["outputs_caption"])
|
255 |
+
predictions_captioning.append(results["outputs_captionting"])
|
256 |
+
|
257 |
+
for i in range(self.num_layers):
|
258 |
+
level_index = i % self.num_feature_levels
|
259 |
+
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
|
260 |
+
|
261 |
+
if self.training and task == 'vlp' and self.task_switch['captioning']:
|
262 |
+
attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1)
|
263 |
+
# attention: cross-attention first
|
264 |
+
output, avg_attn = self.transformer_cross_attention_layers[i](
|
265 |
+
output, src[level_index],
|
266 |
+
memory_mask=attn_mask,
|
267 |
+
memory_key_padding_mask=None, # here we do not apply masking on padded region
|
268 |
+
pos=pos[level_index], query_pos=query_embed
|
269 |
+
)
|
270 |
+
|
271 |
+
if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']):
|
272 |
+
output = torch.cat((output, _grounding_tokens), dim=0)
|
273 |
+
query_embed = torch.cat((query_embed, grounding_tokens), dim=0)
|
274 |
+
|
275 |
+
output = self.transformer_self_attention_layers[i](
|
276 |
+
output, tgt_mask=self_tgt_mask,
|
277 |
+
tgt_key_padding_mask=None,
|
278 |
+
query_pos=query_embed
|
279 |
+
)
|
280 |
+
|
281 |
+
# FFN
|
282 |
+
output = self.transformer_ffn_layers[i](
|
283 |
+
output
|
284 |
+
)
|
285 |
+
|
286 |
+
if ((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']:
|
287 |
+
_grounding_tokens = output[-len(_grounding_tokens):]
|
288 |
+
output = output[:-len(_grounding_tokens)]
|
289 |
+
query_embed = query_embed[:-len(_grounding_tokens)]
|
290 |
+
|
291 |
+
results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task)
|
292 |
+
attn_mask = results["attn_mask"]
|
293 |
+
predictions_class.append(results["outputs_class"])
|
294 |
+
predictions_mask.append(results["outputs_mask"])
|
295 |
+
predictions_bbox.append(results["outputs_bbox"])
|
296 |
+
predictions_caption.append(results["outputs_caption"])
|
297 |
+
predictions_captioning.append(results["outputs_captionting"])
|
298 |
+
|
299 |
+
assert len(predictions_class) == self.num_layers + 1
|
300 |
+
if task == 'vlp':
|
301 |
+
out = {'pred_captionings': predictions_captioning[-1],
|
302 |
+
'pred_captions': predictions_caption[-1],
|
303 |
+
'aux_outputs': [{'pred_captionings': x, 'pred_captions': y } for x, y in zip(predictions_captioning[:-1], predictions_caption[:-1])]}
|
304 |
+
return out
|
305 |
+
else:
|
306 |
+
out = {
|
307 |
+
'pred_logits': predictions_class[-1],
|
308 |
+
'pred_masks': predictions_mask[-1],
|
309 |
+
'pred_boxes': predictions_bbox[-1],
|
310 |
+
'pred_captions': predictions_caption[-1],
|
311 |
+
'aux_outputs': self._set_aux_loss(
|
312 |
+
predictions_class if self.mask_classification else None, predictions_mask, predictions_bbox, predictions_caption
|
313 |
+
)
|
314 |
+
}
|
315 |
+
return out
|
316 |
+
|
317 |
+
def forward_captioning(self, x, mask_features, mask = None, target_queries = None, target_vlp = None, task='seg', extra={}):
|
318 |
+
# x is a list of multi-scale feature
|
319 |
+
assert len(x) == self.num_feature_levels
|
320 |
+
src = []
|
321 |
+
pos = []
|
322 |
+
size_list = []
|
323 |
+
|
324 |
+
# disable mask, it does not affect performance
|
325 |
+
del mask
|
326 |
+
for i in range(self.num_feature_levels):
|
327 |
+
size_list.append(x[i].shape[-2:])
|
328 |
+
pos.append(self.pe_layer(x[i], None).flatten(2))
|
329 |
+
src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
|
330 |
+
|
331 |
+
# flatten NxCxHxW to HWxNxC
|
332 |
+
pos[-1] = pos[-1].permute(2, 0, 1)
|
333 |
+
src[-1] = src[-1].permute(2, 0, 1)
|
334 |
+
|
335 |
+
_, bs, _ = src[0].shape
|
336 |
+
|
337 |
+
# QxNxC
|
338 |
+
query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
|
339 |
+
query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
|
340 |
+
caping_lang_token = extra['start_token'].repeat(bs, 1)
|
341 |
+
pos_embed_caping = self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1)
|
342 |
+
|
343 |
+
# prepare token embedding for evaluation
|
344 |
+
token_embs = self.lang_encoder.lang_encoder.token_embedding.weight
|
345 |
+
# token_embs = (token_embs / token_embs.norm(dim=-1, keepdim=True) + 1e-7)
|
346 |
+
|
347 |
+
for cap_idx in range(0, self.captioning_step):
|
348 |
+
caping_lang_embed = self.lang_encoder.forward_language_token((caping_lang_token,))[0].transpose(0, 1)
|
349 |
+
output = torch.cat((query_feat, caping_lang_embed), dim=0) # concat object query, class token and caption token.
|
350 |
+
caping_lang_embed += pos_embed_caping
|
351 |
+
query_embed = torch.cat((query_embed_, caping_lang_embed), dim=0) # may not add at the beginning.
|
352 |
+
# output = torch.cat((query_feat, query_feat_caping), dim=0) # concat object query, class token and caption token.
|
353 |
+
|
354 |
+
# prediction heads on learnable query features
|
355 |
+
results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task)
|
356 |
+
attn_mask = results["attn_mask"]
|
357 |
+
|
358 |
+
for i in range(self.num_layers):
|
359 |
+
level_index = i % self.num_feature_levels
|
360 |
+
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
|
361 |
+
attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1)
|
362 |
+
self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1)
|
363 |
+
|
364 |
+
if extra['captioning_mask'] is not None:
|
365 |
+
bs,nq,wh = attn_mask.shape
|
366 |
+
assert bs==self.num_heads, "Only support single image referring captioning."
|
367 |
+
cap_mask = extra['captioning_mask']
|
368 |
+
attn_mask = attn_mask.reshape(bs,nq,size_list[i%3][0],size_list[i%3][1])
|
369 |
+
cap_mask = F.interpolate(cap_mask[None,].float(), size_list[i%3], mode='nearest').bool()[0,0]
|
370 |
+
attn_mask[:,self.num_queries:, cap_mask] = True
|
371 |
+
attn_mask = attn_mask.reshape(bs,nq,wh)
|
372 |
+
|
373 |
+
# attention: cross-attention first
|
374 |
+
output, avg_attn = self.transformer_cross_attention_layers[i](
|
375 |
+
output, src[level_index],
|
376 |
+
memory_mask=attn_mask,
|
377 |
+
memory_key_padding_mask=None, # here we do not apply masking on padded region
|
378 |
+
pos=pos[level_index], query_pos=query_embed
|
379 |
+
)
|
380 |
+
|
381 |
+
output = self.transformer_self_attention_layers[i](
|
382 |
+
output, tgt_mask=self_tgt_mask,
|
383 |
+
tgt_key_padding_mask=None,
|
384 |
+
query_pos=query_embed
|
385 |
+
)
|
386 |
+
|
387 |
+
# FFN
|
388 |
+
output = self.transformer_ffn_layers[i](
|
389 |
+
output
|
390 |
+
)
|
391 |
+
|
392 |
+
results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task)
|
393 |
+
attn_mask = results["attn_mask"]
|
394 |
+
|
395 |
+
pred_captions_gen = results['outputs_captionting']
|
396 |
+
# pred_captions_gen = (pred_captions_gen / pred_captions_gen.norm(dim=-1, keepdim=True) + 1e-7)
|
397 |
+
pred_captions_gen = pred_captions_gen @ token_embs.t()
|
398 |
+
caping_lang_token[:,cap_idx+1] = pred_captions_gen[:,cap_idx].max(-1)[1]
|
399 |
+
|
400 |
+
texts = self.lang_encoder.tokenizer.batch_decode(caping_lang_token, skip_special_tokens=False)
|
401 |
+
texts_new = []
|
402 |
+
|
403 |
+
for x in texts:
|
404 |
+
x = x.split('<|endoftext|>')[0]
|
405 |
+
x = x.replace('<|endoftext|>','')
|
406 |
+
x = x.replace('<|startoftext|>','')
|
407 |
+
x = x.strip()
|
408 |
+
texts_new.append(x)
|
409 |
+
|
410 |
+
out = {'pred_captionings': caping_lang_token,
|
411 |
+
'pred_texts': texts_new}
|
412 |
+
return out
|
413 |
+
|
414 |
+
|
415 |
+
def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1, task='seg'):
|
416 |
+
decoder_output = self.decoder_norm(output)
|
417 |
+
decoder_output = decoder_output.transpose(0, 1)
|
418 |
+
|
419 |
+
# extract image captioning token from decoder output.
|
420 |
+
if self.task_switch['captioning'] and (task == 'vlp' or task == 'captioning_infer'):
|
421 |
+
outputs_captionting = decoder_output[:,self.num_queries:] @ self.caping_embed
|
422 |
+
else:
|
423 |
+
outputs_captionting = None
|
424 |
+
|
425 |
+
# recompute class token output.
|
426 |
+
norm_decoder_output = decoder_output / (decoder_output.norm(dim=-1, keepdim=True) + 1e-7)
|
427 |
+
obj_token = norm_decoder_output[:,:self.num_queries-1]
|
428 |
+
cls_token = norm_decoder_output[:,self.num_queries-1:self.num_queries]
|
429 |
+
|
430 |
+
sim = (cls_token @ obj_token.transpose(1,2)).softmax(-1)[:,0,:,None] # TODO include class token.
|
431 |
+
cls_token = (sim * decoder_output[:,:self.num_queries-1]).sum(dim=1, keepdim=True)
|
432 |
+
|
433 |
+
if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']):
|
434 |
+
decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token, decoder_output[:,self.num_queries:2*self.num_queries-1]), dim=1)
|
435 |
+
else:
|
436 |
+
decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token), dim=1)
|
437 |
+
|
438 |
+
# compute class, mask and bbox.
|
439 |
+
class_embed = decoder_output @ self.class_embed
|
440 |
+
# HACK do not compute similarity if mask is not on
|
441 |
+
outputs_class = self.lang_encoder.compute_similarity(class_embed, fake=(((not self.task_switch['mask']) and self.training)))
|
442 |
+
|
443 |
+
if self.task_switch['mask']:
|
444 |
+
mask_embed = self.mask_embed(decoder_output)
|
445 |
+
outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
|
446 |
+
|
447 |
+
# NOTE: prediction is of higher-resolution
|
448 |
+
# [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
|
449 |
+
attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bicubic", align_corners=False, antialias=True)
|
450 |
+
|
451 |
+
# must use bool type
|
452 |
+
# If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
|
453 |
+
attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
|
454 |
+
attn_mask = attn_mask.detach()
|
455 |
+
|
456 |
+
# NOTE: fill False for cls token (JY)
|
457 |
+
attn_mask[:, self.num_queries:self.num_queries+1].fill_(False)
|
458 |
+
else:
|
459 |
+
outputs_mask = None
|
460 |
+
attn_mask = torch.zeros((list(decoder_output.shape[:2]) + [attn_mask_target_size[0]*attn_mask_target_size[1]]), device=decoder_output.device).repeat(self.num_heads, 1, 1).bool()
|
461 |
+
|
462 |
+
outputs_bbox = [None for i in range(len(decoder_output))]
|
463 |
+
if self.task_switch['bbox']:
|
464 |
+
outputs_bbox = self.bbox_embed(decoder_output)
|
465 |
+
|
466 |
+
outputs_caption = None
|
467 |
+
if self.task_switch['caption']:
|
468 |
+
outputs_caption = class_embed
|
469 |
+
|
470 |
+
|
471 |
+
results = {
|
472 |
+
"outputs_class": outputs_class,
|
473 |
+
"outputs_mask": outputs_mask,
|
474 |
+
"outputs_bbox": outputs_bbox,
|
475 |
+
"attn_mask": attn_mask,
|
476 |
+
"outputs_caption": outputs_caption,
|
477 |
+
"outputs_captionting": outputs_captionting,
|
478 |
+
}
|
479 |
+
return results
|
480 |
+
|
481 |
+
@torch.jit.unused
|
482 |
+
def _set_aux_loss(self, outputs_class, outputs_seg_masks, outputs_boxes, outputs_captions):
|
483 |
+
# this is a workaround to make torchscript happy, as torchscript
|
484 |
+
# doesn't support dictionary with non-homogeneous values, such
|
485 |
+
# as a dict having both a Tensor and a list.
|
486 |
+
if self.mask_classification:
|
487 |
+
return [
|
488 |
+
{"pred_logits": a, "pred_masks": b, "pred_boxes": c, "pred_captions": d}
|
489 |
+
for a, b, c, d in zip(outputs_class[:-1], outputs_seg_masks[:-1], outputs_boxes[:-1], outputs_captions[:-1])
|
490 |
+
]
|
491 |
+
else:
|
492 |
+
return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
|
493 |
+
|
494 |
+
|
495 |
+
@register_decoder
|
496 |
+
def get_xdecoder_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
|
497 |
+
return XDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
|
modeling/language/LangEncoder/__init__.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPTokenizer, CLIPTokenizerFast
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
|
4 |
+
from .transformer import *
|
5 |
+
from .build import *
|
6 |
+
|
7 |
+
|
8 |
+
def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
|
9 |
+
model_name = config_encoder['NAME']
|
10 |
+
|
11 |
+
if not is_lang_encoder(model_name):
|
12 |
+
raise ValueError(f'Unkown model: {model_name}')
|
13 |
+
|
14 |
+
return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs)
|
15 |
+
|
16 |
+
def build_tokenizer(config_encoder):
|
17 |
+
tokenizer = None
|
18 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
19 |
+
if config_encoder['TOKENIZER'] == 'clip':
|
20 |
+
pretrained_tokenizer = config_encoder.get(
|
21 |
+
'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
|
22 |
+
)
|
23 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer)
|
24 |
+
tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token})
|
25 |
+
elif config_encoder['TOKENIZER'] == 'clip-fast':
|
26 |
+
pretrained_tokenizer = config_encoder.get(
|
27 |
+
'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
|
28 |
+
)
|
29 |
+
tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True)
|
30 |
+
elif config_encoder['TOKENIZER'] == 'biomed-clip':
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
|
32 |
+
else:
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER'])
|
34 |
+
|
35 |
+
return tokenizer
|
modeling/language/LangEncoder/build.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_lang_encoders = {}
|
2 |
+
|
3 |
+
|
4 |
+
def register_lang_encoder(fn):
|
5 |
+
module_name_split = fn.__module__.split('.')
|
6 |
+
model_name = module_name_split[-1]
|
7 |
+
|
8 |
+
_lang_encoders[model_name] = fn
|
9 |
+
|
10 |
+
return fn
|
11 |
+
|
12 |
+
def lang_encoders(model_name):
|
13 |
+
return _lang_encoders[model_name]
|
14 |
+
|
15 |
+
def is_lang_encoder(model_name):
|
16 |
+
return model_name in _lang_encoders
|
modeling/language/LangEncoder/transformer.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from timm.models.layers import DropPath, trunc_normal_
|
12 |
+
|
13 |
+
from .build import register_lang_encoder
|
14 |
+
from utilities.distributed import is_main_process
|
15 |
+
from utilities.model import register_norm_module
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
@register_norm_module
|
21 |
+
class LayerNorm(nn.Module):
|
22 |
+
def __init__(self, hidden_size, eps=1e-12):
|
23 |
+
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
24 |
+
"""
|
25 |
+
super(LayerNorm, self).__init__()
|
26 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
27 |
+
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
28 |
+
self.variance_epsilon = eps
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
pdtype = x.dtype
|
32 |
+
x = x.float()
|
33 |
+
u = x.mean(-1, keepdim=True)
|
34 |
+
s = (x - u).pow(2).mean(-1, keepdim=True)
|
35 |
+
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
36 |
+
return self.weight * x.to(pdtype) + self.bias
|
37 |
+
|
38 |
+
|
39 |
+
class QuickGELU(nn.Module):
|
40 |
+
def forward(self, x: torch.Tensor):
|
41 |
+
return x * torch.sigmoid(1.702 * x)
|
42 |
+
|
43 |
+
|
44 |
+
class ResidualAttentionBlock(nn.Module):
|
45 |
+
def __init__(self,
|
46 |
+
d_model: int,
|
47 |
+
n_head: int,
|
48 |
+
attn_mask: torch.Tensor = None,
|
49 |
+
drop_path: float = 0.0):
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
53 |
+
self.ln_1 = LayerNorm(d_model)
|
54 |
+
self.mlp = nn.Sequential(OrderedDict([
|
55 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
56 |
+
("gelu", QuickGELU()),
|
57 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
58 |
+
]))
|
59 |
+
self.ln_2 = LayerNorm(d_model)
|
60 |
+
self.attn_mask = attn_mask
|
61 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
62 |
+
|
63 |
+
def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
|
64 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
|
65 |
+
if self.attn_mask is not None else None
|
66 |
+
|
67 |
+
|
68 |
+
return self.attn(
|
69 |
+
x, x, x,
|
70 |
+
key_padding_mask=key_padding_mask,
|
71 |
+
need_weights=False,
|
72 |
+
attn_mask=self.attn_mask
|
73 |
+
)[0]
|
74 |
+
|
75 |
+
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
|
76 |
+
x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
|
77 |
+
x = x + self.drop_path(self.mlp(self.ln_2(x)))
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class Transformer(nn.Module):
|
82 |
+
def __init__(self,
|
83 |
+
context_length: int,
|
84 |
+
vocab_size: int,
|
85 |
+
width: int,
|
86 |
+
layers: int,
|
87 |
+
heads: int,
|
88 |
+
drop_path: float = 0.0,
|
89 |
+
autogressive: bool =True):
|
90 |
+
super().__init__()
|
91 |
+
|
92 |
+
self.token_embedding = nn.Embedding(vocab_size, width)
|
93 |
+
|
94 |
+
self.context_length = context_length
|
95 |
+
self.positional_embedding = nn.Parameter(
|
96 |
+
torch.empty(self.context_length, width)
|
97 |
+
)
|
98 |
+
|
99 |
+
self.width = width
|
100 |
+
self.layers = layers
|
101 |
+
self.autogressive = autogressive
|
102 |
+
attn_mask = self.build_attention_mask() if autogressive else None
|
103 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule
|
104 |
+
self.resblocks = nn.ModuleList(
|
105 |
+
[
|
106 |
+
ResidualAttentionBlock(width, heads, attn_mask, dpr[i])
|
107 |
+
for i in range(layers)
|
108 |
+
]
|
109 |
+
)
|
110 |
+
|
111 |
+
self.ln_final = LayerNorm(width)
|
112 |
+
|
113 |
+
trunc_normal_(self.positional_embedding, std=.02)
|
114 |
+
# nn.init.normal_(self.token_embedding, std=.02)
|
115 |
+
trunc_normal_(self.token_embedding.weight, std=.02)
|
116 |
+
self.apply(self._init_weights)
|
117 |
+
|
118 |
+
@property
|
119 |
+
def dim_out(self):
|
120 |
+
return self.width
|
121 |
+
|
122 |
+
def build_attention_mask(self):
|
123 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
124 |
+
# pytorch uses additive attention mask; fill with -inf
|
125 |
+
mask = torch.empty(self.context_length, self.context_length)
|
126 |
+
mask.fill_(float("-inf"))
|
127 |
+
mask.triu_(1) # zero out the lower diagonal
|
128 |
+
return mask
|
129 |
+
|
130 |
+
def _init_weights(self, m):
|
131 |
+
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
132 |
+
if is_main_process():
|
133 |
+
logger.info('=> init weight of Linear/Conv2d from trunc norm')
|
134 |
+
trunc_normal_(m.weight, std=0.02)
|
135 |
+
if m.bias is not None:
|
136 |
+
if is_main_process():
|
137 |
+
logger.info('=> init bias of Linear/Conv2d to zeros')
|
138 |
+
nn.init.constant_(m.bias, 0)
|
139 |
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
140 |
+
nn.init.constant_(m.bias, 0)
|
141 |
+
|
142 |
+
def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
|
143 |
+
if os.path.isfile(pretrained):
|
144 |
+
pretrained_dict = torch.load(pretrained, map_location='cpu')
|
145 |
+
logging.info(f'=> loading pretrained model {pretrained}')
|
146 |
+
model_dict = self.state_dict()
|
147 |
+
stripped_key = lambda x: x[13:] if x.startswith('lang_encoder.') else x
|
148 |
+
pretrained_dict = {
|
149 |
+
stripped_key(k): v for k, v in pretrained_dict.items()
|
150 |
+
if stripped_key(k) in model_dict.keys()
|
151 |
+
}
|
152 |
+
need_init_state_dict = {}
|
153 |
+
for k, v in pretrained_dict.items():
|
154 |
+
need_init = (
|
155 |
+
k.split('.')[0] in pretrained_layers
|
156 |
+
or pretrained_layers[0] == '*'
|
157 |
+
)
|
158 |
+
if need_init:
|
159 |
+
if verbose:
|
160 |
+
logger.info(f'=> init {k} from {pretrained}')
|
161 |
+
|
162 |
+
if 'positional_embedding' in k and v.size() != model_dict[k].size():
|
163 |
+
positional_embedding_pretrained = v
|
164 |
+
positional_embedding_current = model_dict[k]
|
165 |
+
L1, nH1 = positional_embedding_pretrained.size()
|
166 |
+
L2, nH2 = positional_embedding_current.size()
|
167 |
+
if nH1 != nH2:
|
168 |
+
logger.info(f"Error in loading {k}, passing")
|
169 |
+
else:
|
170 |
+
if L1 != L2:
|
171 |
+
logger.info(
|
172 |
+
'=> load_pretrained: resized variant: {} to {}'
|
173 |
+
.format((L1, nH1), (L2, nH2))
|
174 |
+
)
|
175 |
+
|
176 |
+
posemb = positional_embedding_pretrained.float()
|
177 |
+
posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1)
|
178 |
+
posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear')
|
179 |
+
posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0)
|
180 |
+
v = posemb_grid
|
181 |
+
|
182 |
+
need_init_state_dict[k] = v
|
183 |
+
|
184 |
+
self.load_state_dict(need_init_state_dict, strict=False)
|
185 |
+
|
186 |
+
|
187 |
+
@torch.jit.ignore
|
188 |
+
def no_weight_decay(self):
|
189 |
+
return {
|
190 |
+
'positional_embedding',
|
191 |
+
'token_embedding',
|
192 |
+
}
|
193 |
+
|
194 |
+
def forward(self, input_ids, attention_mask=None):
|
195 |
+
key_padding_mask = (attention_mask == 0) if (not self.autogressive and attention_mask is not None) else None
|
196 |
+
# key_padding_mask = (input_ids == 0) if not self.autogressive else None
|
197 |
+
x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model]
|
198 |
+
x = x + self.positional_embedding
|
199 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
200 |
+
for block in self.resblocks:
|
201 |
+
x = block(x, key_padding_mask)
|
202 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
203 |
+
|
204 |
+
x = self.ln_final(x)
|
205 |
+
|
206 |
+
return {'last_hidden_state': x}
|
207 |
+
|
208 |
+
|
209 |
+
@register_lang_encoder
|
210 |
+
def lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
|
211 |
+
transformer = Transformer(
|
212 |
+
context_length=config_encoder['CONTEXT_LENGTH'],
|
213 |
+
vocab_size=tokenizer.vocab_size,
|
214 |
+
width=config_encoder['WIDTH'],
|
215 |
+
layers=config_encoder['LAYERS'],
|
216 |
+
heads=config_encoder['HEADS'],
|
217 |
+
autogressive=config_encoder.get('AUTOGRESSIVE', True)
|
218 |
+
)
|
219 |
+
|
220 |
+
if config_encoder.get('LOAD_PRETRAINED', False):
|
221 |
+
transformer.load_pretrained(config_encoder['PRETRAINED'], config_encoder.get('PRETRAINED_LAYERS', ['*']))
|
222 |
+
return transformer
|
modeling/language/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .vlpencoder import *
|
2 |
+
from .build import *
|
3 |
+
|
4 |
+
def build_language_encoder(config, **kwargs):
|
5 |
+
model_name = config['MODEL']['TEXT']['ARCH']
|
6 |
+
|
7 |
+
if not is_model(model_name):
|
8 |
+
raise ValueError(f'Unkown model: {model_name}')
|
9 |
+
|
10 |
+
return model_entrypoints(model_name)(config, **kwargs)
|
modeling/language/build.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_model_entrypoints = {}
|
2 |
+
|
3 |
+
|
4 |
+
def register_model(fn):
|
5 |
+
module_name_split = fn.__module__.split('.')
|
6 |
+
model_name = module_name_split[-1]
|
7 |
+
_model_entrypoints[model_name] = fn
|
8 |
+
return fn
|
9 |
+
|
10 |
+
def model_entrypoints(model_name):
|
11 |
+
return _model_entrypoints[model_name]
|
12 |
+
|
13 |
+
def is_model(model_name):
|
14 |
+
return model_name in _model_entrypoints
|
modeling/language/loss.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Xueyan Zou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import pickle
|
9 |
+
from distutils import log
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torch.distributed as dist
|
14 |
+
|
15 |
+
from einops import rearrange, repeat
|
16 |
+
from timm.loss import SoftTargetCrossEntropy
|
17 |
+
|
18 |
+
soft_cross_entropy = SoftTargetCrossEntropy()
|
19 |
+
|
20 |
+
def is_dist_initialized():
|
21 |
+
return torch.distributed.is_initialized()
|
22 |
+
|
23 |
+
def get_world_size():
|
24 |
+
if is_dist_initialized():
|
25 |
+
return torch.distributed.get_world_size()
|
26 |
+
return 1
|
27 |
+
|
28 |
+
def get_rank():
|
29 |
+
if is_dist_initialized():
|
30 |
+
return dist.get_rank()
|
31 |
+
return 0
|
32 |
+
|
33 |
+
def all_gather_grad(x):
|
34 |
+
if get_world_size() > 1:
|
35 |
+
all_x = [torch.zeros_like(x) for _ in range(get_world_size())]
|
36 |
+
torch.distributed.all_gather(all_x, x)
|
37 |
+
all_x[torch.distributed.get_rank()] = x
|
38 |
+
x = torch.cat(all_x, dim=0)
|
39 |
+
return x
|
40 |
+
|
41 |
+
def vl_multilabel_contrastive_loss(image_feat, text_feat, temperature=1):
|
42 |
+
"""
|
43 |
+
Args:
|
44 |
+
image_feat (torch.Tensor): shape [B, L1, C] # B: batch_size, L1: 1, C: 256
|
45 |
+
text_feat (torch.Tensor): shape [B, L2, C] # B:batch_size, L2: number of selected nouns, C: 256
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
"""
|
49 |
+
# [B, L1, C], L1 = 1
|
50 |
+
# image_feat = F.normalize(image_feat, dim=-1)
|
51 |
+
# [B, L2, C]
|
52 |
+
# text_feat = F.normalize(text_feat, dim=-1)
|
53 |
+
# HACK: normalize outside
|
54 |
+
|
55 |
+
# [B, L1, L2]
|
56 |
+
dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l')
|
57 |
+
# [B, L2, L1]
|
58 |
+
dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l')
|
59 |
+
|
60 |
+
batch = image_feat.shape[0]
|
61 |
+
img_len = image_feat.shape[1]
|
62 |
+
text_len = text_feat.shape[1]
|
63 |
+
# [B, L1, L2]
|
64 |
+
pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2')
|
65 |
+
# [B, L2, L1]
|
66 |
+
pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1')
|
67 |
+
|
68 |
+
image_x = rearrange(image_feat, 'b l c -> (b l) c')
|
69 |
+
text_x = rearrange(text_feat, 'b l c -> (b l) c')
|
70 |
+
|
71 |
+
logits_per_img = image_x @ all_gather_grad(text_x).t()
|
72 |
+
logits_per_text = text_x @ all_gather_grad(image_x).t()
|
73 |
+
|
74 |
+
# get label globally
|
75 |
+
# [B, L1, B, L2, W]
|
76 |
+
labels_per_img = F.one_hot(
|
77 |
+
torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * get_rank(),
|
78 |
+
num_classes=get_world_size()).to(image_x.dtype)
|
79 |
+
labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat(
|
80 |
+
torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1')
|
81 |
+
# [BxL1, WxBxL2]
|
82 |
+
labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)')
|
83 |
+
# [B, L2, B, L1, W]
|
84 |
+
labels_per_text = F.one_hot(
|
85 |
+
torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * get_rank(),
|
86 |
+
num_classes=get_world_size()).to(text_x.dtype)
|
87 |
+
labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat(
|
88 |
+
torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1')
|
89 |
+
# [BxL2, WxBxL1]
|
90 |
+
labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)')
|
91 |
+
|
92 |
+
logit_scale = temperature.exp().clamp(max=100)
|
93 |
+
|
94 |
+
loss_img = soft_cross_entropy(logit_scale * logits_per_img, labels_per_img)
|
95 |
+
loss_text = soft_cross_entropy(logit_scale * logits_per_text, labels_per_text)
|
96 |
+
|
97 |
+
loss = 0.5 * (loss_img + loss_text)
|
98 |
+
return loss
|
99 |
+
|
100 |
+
def vl_contrastive_loss(image_feat, text_feat, temperature=1):
|
101 |
+
# if image_id or text_id is None, it should be None across all GPUs
|
102 |
+
# image_feat = F.normalize(image_feat, dim=1)
|
103 |
+
# text_feat = F.normalize(text_feat, dim=1)
|
104 |
+
# handle normalization outside
|
105 |
+
|
106 |
+
# add the following 4 lines
|
107 |
+
image_feat = all_gather_grad(image_feat)
|
108 |
+
text_feat = all_gather_grad(text_feat)
|
109 |
+
|
110 |
+
logits = torch.matmul(image_feat, text_feat.t())
|
111 |
+
logit_scale = temperature.exp().clamp(max=100)
|
112 |
+
|
113 |
+
gt = torch.arange(logits.shape[0], device=logits.device)
|
114 |
+
loss1 = F.cross_entropy(logit_scale * logits, gt)
|
115 |
+
loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
|
116 |
+
return (loss1 + loss2) / 2 # scale it up by the number of GPUs
|
117 |
+
|
118 |
+
|
119 |
+
def all_gather_pickle(data, device):
|
120 |
+
"""
|
121 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
122 |
+
Args:
|
123 |
+
data: any picklable object
|
124 |
+
Returns:
|
125 |
+
list[data]: list of data gathered from each rank
|
126 |
+
"""
|
127 |
+
world_size = get_world_size()
|
128 |
+
if world_size == 1:
|
129 |
+
return [data]
|
130 |
+
|
131 |
+
# serialized to a Tensor
|
132 |
+
buffer = pickle.dumps(data)
|
133 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
134 |
+
tensor = torch.ByteTensor(storage).to(device)
|
135 |
+
|
136 |
+
# obtain Tensor size of each rank
|
137 |
+
local_size = torch.LongTensor([tensor.numel()]).cuda()
|
138 |
+
size_list = [torch.LongTensor([0]).cuda() for _ in range(world_size)]
|
139 |
+
dist.all_gather(size_list, local_size)
|
140 |
+
size_list = [int(size.item()) for size in size_list]
|
141 |
+
max_size = max(size_list)
|
142 |
+
|
143 |
+
# receiving Tensor from all ranks
|
144 |
+
# we pad the tensor because torch all_gather does not support
|
145 |
+
# gathering tensors of different shapes
|
146 |
+
tensor_list = []
|
147 |
+
for _ in size_list:
|
148 |
+
tensor_list.append(torch.ByteTensor(size=(max_size,)).cuda())
|
149 |
+
if local_size != max_size:
|
150 |
+
padding = torch.ByteTensor(size=(max_size - local_size,)).cuda()
|
151 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
152 |
+
dist.all_gather(tensor_list, tensor)
|
153 |
+
|
154 |
+
data_list = []
|
155 |
+
for size, tensor in zip(size_list, tensor_list):
|
156 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
157 |
+
data_list.append(pickle.loads(buffer))
|
158 |
+
|
159 |
+
return data_list
|
160 |
+
|
161 |
+
def all_gather_arbitary_tensor(tensor):
|
162 |
+
if get_world_size() > 1:
|
163 |
+
device = tensor.device
|
164 |
+
tensor_batch = all_gather_pickle(tensor.cpu(), device)
|
165 |
+
tensor_batch = [x.to(device) for x in tensor_batch]
|
166 |
+
tensor_batch[torch.distributed.get_rank()] = tensor
|
167 |
+
tensor_batch = torch.cat(tensor_batch, dim=0)
|
168 |
+
else:
|
169 |
+
tensor_batch = tensor
|
170 |
+
return tensor_batch
|
171 |
+
|
172 |
+
def ql_contrastive_loss(image_feat, text_feat, temperature=1):
|
173 |
+
# add the following 4 lines
|
174 |
+
image_feat = all_gather_arbitary_tensor(image_feat)
|
175 |
+
text_feat = all_gather_arbitary_tensor(text_feat)
|
176 |
+
|
177 |
+
logits = torch.matmul(image_feat, text_feat.t())
|
178 |
+
logit_scale = temperature.exp().clamp(max=100)
|
179 |
+
|
180 |
+
gt = torch.arange(logits.shape[0], device=logits.device)
|
181 |
+
loss1 = F.cross_entropy(logit_scale * logits, gt)
|
182 |
+
loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
|
183 |
+
return (loss1 + loss2) / 2 # scale it up by the number of GPUs
|
184 |
+
|
185 |
+
def vl_similarity(image_feat, text_feat, temperature=1):
|
186 |
+
# Only support single GPU for now.
|
187 |
+
logits = torch.matmul(image_feat, text_feat.t())
|
188 |
+
logits = temperature.exp().clamp(max=100) * logits
|
189 |
+
return logits
|
190 |
+
|
191 |
+
def ql_multi_contrastive_loss(image_feat, text_feat, text_hash, temperature=1):
|
192 |
+
# add the following 4 lines
|
193 |
+
image_feat = all_gather_arbitary_tensor(image_feat)
|
194 |
+
text_feat = all_gather_arbitary_tensor(text_feat)
|
195 |
+
|
196 |
+
text_hash_batch = all_gather_pickle(text_hash, text_feat.device)
|
197 |
+
text_hash_all = torch.cat(text_hash_batch)
|
198 |
+
|
199 |
+
text_hash_all_unique = torch.unique(text_hash_all).tolist()
|
200 |
+
gt = torch.zeros((image_feat.shape[0], len(text_hash_all_unique)), device=text_feat.device)
|
201 |
+
text_hash_all = text_hash_all.tolist()
|
202 |
+
text_feat_unique = torch.stack([text_feat[text_hash_all.index(txt)] for txt in text_hash_all_unique])
|
203 |
+
|
204 |
+
for idx, txt in enumerate(text_hash_all):
|
205 |
+
gt[idx][text_hash_all_unique.index(txt)] = 1
|
206 |
+
|
207 |
+
logits = torch.matmul(image_feat, text_feat_unique.t())
|
208 |
+
logits = logits*temperature.exp().clamp(max=100)
|
209 |
+
|
210 |
+
loss_img = soft_cross_entropy(logits, gt)
|
211 |
+
loss_text = soft_cross_entropy(logits.t(), gt.t() / gt.t().sum(-1, keepdim=True))
|
212 |
+
|
213 |
+
loss = 0.7 * loss_img + 0.3 * loss_text
|
214 |
+
return loss
|
215 |
+
|
216 |
+
def image_text_contrastive_loss_queue(image_feat_inp, text_feat_inp, lang_enc, training):
|
217 |
+
# add the following 4 lines
|
218 |
+
image_feat = all_gather_grad(image_feat_inp.contiguous())
|
219 |
+
text_feat = all_gather_grad(text_feat_inp.contiguous())
|
220 |
+
|
221 |
+
image_feat = image_feat / (image_feat.norm(dim=-1, keepdim=True) + 1e-7)
|
222 |
+
text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-7)
|
223 |
+
|
224 |
+
temperature = lang_enc.logit_scale
|
225 |
+
logits = torch.matmul(image_feat, text_feat.t())
|
226 |
+
logit_scale = temperature.exp().clamp(max=100)
|
227 |
+
|
228 |
+
gt = torch.arange(logits.shape[0], device=logits.device)
|
229 |
+
loss1 = F.cross_entropy(logit_scale * logits, gt)
|
230 |
+
loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
|
231 |
+
|
232 |
+
return (loss1 + loss2) / 2 # scale it up by the number of GPUs
|
modeling/language/misc.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import nltk
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from utilities.constants import IMAGENET_DEFAULT_TEMPLATES
|
8 |
+
|
9 |
+
nltk.download('punkt', quiet=True)
|
10 |
+
nltk.download('averaged_perceptron_tagger', quiet=True)
|
11 |
+
|
12 |
+
def get_tag(tokenized, tags):
|
13 |
+
if not isinstance(tags, (list, tuple)):
|
14 |
+
tags = [tags]
|
15 |
+
ret = []
|
16 |
+
for (word, pos) in nltk.pos_tag(tokenized):
|
17 |
+
for tag in tags:
|
18 |
+
if pos == tag:
|
19 |
+
ret.append(word)
|
20 |
+
return ret
|
21 |
+
|
22 |
+
def get_noun_phrase(tokenized):
|
23 |
+
# Taken from Su Nam Kim Paper...
|
24 |
+
grammar = r"""
|
25 |
+
NBAR:
|
26 |
+
{<NN.*|JJ>*<NN.*>} # Nouns and Adjectives, terminated with Nouns
|
27 |
+
|
28 |
+
NP:
|
29 |
+
{<NBAR>}
|
30 |
+
{<NBAR><IN><NBAR>} # Above, connected with in/of/etc...
|
31 |
+
"""
|
32 |
+
chunker = nltk.RegexpParser(grammar)
|
33 |
+
|
34 |
+
chunked = chunker.parse(nltk.pos_tag(tokenized))
|
35 |
+
continuous_chunk = []
|
36 |
+
current_chunk = []
|
37 |
+
|
38 |
+
for subtree in chunked:
|
39 |
+
if isinstance(subtree, nltk.Tree):
|
40 |
+
current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
|
41 |
+
elif current_chunk:
|
42 |
+
named_entity = ' '.join(current_chunk)
|
43 |
+
if named_entity not in continuous_chunk:
|
44 |
+
continuous_chunk.append(named_entity)
|
45 |
+
current_chunk = []
|
46 |
+
else:
|
47 |
+
continue
|
48 |
+
|
49 |
+
return continuous_chunk
|
50 |
+
|
51 |
+
def text_noun_with_prompt_all(text, phrase_prob=0.0, append_text=True):
|
52 |
+
tokenized = nltk.word_tokenize(text)
|
53 |
+
|
54 |
+
if random.random() >= phrase_prob:
|
55 |
+
nouns = get_tag(tokenized, ['NN', 'NNS', 'NNP'])
|
56 |
+
else:
|
57 |
+
nouns = get_noun_phrase(tokenized)
|
58 |
+
|
59 |
+
|
60 |
+
prompt_texts = [np.random.choice(IMAGENET_DEFAULT_TEMPLATES).format(noun) for noun in nouns]
|
61 |
+
|
62 |
+
if append_text:
|
63 |
+
prompt_texts += [text]
|
64 |
+
nouns += [text]
|
65 |
+
|
66 |
+
return prompt_texts, nouns
|
modeling/language/vlpencoder.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Xueyan Zou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
from timm.models.layers import trunc_normal_
|
13 |
+
|
14 |
+
from .build import register_model
|
15 |
+
from ..utils import configurable
|
16 |
+
from .LangEncoder import build_tokenizer, build_lang_encoder
|
17 |
+
from utilities.prompt_engineering import prompt_engineering, get_prompt_templates
|
18 |
+
|
19 |
+
from transformers import AutoTokenizer, AutoModel
|
20 |
+
|
21 |
+
class LanguageEncoder(nn.Module):
|
22 |
+
|
23 |
+
@configurable
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
tokenizer,
|
27 |
+
tokenizer_type,
|
28 |
+
lang_encoder,
|
29 |
+
lang_projection,
|
30 |
+
max_token_num,
|
31 |
+
queue_operator,
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
# seg
|
35 |
+
self.tokenizer = tokenizer
|
36 |
+
self.tokenizer_type = tokenizer_type
|
37 |
+
self.lang_encoder = lang_encoder
|
38 |
+
self.lang_proj = lang_projection
|
39 |
+
self.max_token_num = max_token_num
|
40 |
+
self.logit_scale = nn.Parameter(torch.ones([]))
|
41 |
+
|
42 |
+
# captioning & retrieval
|
43 |
+
for key, value in queue_operator.items():
|
44 |
+
self.register_buffer(key, value)
|
45 |
+
|
46 |
+
self.biomed_encoder = AutoModel.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
|
47 |
+
|
48 |
+
@classmethod
|
49 |
+
def from_config(cls, cfg):
|
50 |
+
# build up text encoder for seg
|
51 |
+
tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])
|
52 |
+
tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER']
|
53 |
+
lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE'])
|
54 |
+
max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
|
55 |
+
|
56 |
+
dim_lang = cfg['MODEL']['TEXT']['WIDTH']
|
57 |
+
dim_projection = cfg['MODEL']['DIM_PROJ']
|
58 |
+
lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection))
|
59 |
+
trunc_normal_(lang_projection, std=.02)
|
60 |
+
|
61 |
+
# tested not working better
|
62 |
+
queue_operator = {}
|
63 |
+
|
64 |
+
return {
|
65 |
+
"tokenizer": tokenizer,
|
66 |
+
"tokenizer_type": tokenizer_type,
|
67 |
+
"lang_encoder": lang_encoder,
|
68 |
+
"lang_projection": lang_projection,
|
69 |
+
"max_token_num": max_token_num,
|
70 |
+
"queue_operator": queue_operator,
|
71 |
+
}
|
72 |
+
|
73 |
+
def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True, store_buffer=None):
|
74 |
+
if not is_eval:
|
75 |
+
if prompt:
|
76 |
+
# randomly sample one template
|
77 |
+
arbitary_concepts = [
|
78 |
+
prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \
|
79 |
+
for label in range(len(class_names))
|
80 |
+
]
|
81 |
+
if add_bgd:
|
82 |
+
arbitary_concepts.append("A background in coco.")
|
83 |
+
else:
|
84 |
+
arbitary_concepts = class_names
|
85 |
+
|
86 |
+
input_ids = []
|
87 |
+
attention_masks = []
|
88 |
+
for txt in arbitary_concepts:
|
89 |
+
tokens = self.tokenizer(
|
90 |
+
txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
|
91 |
+
)
|
92 |
+
tokens['input_ids'].squeeze_()
|
93 |
+
tokens['attention_mask'].squeeze_()
|
94 |
+
|
95 |
+
input_ids.append(tokens['input_ids'])
|
96 |
+
attention_masks.append(tokens['attention_mask'])
|
97 |
+
|
98 |
+
arbitary_tokens = torch.stack(input_ids)
|
99 |
+
arbitary_attention_masks = torch.stack(attention_masks)
|
100 |
+
|
101 |
+
text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm)
|
102 |
+
setattr(self, '{}_text_embeddings'.format(name), text_emb)
|
103 |
+
else:
|
104 |
+
with torch.no_grad():
|
105 |
+
def extract_mean_emb(txts):
|
106 |
+
tokens = self.tokenizer(
|
107 |
+
txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
|
108 |
+
)
|
109 |
+
clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm)
|
110 |
+
clss_embedding = clss_embedding.mean(dim=0)
|
111 |
+
clss_embedding /= clss_embedding.norm()
|
112 |
+
return clss_embedding
|
113 |
+
|
114 |
+
templates = get_prompt_templates()
|
115 |
+
clss_embeddings = []
|
116 |
+
if prompt:
|
117 |
+
for clss in class_names:
|
118 |
+
txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates]
|
119 |
+
clss_embeddings.append(extract_mean_emb(txts))
|
120 |
+
else:
|
121 |
+
for clss in class_names:
|
122 |
+
clss_embeddings.append(extract_mean_emb([clss]))
|
123 |
+
|
124 |
+
if add_bgd:
|
125 |
+
txts = ["A background in coco."]
|
126 |
+
clss_embeddings.append(extract_mean_emb(txts))
|
127 |
+
|
128 |
+
text_emb = torch.stack(clss_embeddings, dim=0)
|
129 |
+
setattr(self, '{}_text_embeddings'.format(name), text_emb)
|
130 |
+
|
131 |
+
def reset_text_embeddings(self, name='default'):
|
132 |
+
pass
|
133 |
+
|
134 |
+
def get_text_token_embeddings(self, txts, name='default', token=False, norm=False):
|
135 |
+
if not token:
|
136 |
+
tokens = self.tokenizer(
|
137 |
+
txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
|
138 |
+
)
|
139 |
+
tokens = {key: value.cuda() for key, value in tokens.items()}
|
140 |
+
else:
|
141 |
+
tokens = txts
|
142 |
+
token_emb, class_emb = self.forward_language_token((tokens['input_ids'], tokens['attention_mask']), norm=norm)
|
143 |
+
ret = {"tokens": tokens,
|
144 |
+
"token_emb": token_emb,
|
145 |
+
"class_emb": class_emb,}
|
146 |
+
setattr(self, '{}_token_embeddings'.format(name), ret)
|
147 |
+
return ret
|
148 |
+
|
149 |
+
def forward_language(self, texts, norm=True):
|
150 |
+
if self.tokenizer_type == 'biomed-clip':
|
151 |
+
with torch.no_grad(): # Disable gradient calculation
|
152 |
+
outputs = self.biomed_encoder(*texts)
|
153 |
+
# Extract the last hidden state
|
154 |
+
x = outputs['last_hidden_state']
|
155 |
+
x = x[:, 0] # Get the [CLS] token's embeddings for all examples
|
156 |
+
else:
|
157 |
+
x = self.lang_encoder(*texts)
|
158 |
+
x = x['last_hidden_state']
|
159 |
+
|
160 |
+
if self.tokenizer_type == 'clip':
|
161 |
+
x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)]
|
162 |
+
else:
|
163 |
+
x = x[:, 0]
|
164 |
+
|
165 |
+
x = x @ self.lang_proj
|
166 |
+
if norm:
|
167 |
+
x = x / (x.norm(dim=-1, keepdim=True) + 1e-7)
|
168 |
+
return x
|
169 |
+
|
170 |
+
def forward_language_token(self, texts, norm=False):
|
171 |
+
if self.tokenizer_type == 'biomed-clip':
|
172 |
+
with torch.no_grad(): # Disable gradient calculation
|
173 |
+
outputs = self.biomed_encoder(*texts)
|
174 |
+
# Extract the last hidden state
|
175 |
+
token_x = outputs['last_hidden_state']
|
176 |
+
class_x = token_x[:, 0] # Get the [CLS] token's embeddings for all examples
|
177 |
+
else:
|
178 |
+
x = self.lang_encoder(*texts)
|
179 |
+
token_x = x['last_hidden_state']
|
180 |
+
|
181 |
+
if self.tokenizer_type == 'clip':
|
182 |
+
class_x = token_x[torch.arange(token_x.size(0)), texts[0].argmax(dim=-1)]
|
183 |
+
else:
|
184 |
+
class_x = token_x[:, 0]
|
185 |
+
|
186 |
+
class_x = class_x @ self.lang_proj
|
187 |
+
token_x = token_x @ self.lang_proj
|
188 |
+
|
189 |
+
if norm:
|
190 |
+
class_x = class_x / (class_x.norm(dim=-1, keepdim=True) + 1e-7)
|
191 |
+
token_x = token_x / (token_x.norm(dim=-1, keepdim=True) + 1e-7)
|
192 |
+
|
193 |
+
return token_x, class_x
|
194 |
+
|
195 |
+
def compute_similarity(self, v_emb, name='default', fake=False):
|
196 |
+
if fake:
|
197 |
+
return None
|
198 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
199 |
+
t_emb = getattr(self, '{}_text_embeddings'.format(name))
|
200 |
+
output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2)
|
201 |
+
return output
|
202 |
+
|
203 |
+
|
204 |
+
@register_model
|
205 |
+
def get_language_model(cfg, **kwargs):
|
206 |
+
return LanguageEncoder(cfg)
|
modeling/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .point_features import *
|
2 |
+
from .position_encoding import *
|
3 |
+
from .postprocessing import *
|
4 |
+
from .attention import *
|
5 |
+
from .criterion import *
|
6 |
+
from .matcher import *
|
modeling/modules/attention.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch import Tensor
|
7 |
+
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
8 |
+
from torch.nn.parameter import Parameter
|
9 |
+
from torch.overrides import has_torch_function, handle_torch_function
|
10 |
+
from torch.nn.functional import pad, linear, softmax, dropout
|
11 |
+
|
12 |
+
|
13 |
+
def multi_head_attention_forward(
|
14 |
+
query: Tensor,
|
15 |
+
key: Tensor,
|
16 |
+
value: Tensor,
|
17 |
+
embed_dim_to_check: int,
|
18 |
+
num_heads: int,
|
19 |
+
in_proj_weight: Tensor,
|
20 |
+
in_proj_bias: Tensor,
|
21 |
+
bias_k: Optional[Tensor],
|
22 |
+
bias_v: Optional[Tensor],
|
23 |
+
add_zero_attn: bool,
|
24 |
+
dropout_p: float,
|
25 |
+
out_proj_weight: Tensor,
|
26 |
+
out_proj_bias: Tensor,
|
27 |
+
training: bool = True,
|
28 |
+
key_padding_mask: Optional[Tensor] = None,
|
29 |
+
need_weights: bool = True,
|
30 |
+
attn_mask: Optional[Tensor] = None,
|
31 |
+
use_separate_proj_weight: bool = False,
|
32 |
+
q_proj_weight: Optional[Tensor] = None,
|
33 |
+
k_proj_weight: Optional[Tensor] = None,
|
34 |
+
v_proj_weight: Optional[Tensor] = None,
|
35 |
+
static_k: Optional[Tensor] = None,
|
36 |
+
static_v: Optional[Tensor] = None,
|
37 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
38 |
+
r"""
|
39 |
+
Args:
|
40 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
41 |
+
See "Attention Is All You Need" for more details.
|
42 |
+
embed_dim_to_check: total dimension of the model.
|
43 |
+
num_heads: parallel attention heads.
|
44 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
45 |
+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
46 |
+
add_zero_attn: add a new batch of zeros to the key and
|
47 |
+
value sequences at dim=1.
|
48 |
+
dropout_p: probability of an element to be zeroed.
|
49 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
50 |
+
training: apply dropout if is ``True``.
|
51 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
52 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
53 |
+
the corresponding value on the attention layer will be filled with -inf.
|
54 |
+
need_weights: output attn_output_weights.
|
55 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
56 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
57 |
+
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
58 |
+
and value in different forms. If false, in_proj_weight will be used, which is
|
59 |
+
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
60 |
+
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
61 |
+
static_k, static_v: static key and value used for attention operators.
|
62 |
+
|
63 |
+
|
64 |
+
Shape:
|
65 |
+
Inputs:
|
66 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
67 |
+
the embedding dimension.
|
68 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
69 |
+
the embedding dimension.
|
70 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
71 |
+
the embedding dimension.
|
72 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
73 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
74 |
+
will be unchanged. If a BoolTensor is provided, the positions with the
|
75 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
76 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
77 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
78 |
+
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
79 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
80 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
81 |
+
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
82 |
+
is provided, it will be added to the attention weight.
|
83 |
+
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
84 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
85 |
+
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
86 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
87 |
+
|
88 |
+
Outputs:
|
89 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
90 |
+
E is the embedding dimension.
|
91 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
92 |
+
L is the target sequence length, S is the source sequence length.
|
93 |
+
"""
|
94 |
+
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
|
95 |
+
if has_torch_function(tens_ops):
|
96 |
+
return handle_torch_function(
|
97 |
+
multi_head_attention_forward,
|
98 |
+
tens_ops,
|
99 |
+
query,
|
100 |
+
key,
|
101 |
+
value,
|
102 |
+
embed_dim_to_check,
|
103 |
+
num_heads,
|
104 |
+
in_proj_weight,
|
105 |
+
in_proj_bias,
|
106 |
+
bias_k,
|
107 |
+
bias_v,
|
108 |
+
add_zero_attn,
|
109 |
+
dropout_p,
|
110 |
+
out_proj_weight,
|
111 |
+
out_proj_bias,
|
112 |
+
training=training,
|
113 |
+
key_padding_mask=key_padding_mask,
|
114 |
+
need_weights=need_weights,
|
115 |
+
attn_mask=attn_mask,
|
116 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
117 |
+
q_proj_weight=q_proj_weight,
|
118 |
+
k_proj_weight=k_proj_weight,
|
119 |
+
v_proj_weight=v_proj_weight,
|
120 |
+
static_k=static_k,
|
121 |
+
static_v=static_v,
|
122 |
+
)
|
123 |
+
tgt_len, bsz, embed_dim = query.size()
|
124 |
+
assert embed_dim == embed_dim_to_check
|
125 |
+
# allow MHA to have different sizes for the feature dimension
|
126 |
+
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
127 |
+
|
128 |
+
head_dim = embed_dim // num_heads
|
129 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
130 |
+
scaling = float(head_dim) ** -0.5
|
131 |
+
|
132 |
+
if not use_separate_proj_weight:
|
133 |
+
if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)):
|
134 |
+
# self-attention
|
135 |
+
q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
136 |
+
|
137 |
+
elif key is value or torch.equal(key, value):
|
138 |
+
# encoder-decoder attention
|
139 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
140 |
+
_b = in_proj_bias
|
141 |
+
_start = 0
|
142 |
+
_end = embed_dim
|
143 |
+
_w = in_proj_weight[_start:_end, :]
|
144 |
+
if _b is not None:
|
145 |
+
_b = _b[_start:_end]
|
146 |
+
q = linear(query, _w, _b)
|
147 |
+
|
148 |
+
if key is None:
|
149 |
+
assert value is None
|
150 |
+
k = None
|
151 |
+
v = None
|
152 |
+
else:
|
153 |
+
|
154 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
155 |
+
_b = in_proj_bias
|
156 |
+
_start = embed_dim
|
157 |
+
_end = None
|
158 |
+
_w = in_proj_weight[_start:, :]
|
159 |
+
if _b is not None:
|
160 |
+
_b = _b[_start:]
|
161 |
+
k, v = linear(key, _w, _b).chunk(2, dim=-1)
|
162 |
+
|
163 |
+
else:
|
164 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
165 |
+
_b = in_proj_bias
|
166 |
+
_start = 0
|
167 |
+
_end = embed_dim
|
168 |
+
_w = in_proj_weight[_start:_end, :]
|
169 |
+
if _b is not None:
|
170 |
+
_b = _b[_start:_end]
|
171 |
+
q = linear(query, _w, _b)
|
172 |
+
|
173 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
174 |
+
_b = in_proj_bias
|
175 |
+
_start = embed_dim
|
176 |
+
_end = embed_dim * 2
|
177 |
+
_w = in_proj_weight[_start:_end, :]
|
178 |
+
if _b is not None:
|
179 |
+
_b = _b[_start:_end]
|
180 |
+
k = linear(key, _w, _b)
|
181 |
+
|
182 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
183 |
+
_b = in_proj_bias
|
184 |
+
_start = embed_dim * 2
|
185 |
+
_end = None
|
186 |
+
_w = in_proj_weight[_start:, :]
|
187 |
+
if _b is not None:
|
188 |
+
_b = _b[_start:]
|
189 |
+
v = linear(value, _w, _b)
|
190 |
+
else:
|
191 |
+
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
192 |
+
len1, len2 = q_proj_weight_non_opt.size()
|
193 |
+
assert len1 == embed_dim and len2 == query.size(-1)
|
194 |
+
|
195 |
+
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
196 |
+
len1, len2 = k_proj_weight_non_opt.size()
|
197 |
+
assert len1 == embed_dim and len2 == key.size(-1)
|
198 |
+
|
199 |
+
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
200 |
+
len1, len2 = v_proj_weight_non_opt.size()
|
201 |
+
assert len1 == embed_dim and len2 == value.size(-1)
|
202 |
+
|
203 |
+
if in_proj_bias is not None:
|
204 |
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
205 |
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)])
|
206 |
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])
|
207 |
+
else:
|
208 |
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias)
|
209 |
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias)
|
210 |
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias)
|
211 |
+
q = q * scaling
|
212 |
+
|
213 |
+
if attn_mask is not None:
|
214 |
+
assert (
|
215 |
+
attn_mask.dtype == torch.float32
|
216 |
+
or attn_mask.dtype == torch.float64
|
217 |
+
or attn_mask.dtype == torch.float16
|
218 |
+
or attn_mask.dtype == torch.uint8
|
219 |
+
or attn_mask.dtype == torch.bool
|
220 |
+
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype)
|
221 |
+
if attn_mask.dtype == torch.uint8:
|
222 |
+
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
223 |
+
attn_mask = attn_mask.to(torch.bool)
|
224 |
+
|
225 |
+
if attn_mask.dim() == 2:
|
226 |
+
attn_mask = attn_mask.unsqueeze(0)
|
227 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
228 |
+
raise RuntimeError("The size of the 2D attn_mask is not correct.")
|
229 |
+
elif attn_mask.dim() == 3:
|
230 |
+
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
|
231 |
+
raise RuntimeError("The size of the 3D attn_mask is not correct.")
|
232 |
+
else:
|
233 |
+
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
234 |
+
# attn_mask's dim is 3 now.
|
235 |
+
|
236 |
+
# convert ByteTensor key_padding_mask to bool
|
237 |
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
238 |
+
warnings.warn(
|
239 |
+
"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
|
240 |
+
)
|
241 |
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
242 |
+
|
243 |
+
if bias_k is not None and bias_v is not None:
|
244 |
+
if static_k is None and static_v is None:
|
245 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
246 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
247 |
+
if attn_mask is not None:
|
248 |
+
attn_mask = pad(attn_mask, (0, 1))
|
249 |
+
if key_padding_mask is not None:
|
250 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
251 |
+
else:
|
252 |
+
assert static_k is None, "bias cannot be added to static key."
|
253 |
+
assert static_v is None, "bias cannot be added to static value."
|
254 |
+
else:
|
255 |
+
assert bias_k is None
|
256 |
+
assert bias_v is None
|
257 |
+
|
258 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
259 |
+
if k is not None:
|
260 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
261 |
+
if v is not None:
|
262 |
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
263 |
+
|
264 |
+
if static_k is not None:
|
265 |
+
assert static_k.size(0) == bsz * num_heads
|
266 |
+
assert static_k.size(2) == head_dim
|
267 |
+
k = static_k
|
268 |
+
|
269 |
+
if static_v is not None:
|
270 |
+
assert static_v.size(0) == bsz * num_heads
|
271 |
+
assert static_v.size(2) == head_dim
|
272 |
+
v = static_v
|
273 |
+
|
274 |
+
src_len = k.size(1)
|
275 |
+
|
276 |
+
if key_padding_mask is not None:
|
277 |
+
# assert key_padding_mask.size(0) == bsz
|
278 |
+
assert key_padding_mask.size(1) == src_len
|
279 |
+
|
280 |
+
if add_zero_attn:
|
281 |
+
src_len += 1
|
282 |
+
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
283 |
+
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
284 |
+
if attn_mask is not None:
|
285 |
+
attn_mask = pad(attn_mask, (0, 1))
|
286 |
+
if key_padding_mask is not None:
|
287 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
288 |
+
|
289 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
290 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
291 |
+
|
292 |
+
if attn_mask is not None:
|
293 |
+
if attn_mask.dtype == torch.bool:
|
294 |
+
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
295 |
+
else:
|
296 |
+
attn_output_weights += attn_mask
|
297 |
+
|
298 |
+
if key_padding_mask is not None:
|
299 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
300 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
301 |
+
key_padding_mask.unsqueeze(1),
|
302 |
+
float("-inf"),
|
303 |
+
)
|
304 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
305 |
+
|
306 |
+
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
307 |
+
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
|
308 |
+
|
309 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
310 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
311 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
312 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
313 |
+
|
314 |
+
if need_weights:
|
315 |
+
# average attention weights over heads
|
316 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
317 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
318 |
+
else:
|
319 |
+
return attn_output, None
|
320 |
+
|
321 |
+
|
322 |
+
# This class exists solely for Transformer; it has an annotation stating
|
323 |
+
# that bias is never None, which appeases TorchScript
|
324 |
+
class _LinearWithBias(nn.Linear):
|
325 |
+
bias: Tensor # type: ignore
|
326 |
+
|
327 |
+
def __init__(self, in_features: int, out_features: int) -> None:
|
328 |
+
super().__init__(in_features, out_features, bias=True) # type: ignore
|
329 |
+
|
330 |
+
|
331 |
+
class MultiheadAttention(nn.Module):
|
332 |
+
r"""Allows the model to jointly attend to information
|
333 |
+
from different representation subspaces.
|
334 |
+
See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_
|
335 |
+
|
336 |
+
.. math::
|
337 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
338 |
+
|
339 |
+
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
embed_dim: total dimension of the model.
|
343 |
+
num_heads: parallel attention heads.
|
344 |
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
345 |
+
bias: add bias as module parameter. Default: True.
|
346 |
+
add_bias_kv: add bias to the key and value sequences at dim=0.
|
347 |
+
add_zero_attn: add a new batch of zeros to the key and
|
348 |
+
value sequences at dim=1.
|
349 |
+
kdim: total number of features in key. Default: None.
|
350 |
+
vdim: total number of features in value. Default: None.
|
351 |
+
|
352 |
+
Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
|
353 |
+
to :attr:`embed_dim` such that query, key, and value have the same
|
354 |
+
number of features.
|
355 |
+
|
356 |
+
Examples::
|
357 |
+
|
358 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
359 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
360 |
+
"""
|
361 |
+
bias_k: Optional[torch.Tensor]
|
362 |
+
bias_v: Optional[torch.Tensor]
|
363 |
+
|
364 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
|
365 |
+
super(MultiheadAttention, self).__init__()
|
366 |
+
self.embed_dim = embed_dim
|
367 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
368 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
369 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
370 |
+
|
371 |
+
self.num_heads = num_heads
|
372 |
+
self.dropout = dropout
|
373 |
+
self.head_dim = embed_dim // num_heads
|
374 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
375 |
+
|
376 |
+
if self._qkv_same_embed_dim is False:
|
377 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
378 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
379 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
380 |
+
self.register_parameter('in_proj_weight', None)
|
381 |
+
else:
|
382 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
383 |
+
self.register_parameter('q_proj_weight', None)
|
384 |
+
self.register_parameter('k_proj_weight', None)
|
385 |
+
self.register_parameter('v_proj_weight', None)
|
386 |
+
|
387 |
+
if bias:
|
388 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
389 |
+
else:
|
390 |
+
self.register_parameter('in_proj_bias', None)
|
391 |
+
self.out_proj = _LinearWithBias(embed_dim, embed_dim)
|
392 |
+
|
393 |
+
if add_bias_kv:
|
394 |
+
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
395 |
+
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
396 |
+
else:
|
397 |
+
self.bias_k = self.bias_v = None
|
398 |
+
|
399 |
+
self.add_zero_attn = add_zero_attn
|
400 |
+
|
401 |
+
self._reset_parameters()
|
402 |
+
|
403 |
+
def _reset_parameters(self):
|
404 |
+
if self._qkv_same_embed_dim:
|
405 |
+
xavier_uniform_(self.in_proj_weight)
|
406 |
+
else:
|
407 |
+
xavier_uniform_(self.q_proj_weight)
|
408 |
+
xavier_uniform_(self.k_proj_weight)
|
409 |
+
xavier_uniform_(self.v_proj_weight)
|
410 |
+
|
411 |
+
if self.in_proj_bias is not None:
|
412 |
+
constant_(self.in_proj_bias, 0.)
|
413 |
+
constant_(self.out_proj.bias, 0.)
|
414 |
+
if self.bias_k is not None:
|
415 |
+
xavier_normal_(self.bias_k)
|
416 |
+
if self.bias_v is not None:
|
417 |
+
xavier_normal_(self.bias_v)
|
418 |
+
|
419 |
+
def __setstate__(self, state):
|
420 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
421 |
+
if '_qkv_same_embed_dim' not in state:
|
422 |
+
state['_qkv_same_embed_dim'] = True
|
423 |
+
|
424 |
+
super(MultiheadAttention, self).__setstate__(state)
|
425 |
+
|
426 |
+
def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
|
427 |
+
need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
|
428 |
+
r"""
|
429 |
+
Args:
|
430 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
431 |
+
See "Attention Is All You Need" for more details.
|
432 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
433 |
+
be ignored by the attention. When given a binary mask and a value is True,
|
434 |
+
the corresponding value on the attention layer will be ignored. When given
|
435 |
+
a byte mask and a value is non-zero, the corresponding value on the attention
|
436 |
+
layer will be ignored
|
437 |
+
need_weights: output attn_output_weights.
|
438 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
439 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
440 |
+
|
441 |
+
Shapes for inputs:
|
442 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
443 |
+
the embedding dimension.
|
444 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
445 |
+
the embedding dimension.
|
446 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
447 |
+
the embedding dimension.
|
448 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
449 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
450 |
+
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
451 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
452 |
+
- attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
|
453 |
+
source sequence length.
|
454 |
+
|
455 |
+
If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
|
456 |
+
length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
|
457 |
+
the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
458 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
459 |
+
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
460 |
+
is provided, it will be added to the attention weight.
|
461 |
+
|
462 |
+
Shapes for outputs:
|
463 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
464 |
+
E is the embedding dimension.
|
465 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
466 |
+
L is the target sequence length, S is the source sequence length.
|
467 |
+
"""
|
468 |
+
if not self._qkv_same_embed_dim:
|
469 |
+
return multi_head_attention_forward(
|
470 |
+
query, key, value, self.embed_dim, self.num_heads,
|
471 |
+
self.in_proj_weight, self.in_proj_bias,
|
472 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
473 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
474 |
+
training=self.training,
|
475 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
476 |
+
attn_mask=attn_mask, use_separate_proj_weight=True,
|
477 |
+
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
478 |
+
v_proj_weight=self.v_proj_weight)
|
479 |
+
else:
|
480 |
+
return multi_head_attention_forward(
|
481 |
+
query, key, value, self.embed_dim, self.num_heads,
|
482 |
+
self.in_proj_weight, self.in_proj_bias,
|
483 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
484 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
485 |
+
training=self.training,
|
486 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
487 |
+
attn_mask=attn_mask)
|
modeling/modules/criterion.py
ADDED
@@ -0,0 +1,874 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Modified by Xueyan Zou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
9 |
+
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
10 |
+
"""
|
11 |
+
MaskFormer criterion.
|
12 |
+
"""
|
13 |
+
import logging
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
from detectron2.utils.comm import get_world_size
|
20 |
+
from timm.loss import SoftTargetCrossEntropy
|
21 |
+
from .point_features import (
|
22 |
+
get_uncertain_point_coords_with_randomness,
|
23 |
+
point_sample,
|
24 |
+
)
|
25 |
+
|
26 |
+
from ..language.loss import ql_multi_contrastive_loss, image_text_contrastive_loss_queue, vl_similarity, all_gather_grad
|
27 |
+
from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list, _max_by_axis
|
28 |
+
from ..utils import box_ops
|
29 |
+
|
30 |
+
# from image2html.visualizer import VL
|
31 |
+
|
32 |
+
|
33 |
+
def dice_loss(
|
34 |
+
inputs: torch.Tensor,
|
35 |
+
targets: torch.Tensor,
|
36 |
+
num_masks: float,
|
37 |
+
):
|
38 |
+
"""
|
39 |
+
Compute the DICE loss, similar to generalized IOU for masks
|
40 |
+
Args:
|
41 |
+
inputs: A float tensor of arbitrary shape.
|
42 |
+
The predictions for each example.
|
43 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
44 |
+
classification label for each element in inputs
|
45 |
+
(0 for the negative class and 1 for the positive class).
|
46 |
+
"""
|
47 |
+
inputs = inputs.sigmoid()
|
48 |
+
inputs = inputs.flatten(1)
|
49 |
+
numerator = 2 * (inputs * targets).sum(-1)
|
50 |
+
denominator = inputs.sum(-1) + targets.sum(-1)
|
51 |
+
loss = 1 - (numerator + 1) / (denominator + 1)
|
52 |
+
return loss.sum() / num_masks
|
53 |
+
|
54 |
+
|
55 |
+
dice_loss_jit = torch.jit.script(
|
56 |
+
dice_loss
|
57 |
+
) # type: torch.jit.ScriptModule
|
58 |
+
|
59 |
+
|
60 |
+
def sigmoid_ce_loss(
|
61 |
+
inputs: torch.Tensor,
|
62 |
+
targets: torch.Tensor,
|
63 |
+
num_masks: float,
|
64 |
+
):
|
65 |
+
"""
|
66 |
+
Args:
|
67 |
+
inputs: A float tensor of arbitrary shape.
|
68 |
+
The predictions for each example.
|
69 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
70 |
+
classification label for each element in inputs
|
71 |
+
(0 for the negative class and 1 for the positive class).
|
72 |
+
Returns:
|
73 |
+
Loss tensor
|
74 |
+
"""
|
75 |
+
loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
76 |
+
|
77 |
+
return loss.mean(1).sum() / num_masks
|
78 |
+
|
79 |
+
|
80 |
+
sigmoid_ce_loss_jit = torch.jit.script(
|
81 |
+
sigmoid_ce_loss
|
82 |
+
) # type: torch.jit.ScriptModule
|
83 |
+
|
84 |
+
|
85 |
+
def calculate_uncertainty(logits):
|
86 |
+
"""
|
87 |
+
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
|
88 |
+
foreground class in `classes`.
|
89 |
+
Args:
|
90 |
+
logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
|
91 |
+
class-agnostic, where R is the total number of predicted masks in all images and C is
|
92 |
+
the number of foreground classes. The values are logits.
|
93 |
+
Returns:
|
94 |
+
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
|
95 |
+
the most uncertain locations having the highest uncertainty score.
|
96 |
+
"""
|
97 |
+
assert logits.shape[1] == 1
|
98 |
+
gt_class_logits = logits.clone()
|
99 |
+
return -(torch.abs(gt_class_logits))
|
100 |
+
|
101 |
+
|
102 |
+
class SetCriterion(nn.Module):
|
103 |
+
"""This class computes the loss for DETR.
|
104 |
+
The process happens in two steps:
|
105 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
106 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(self, num_classes, matcher, weight_dict, eos_coef, top_x_layers, losses,
|
110 |
+
num_points, oversample_ratio, importance_sample_ratio, grounding_weight):
|
111 |
+
"""Create the criterion.
|
112 |
+
Parameters:
|
113 |
+
num_classes: number of object categories, omitting the special no-object category
|
114 |
+
matcher: module able to compute a matching between targets and proposals
|
115 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
116 |
+
eos_coef: relative classification weight applied to the no-object category
|
117 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
118 |
+
"""
|
119 |
+
super().__init__()
|
120 |
+
self.num_classes = num_classes
|
121 |
+
self.matcher = matcher
|
122 |
+
self.weight_dict = weight_dict
|
123 |
+
self.eos_coef = eos_coef
|
124 |
+
self.top_x_layers = top_x_layers
|
125 |
+
self.losses = losses
|
126 |
+
empty_weight = torch.ones(self.num_classes + 1)
|
127 |
+
empty_weight[-1] = self.eos_coef
|
128 |
+
self.register_buffer("empty_weight", empty_weight)
|
129 |
+
|
130 |
+
# pointwise mask loss parameters
|
131 |
+
self.num_points = num_points
|
132 |
+
self.oversample_ratio = oversample_ratio
|
133 |
+
self.importance_sample_ratio = importance_sample_ratio
|
134 |
+
|
135 |
+
# grounding
|
136 |
+
self.grounding_weight = grounding_weight
|
137 |
+
|
138 |
+
def loss_labels(self, outputs, targets, indices, num_masks, layer_id, extra):
|
139 |
+
"""Classification loss (NLL)
|
140 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
141 |
+
"""
|
142 |
+
if layer_id > self.top_x_layers['mask']:
|
143 |
+
return {"loss_mask_ce_0": 0}
|
144 |
+
|
145 |
+
if indices is None or len(targets) == 0:
|
146 |
+
loss_ce = outputs['pred_logits'].sum() * 0.0
|
147 |
+
losses = {"loss_mask_ce_0": loss_ce}
|
148 |
+
return losses
|
149 |
+
|
150 |
+
assert "pred_logits" in outputs
|
151 |
+
src_logits = outputs["pred_logits"].type(self.empty_weight.dtype)
|
152 |
+
|
153 |
+
idx = self._get_src_permutation_idx(indices)
|
154 |
+
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
|
155 |
+
target_classes = torch.full(
|
156 |
+
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
157 |
+
)
|
158 |
+
target_classes[idx] = target_classes_o
|
159 |
+
|
160 |
+
if src_logits.shape[2] == self.num_classes+1:
|
161 |
+
empty_weight = torch.ones(self.num_classes + 1).to(src_logits.device).type(self.empty_weight.dtype)
|
162 |
+
empty_weight[-1] = self.eos_coef
|
163 |
+
else:
|
164 |
+
empty_weight = torch.ones(self.num_classes + 1000 + 1).to(src_logits.device).type(self.empty_weight.dtype)
|
165 |
+
empty_weight[self.num_classes] = self.eos_coef
|
166 |
+
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes)
|
167 |
+
losses = {"loss_mask_ce_0": loss_ce}
|
168 |
+
return losses
|
169 |
+
|
170 |
+
def loss_labels_openimage(self, outputs, targets, indices, num_masks, layer_id, extra):
|
171 |
+
"""Classification loss (NLL)
|
172 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
173 |
+
"""
|
174 |
+
if layer_id > self.top_x_layers['mask']:
|
175 |
+
return {"loss_openimage_ce_0": 0}
|
176 |
+
|
177 |
+
assert "pred_captions" in outputs
|
178 |
+
|
179 |
+
if indices is None or len(targets) == 0 or (len(targets) > 0 and len(targets[0]['labels']) == 0):
|
180 |
+
loss_ce = outputs['pred_captions'].sum() * 0.0
|
181 |
+
losses = {"loss_openimage_ce_0": loss_ce}
|
182 |
+
return losses
|
183 |
+
|
184 |
+
# compute i2t loss
|
185 |
+
loss_openimage_ce = 0
|
186 |
+
losses = {}
|
187 |
+
for b in range(len(indices)):
|
188 |
+
pred_logit = outputs["pred_logits"][b][indices[b][0]]
|
189 |
+
gt_logit = torch.zeros_like(pred_logit)
|
190 |
+
select_idx = torch.stack((torch.arange(len(indices[b][1])), indices[b][1])).tolist()
|
191 |
+
gt_logit[select_idx] = 1
|
192 |
+
loss_openimage_ce += torch.sum(-gt_logit * F.log_softmax(pred_logit, dim=-1), dim=-1).mean()
|
193 |
+
loss_openimage_ce = loss_openimage_ce / len(indices)
|
194 |
+
losses.update({"loss_openimage_ce_0": loss_openimage_ce})
|
195 |
+
return losses
|
196 |
+
|
197 |
+
def loss_itc(self, outputs, targets, indices, num_masks, layer_id, extra):
|
198 |
+
if layer_id >= self.top_x_layers['retrieval']:
|
199 |
+
return {"loss_retrieval_decoder_0": 0}
|
200 |
+
t_emb = torch.cat([x['caption_proj'] for x in targets], dim=0)
|
201 |
+
v_emb = outputs['pred_captions'][:,-1]
|
202 |
+
loss_contrast = image_text_contrastive_loss_queue(v_emb, t_emb, extra['lang_encoder'], extra['training'])
|
203 |
+
|
204 |
+
# compute query-token contrastive loss
|
205 |
+
ttk_emb = torch.cat([x['caption_tokens'] for x in targets], dim=0)
|
206 |
+
ttk_mask = torch.cat([x['caption_mask'] for x in targets], dim=0).float()
|
207 |
+
ttk_mask = ttk_mask * torch.cumsum(ttk_mask, dim=1)
|
208 |
+
vtk_emb = outputs['pred_captions'][:,:-1]
|
209 |
+
keep = torch.cat([x['caption_mask'] for x in targets], dim=0).bool()
|
210 |
+
|
211 |
+
ttk_emb = ttk_emb / (ttk_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
212 |
+
vtk_emb = vtk_emb / (vtk_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
213 |
+
logit_scale = extra['lang_encoder'].logit_scale.exp().clamp(max=100)
|
214 |
+
|
215 |
+
# prepare gt
|
216 |
+
gt = (torch.eye(vtk_emb.shape[0]).type_as(ttk_mask).unsqueeze(-1) * ttk_mask.unsqueeze(0).repeat(vtk_emb.shape[0], 1, 1))[:,keep].flatten(1)
|
217 |
+
gt = gt / (gt.sum(1, keepdim=True) + 1e-7)
|
218 |
+
# compute i2t loss
|
219 |
+
logits = logit_scale * (vtk_emb @ ttk_emb[keep].transpose(0, 1)).mean(1)
|
220 |
+
loss_contrast_fine_vt = SoftTargetCrossEntropy()(logits, gt)
|
221 |
+
# loss_contrast_fine = loss_contrast_fine_vt # i2t only
|
222 |
+
|
223 |
+
# compute t2i loss
|
224 |
+
bs, nq, _ = vtk_emb.shape
|
225 |
+
logits = logit_scale * (ttk_emb @ vtk_emb.flatten(0,1).transpose(0, 1)).reshape(bs,-1,bs,nq).mean(dim=-1)[keep,:]
|
226 |
+
loss_contrast_fine_tv = SoftTargetCrossEntropy()(logits, gt.t())
|
227 |
+
# compute loss
|
228 |
+
loss_contrast_fine = (loss_contrast_fine_vt * 0.7 + loss_contrast_fine_tv * 0.3)
|
229 |
+
|
230 |
+
losses = {"loss_retrieval_decoder_0": loss_contrast + loss_contrast_fine * 0.5}
|
231 |
+
return losses
|
232 |
+
|
233 |
+
def loss_captionings(self, outputs, targets, indices, num_masks, layer_id, extra):
|
234 |
+
if layer_id >= self.top_x_layers['captioning']:
|
235 |
+
return {"loss_captioning_0": 0}
|
236 |
+
|
237 |
+
pred_captions_gen = outputs['pred_captionings'][:, :-1]
|
238 |
+
token_embs = extra['token_embedding'].weight
|
239 |
+
# token_embs = (token_embs / token_embs.norm(dim=-1, keepdim=True) + 1e-7)
|
240 |
+
# pred_captions_gen = (pred_captions_gen / pred_captions_gen.norm(dim=-1, keepdim=True) + 1e-7)
|
241 |
+
pred_captions_gen = pred_captions_gen @ token_embs.t()
|
242 |
+
|
243 |
+
# temperature = extra['lang_encoder'].logit_scale
|
244 |
+
# logit_scale = temperature.exp().clamp(max=100)
|
245 |
+
|
246 |
+
target_captions_gen = torch.cat([target['caption_tokenids'] for target in targets], 0)[:, 1:]
|
247 |
+
target_captions_gen_mask = torch.cat([target['caption_mask'] for target in targets], 0)[:, 1:]
|
248 |
+
|
249 |
+
# loss_caption = F.cross_entropy(pred_captions_gen.transpose(1,2) * logit_scale, target_captions_gen, reduction='none')
|
250 |
+
loss_caption = F.cross_entropy(pred_captions_gen.transpose(1,2), target_captions_gen, reduction='none')
|
251 |
+
loss_caption = (loss_caption * target_captions_gen_mask).sum() / (target_captions_gen_mask.sum() + 1)
|
252 |
+
losses = {"loss_captioning_0": loss_caption}
|
253 |
+
return losses
|
254 |
+
|
255 |
+
def loss_captions(self, outputs, targets, indices, num_masks, layer_id, extra):
|
256 |
+
if layer_id >= self.top_x_layers['caption']:
|
257 |
+
return {"loss_caption_0": 0}
|
258 |
+
matched_tokens = [m[0] for m in indices]
|
259 |
+
t_emb_class = torch.cat([extra['class_embeddings'][targets[bs]['labels'][m[1]]] for bs, m in enumerate(indices)])
|
260 |
+
t_hash_class = torch.cat([torch.tensor(targets[bs]['labels_hash'])[m[1]] for bs, m in enumerate(indices)])
|
261 |
+
|
262 |
+
# pred_captions denotes all unmatched object queries.
|
263 |
+
unmatched_pred_captions = []
|
264 |
+
matched_pred_captions = []
|
265 |
+
for idx, m in enumerate(matched_tokens):
|
266 |
+
unmatched_masks = torch.ones(outputs['pred_captions'].shape[1:-1]).bool()
|
267 |
+
matched_masks = torch.zeros(outputs['pred_captions'].shape[1:-1]).bool()
|
268 |
+
|
269 |
+
unmatched_masks[m] = False
|
270 |
+
matched_masks[m] = True
|
271 |
+
|
272 |
+
unmatched_pred_captions.append(outputs['pred_captions'][idx][unmatched_masks])
|
273 |
+
matched_pred_captions.append(outputs['pred_captions'][idx][matched_masks])
|
274 |
+
|
275 |
+
outputs['unmatched_pred_captions'] = unmatched_pred_captions
|
276 |
+
v_emb_class = torch.cat(matched_pred_captions)
|
277 |
+
v_emb_class = v_emb_class / (v_emb_class.norm(dim=-1, keepdim=True) + 1e-7)
|
278 |
+
|
279 |
+
indices = self.matcher(outputs, targets, mode="caption_womask", extra={'temperature':extra['lang_logit']})
|
280 |
+
src_idx = self._get_src_permutation_idx(indices)
|
281 |
+
|
282 |
+
t_emb = torch.cat([t['captions'][indices[bs][1]] for bs,t in enumerate(targets)])
|
283 |
+
t_hash = torch.cat([torch.tensor(t['captions_hash'])[indices[bs][1]] for bs,t in enumerate(targets)])
|
284 |
+
|
285 |
+
unmatched_pred_captions, _ = nested_tensor_from_tensor_list(unmatched_pred_captions).decompose()
|
286 |
+
v_emb = unmatched_pred_captions[src_idx]
|
287 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
288 |
+
|
289 |
+
loss_contrast = ql_multi_contrastive_loss(torch.cat((v_emb, v_emb_class)), torch.cat((t_emb, t_emb_class)), torch.cat((t_hash, t_hash_class)), temperature=extra['lang_logit'])
|
290 |
+
losses = {"loss_caption_0": loss_contrast}
|
291 |
+
|
292 |
+
return losses
|
293 |
+
|
294 |
+
def loss_masks(self, outputs, targets, indices, num_masks, layer_id, extra):
|
295 |
+
"""Compute the losses related to the masks: the focal loss and the dice loss.
|
296 |
+
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
|
297 |
+
"""
|
298 |
+
if layer_id >= self.top_x_layers['mask']:
|
299 |
+
return {"loss_mask_bce_0": 0, "loss_mask_dice_0": 0}
|
300 |
+
|
301 |
+
assert "pred_masks" in outputs
|
302 |
+
if indices is None or len(targets) == 0:
|
303 |
+
loss = outputs['pred_masks'].sum() * 0.0
|
304 |
+
losses = {"loss_mask_bce_0": loss, "loss_mask_dice_0": loss}
|
305 |
+
return losses
|
306 |
+
|
307 |
+
src_idx = self._get_src_permutation_idx(indices)
|
308 |
+
tgt_idx = self._get_tgt_permutation_idx(indices)
|
309 |
+
src_masks = outputs["pred_masks"]
|
310 |
+
src_masks = src_masks[src_idx]
|
311 |
+
masks = [t["masks"] for t in targets]
|
312 |
+
# TODO use valid to mask invalid areas due to padding in loss
|
313 |
+
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
314 |
+
target_masks = target_masks.to(src_masks)
|
315 |
+
target_masks = target_masks[tgt_idx]
|
316 |
+
# No need to upsample predictions as we are using normalized coordinates :)
|
317 |
+
# N x 1 x H x W
|
318 |
+
src_masks = src_masks[:, None]
|
319 |
+
target_masks = target_masks[:, None]
|
320 |
+
|
321 |
+
with torch.no_grad():
|
322 |
+
# sample point_coords
|
323 |
+
point_coords = get_uncertain_point_coords_with_randomness(
|
324 |
+
src_masks,
|
325 |
+
lambda logits: calculate_uncertainty(logits),
|
326 |
+
self.num_points,
|
327 |
+
self.oversample_ratio,
|
328 |
+
self.importance_sample_ratio,
|
329 |
+
).type(src_masks.dtype)
|
330 |
+
# get gt labels
|
331 |
+
point_labels = point_sample(
|
332 |
+
target_masks,
|
333 |
+
point_coords,
|
334 |
+
align_corners=False,
|
335 |
+
).squeeze(1)
|
336 |
+
|
337 |
+
point_logits = point_sample(
|
338 |
+
src_masks,
|
339 |
+
point_coords,
|
340 |
+
align_corners=False,
|
341 |
+
).squeeze(1)
|
342 |
+
|
343 |
+
losses = {
|
344 |
+
"loss_mask_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
|
345 |
+
"loss_mask_dice_0": dice_loss_jit(point_logits, point_labels, num_masks),
|
346 |
+
}
|
347 |
+
|
348 |
+
del src_masks
|
349 |
+
del target_masks
|
350 |
+
return losses
|
351 |
+
|
352 |
+
def loss_groundings(self, outputs, targets, indices, num_masks, layer_id, extra):
|
353 |
+
"""Compute the losses related to the masks: the focal loss and the dice loss.
|
354 |
+
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
|
355 |
+
"""
|
356 |
+
assert "pred_gmasks" in outputs
|
357 |
+
assert "pred_gtexts" in outputs
|
358 |
+
|
359 |
+
if layer_id >= self.top_x_layers['grounding']:
|
360 |
+
return {"loss_grounding_bce_0": 0, "loss_grounding_dice_0": 0, "loss_grounding_ce_0": 0}
|
361 |
+
|
362 |
+
masks = [t["grounding_masks"] for t in targets]
|
363 |
+
if indices is None or None in masks:
|
364 |
+
loss = outputs['pred_gmasks'].sum() * 0.0
|
365 |
+
return {"loss_grounding_bce_0": loss, "loss_grounding_dice_0": loss, "loss_grounding_ce_0": loss}
|
366 |
+
|
367 |
+
pred_logits = []
|
368 |
+
for b in range(len(indices)):
|
369 |
+
t_emb = targets[b]['grounding_class_embs']
|
370 |
+
v_emb = outputs["pred_gtexts"][b]
|
371 |
+
|
372 |
+
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
373 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
374 |
+
|
375 |
+
out_prob = vl_similarity(v_emb, t_emb, temperature=extra['lang_logit'])
|
376 |
+
pred_logits += [out_prob]
|
377 |
+
outputs['pred_logits'] = pred_logits
|
378 |
+
|
379 |
+
indices = self.matcher(outputs, targets, mode='grounding', extra={'temperature':extra['lang_logit']})
|
380 |
+
src_idx = self._get_src_permutation_idx(indices)
|
381 |
+
tgt_idx = self._get_tgt_permutation_idx(indices)
|
382 |
+
|
383 |
+
src_masks = outputs["pred_gmasks"]
|
384 |
+
src_masks = src_masks[src_idx]
|
385 |
+
# TODO use valid to mask invalid areas due to padding in loss
|
386 |
+
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
387 |
+
target_masks = target_masks.to(src_masks)
|
388 |
+
target_masks = target_masks[tgt_idx]
|
389 |
+
# No need to upsample predictions as we are using normalized coordinates :)
|
390 |
+
# N x 1 x H x W
|
391 |
+
src_masks = src_masks[:, None]
|
392 |
+
target_masks = target_masks[:, None]
|
393 |
+
|
394 |
+
with torch.no_grad():
|
395 |
+
# sample point_coords
|
396 |
+
point_coords = get_uncertain_point_coords_with_randomness(
|
397 |
+
src_masks,
|
398 |
+
lambda logits: calculate_uncertainty(logits),
|
399 |
+
self.num_points,
|
400 |
+
self.oversample_ratio,
|
401 |
+
self.importance_sample_ratio,
|
402 |
+
).type(src_masks.dtype)
|
403 |
+
# get gt labels
|
404 |
+
point_labels = point_sample(
|
405 |
+
target_masks,
|
406 |
+
point_coords,
|
407 |
+
align_corners=False,
|
408 |
+
).squeeze(1)
|
409 |
+
|
410 |
+
point_logits = point_sample(
|
411 |
+
src_masks,
|
412 |
+
point_coords,
|
413 |
+
align_corners=False,
|
414 |
+
).squeeze(1)
|
415 |
+
|
416 |
+
losses = {
|
417 |
+
"loss_grounding_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, len(src_masks)),
|
418 |
+
"loss_grounding_dice_0": dice_loss_jit(point_logits, point_labels, len(src_masks)),
|
419 |
+
}
|
420 |
+
|
421 |
+
# compute query-token contrastive loss
|
422 |
+
# ttk_emb = torch.cat([x['caption_tokens'] for x in targets], dim=0)
|
423 |
+
# ttk_mask = torch.cat([x['caption_mask'] for x in targets], dim=0).float()
|
424 |
+
# ttk_mask = ttk_mask * torch.cumsum(ttk_mask, dim=1)
|
425 |
+
# vtk_emb = outputs['pred_captions'][:,:-1]
|
426 |
+
# keep = torch.cat([x['caption_mask'] for x in targets], dim=0).bool()
|
427 |
+
|
428 |
+
# ttk_emb = ttk_emb / (ttk_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
429 |
+
# vtk_emb = vtk_emb / (vtk_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
430 |
+
# logit_scale = extra['lang_encoder'].logit_scale.exp().clamp(max=100)
|
431 |
+
|
432 |
+
# # prepare gt
|
433 |
+
# gt = (torch.eye(vtk_emb.shape[0]).type_as(ttk_mask).unsqueeze(-1) * ttk_mask.unsqueeze(0).repeat(vtk_emb.shape[0], 1, 1))[:,keep].flatten(1)
|
434 |
+
# gt = gt / (gt.sum(1, keepdim=True) + 1e-7)
|
435 |
+
# # compute i2t loss
|
436 |
+
# logits = logit_scale * (vtk_emb @ ttk_emb[keep].transpose(0, 1)).mean(1)
|
437 |
+
# loss_contrast_fine_vt = SoftTargetCrossEntropy()(logits, gt)
|
438 |
+
# # loss_contrast_fine = loss_contrast_fine_vt # i2t only
|
439 |
+
|
440 |
+
# # compute t2i loss
|
441 |
+
# bs, nq, _ = vtk_emb.shape
|
442 |
+
# logits = logit_scale * (ttk_emb @ vtk_emb.flatten(0,1).transpose(0, 1)).reshape(bs,-1,bs,nq).mean(dim=-1)[keep,:]
|
443 |
+
# loss_contrast_fine_tv = SoftTargetCrossEntropy()(logits, gt.t())
|
444 |
+
# # compute loss
|
445 |
+
# loss_contrast_fine = (loss_contrast_fine_vt * 0.7 + loss_contrast_fine_tv * 0.3)
|
446 |
+
|
447 |
+
# compute t2i loss
|
448 |
+
loss_grd_ce = 0
|
449 |
+
for b in range(len(indices)):
|
450 |
+
task = targets[b]['grounding_task']
|
451 |
+
pred_logit = outputs["pred_logits"][b]
|
452 |
+
gt_logit = torch.zeros_like(pred_logit)
|
453 |
+
select_idx = torch.stack((indices[b][0], indices[b][1])).tolist()
|
454 |
+
gt_logit[select_idx] = 1
|
455 |
+
t_hash = torch.tensor(targets[b]['grounding_hash'], device=gt_logit.device)
|
456 |
+
hash_table = torch.zeros((len(t_hash), len(t_hash)), device=gt_logit.device)
|
457 |
+
for idx in range(0, len(hash_table)):
|
458 |
+
hash_table[idx][t_hash==t_hash[idx]] = 1
|
459 |
+
hash_table = hash_table / hash_table.sum(-1, keepdim=True)
|
460 |
+
gt_logit = gt_logit @ hash_table
|
461 |
+
loss_grd_ce += self.grounding_weight[task]*torch.sum(-gt_logit.t() * F.log_softmax(pred_logit.t(), dim=-1), dim=-1).mean()
|
462 |
+
loss_grd_ce = loss_grd_ce / len(indices)
|
463 |
+
losses.update({"loss_grounding_ce_0": loss_grd_ce})
|
464 |
+
del src_masks
|
465 |
+
del target_masks
|
466 |
+
return losses
|
467 |
+
|
468 |
+
def loss_spatials(self, outputs, targets, indices, num_masks, layer_id, extra):
|
469 |
+
"""Compute the losses related to the masks: the focal loss and the dice loss.
|
470 |
+
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
|
471 |
+
"""
|
472 |
+
assert "pred_smasks" in outputs
|
473 |
+
assert "pred_smaskembs" in outputs
|
474 |
+
|
475 |
+
if layer_id >= self.top_x_layers['spatial']:
|
476 |
+
loss = outputs['pred_smasks'].sum() * 0.0
|
477 |
+
loss_grd_ce = outputs["pred_smasks"].sum() * 0.0
|
478 |
+
return {"loss_spatial_bce_0": loss, "loss_spatial_dice_0": loss, "loss_spatial_ce_0": loss_grd_ce}
|
479 |
+
|
480 |
+
gt_masks = [x['gt_spatial_masks'] for x in targets]
|
481 |
+
# compute a keep index with batch size to avoid empty gt_masks
|
482 |
+
stack_gt_mask = torch.cat(gt_masks)
|
483 |
+
bs,_,_ = stack_gt_mask.shape
|
484 |
+
stack_gt_mask = stack_gt_mask.view(bs,-1).sum(dim=-1)
|
485 |
+
keep = stack_gt_mask > 0 # only keep sample contain positive mask
|
486 |
+
|
487 |
+
if keep.sum() == 0:
|
488 |
+
loss = outputs['pred_smasks'].sum() * 0.0
|
489 |
+
loss_grd_ce = outputs["pred_smasks"].sum() * 0.0
|
490 |
+
return {"loss_spatial_bce_0": loss, "loss_spatial_dice_0": loss, "loss_spatial_ce_0": loss_grd_ce}
|
491 |
+
|
492 |
+
# mask embedding logits
|
493 |
+
v_emb = outputs["pred_smaskembs"] # [bs, nq, 512]
|
494 |
+
|
495 |
+
# pos mask
|
496 |
+
s_emb = outputs["pred_pspatials"] # [bs, ns, 512]
|
497 |
+
pred_logits = v_emb @ s_emb.transpose(1,2)
|
498 |
+
outputs['pred_pos_logits'] = pred_logits # [bs, nq, 1]
|
499 |
+
indices = self.matcher(outputs, targets, mode='spatial', extra={})
|
500 |
+
src_idx = self._get_src_permutation_idx(indices)
|
501 |
+
tgt_idx = self._get_tgt_permutation_idx(indices)
|
502 |
+
|
503 |
+
# pos class loss
|
504 |
+
pred_logit = torch.cat([o[:len(t['gt_spatial_masks'])] for o,t in zip(outputs["pred_pos_logits"].transpose(1,2), targets)])
|
505 |
+
gt_logit = torch.zeros_like(pred_logit)
|
506 |
+
gt_logit = gt_logit[keep]
|
507 |
+
_src_idx = [torch.arange(keep.sum(), device=src_idx[0].device), src_idx[1][keep.cpu()]]
|
508 |
+
gt_logit[_src_idx] = 1
|
509 |
+
pred_logit = pred_logit[keep]
|
510 |
+
loss_spa_ce_pos = torch.sum(-gt_logit * F.log_softmax(pred_logit, dim=-1), dim=-1).mean()
|
511 |
+
|
512 |
+
# neg mask
|
513 |
+
# s_emb = outputs["pred_nspatials"] # [bs, ns, 512]
|
514 |
+
# neg_mask = (s_emb.sum(dim=list(range(1, len(s_emb.shape)))) != 0).float()[keep]
|
515 |
+
# pred_logits = v_emb @ s_emb.transpose(1,2)
|
516 |
+
# outputs['pred_neg_logits'] = pred_logits # [bs, nq, 1]
|
517 |
+
# indices = self.matcher(outputs, targets, mode='spatial_pn', extra=extra)
|
518 |
+
# src_idx = self._get_src_permutation_idx(indices)
|
519 |
+
# tgt_idx = self._get_tgt_permutation_idx(indices)
|
520 |
+
# src_masks_neg = outputs["pred_smasks"][src_idx][keep]
|
521 |
+
# src_masks_neg = src_masks_neg*(neg_mask[:,None,None])
|
522 |
+
# src_masks_neg = src_masks_neg.clip(0) * (-1)
|
523 |
+
|
524 |
+
# neg class loss
|
525 |
+
# pred_logit = outputs["pred_neg_logits"]
|
526 |
+
# gt_logit = torch.zeros_like(pred_logit)
|
527 |
+
# gt_logit[src_idx] = 1
|
528 |
+
# bs,_,ns = pred_logit[keep].shape
|
529 |
+
# pred_logit = pred_logit[keep].transpose(1,2).view(bs*ns,-1)
|
530 |
+
# gt_logit = gt_logit[keep].transpose(1,2).view(bs*ns,-1)
|
531 |
+
# loss_spa_ce_neg = (torch.sum(-gt_logit * F.log_softmax(pred_logit, dim=-1), dim=-1)*neg_mask).sum() / (neg_mask.sum()+1e-6)
|
532 |
+
|
533 |
+
# recompute a keep index with matched tgt
|
534 |
+
stack_gt_mask = nn.utils.rnn.pad_sequence(gt_masks, padding_value=-1).transpose(0,1)[tgt_idx]
|
535 |
+
bs,_,_ = stack_gt_mask.shape
|
536 |
+
target_masks = stack_gt_mask
|
537 |
+
stack_gt_mask = stack_gt_mask.view(bs,-1).sum(dim=-1)
|
538 |
+
keep = stack_gt_mask > 0 # only keep sample contain positive mask
|
539 |
+
src_masks_pos = outputs["pred_smasks"][src_idx][keep]
|
540 |
+
|
541 |
+
# TODO use valid to mask invalid areas due to padding in loss
|
542 |
+
target_masks = target_masks.to(src_masks_pos)
|
543 |
+
target_masks = target_masks[keep]
|
544 |
+
|
545 |
+
# mul = extra['spatial_query_mode'][keep]
|
546 |
+
# src_masks_cur = src_masks_cur.clip(0) * mul[:,None,None]
|
547 |
+
# src_masks_cur = src_masks_cur
|
548 |
+
|
549 |
+
# if neg_mask[0] == 1:
|
550 |
+
# import cv2
|
551 |
+
# print(src_masks_pos.shape)
|
552 |
+
# print(src_masks_neg.shape)
|
553 |
+
# print(target_masks.shape)
|
554 |
+
# # import pdb; pdb.set_trace()
|
555 |
+
# v_pos_mask = (src_masks_pos[0].sigmoid() > 0.5).float().cpu().detach().numpy() * 255
|
556 |
+
# v_neg_mask = (_src_masks_neg[0].sigmoid() > 0.5).float().cpu().detach().numpy() * 255
|
557 |
+
# v_sum = ((src_masks_pos[0]-_src_masks_neg[0].clip(0)).sigmoid() > 0.5).float().cpu().detach().numpy() * 255
|
558 |
+
# v_gt = target_masks[0].float().cpu().detach().numpy() * 255
|
559 |
+
|
560 |
+
# cv2.imwrite('v_pos_mask.png', v_pos_mask)
|
561 |
+
# cv2.imwrite('v_neg_mask.png', v_neg_mask)
|
562 |
+
# cv2.imwrite('v_sum.png', v_sum)
|
563 |
+
# cv2.imwrite('v_gt.png', v_gt)
|
564 |
+
# import pdb; pdb.set_trace()
|
565 |
+
|
566 |
+
# src_masks = (src_masks_pos + src_masks_neg)[:, None]
|
567 |
+
src_masks = src_masks_pos[:, None]
|
568 |
+
target_masks = target_masks[:, None]
|
569 |
+
|
570 |
+
# debug visualization
|
571 |
+
# with torch.no_grad():
|
572 |
+
# import cv2
|
573 |
+
# import numpy as np
|
574 |
+
|
575 |
+
# v_src_masks = (F.interpolate(src_masks, size=target_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5).float().cpu().numpy()[:,0] * 255
|
576 |
+
# v_target_masks = target_masks.float().cpu().numpy()[:,0] * 255
|
577 |
+
# v_masks = np.concatenate([v_src_masks, v_target_masks], axis=2)
|
578 |
+
|
579 |
+
# for i in range(len(src_masks)):
|
580 |
+
# v1 = v_src_masks[i]
|
581 |
+
# v2 = v_target_masks[i]
|
582 |
+
# v = np.concatenate([v1,v2], axis=1)
|
583 |
+
# cv2.imwrite('v{}.png'.format(i), v)
|
584 |
+
# import pdb; pdb.set_trace()
|
585 |
+
|
586 |
+
# visualization
|
587 |
+
# VL.step()
|
588 |
+
# v_img = batched_inputs[0]['image'].permute(1,2,0).cpu().numpy()
|
589 |
+
# VL.add_image(v_img[:,:,::-1])
|
590 |
+
# candidate_masks = batched_inputs[0]['spatial_query']['rand_shape'].float().cpu().numpy()
|
591 |
+
# gt_masks = batched_inputs[0]['spatial_query']['gt_masks'].float().cpu().numpy()
|
592 |
+
# texts = ['cmask' for i in range(len(candidate_masks))]
|
593 |
+
# VL.overlay_obj_mask_to_image(v_img[:,:,::-1], candidate_masks, texts)
|
594 |
+
# texts = ['gmask' for i in range(len(candidate_masks))]
|
595 |
+
# VL.overlay_obj_mask_to_image(v_img[:,:,::-1], gt_masks, texts)
|
596 |
+
|
597 |
+
# import cv2
|
598 |
+
# for i in range(len(src_masks)):
|
599 |
+
# visual_src_mask_cur = (src_masks_cur[i].sigmoid()>0.5).detach().float().cpu().numpy() * 255
|
600 |
+
# visual_src_mask_mem = (src_masks_mem[i].sigmoid()>0.5).detach().float().cpu().numpy() * 255
|
601 |
+
# visual_src_mask = (src_masks[i,0].sigmoid()>0.5).detach().float().cpu().numpy() * 255
|
602 |
+
# visual_target_mask = (target_masks[i,0].sigmoid()>0.5).detach().float().cpu().numpy() * 255
|
603 |
+
|
604 |
+
# cv2.imwrite('visual_src_mask_cur_{}_{}.png'.format(i, mul[i].item()), visual_src_mask_cur)
|
605 |
+
# cv2.imwrite('visual_src_mask_mem_{}_{}.png'.format(i, mul[i].item()), visual_src_mask_mem)
|
606 |
+
# cv2.imwrite('visual_src_mask_{}_{}.png'.format(i, mul[i].item()), visual_src_mask)
|
607 |
+
# cv2.imwrite('visual_target_mask_{}_{}.png'.format(i, mul[i].item()), visual_target_mask)
|
608 |
+
# import pdb; pdb.set_trace()
|
609 |
+
|
610 |
+
with torch.no_grad():
|
611 |
+
# sample point_coords
|
612 |
+
point_coords = get_uncertain_point_coords_with_randomness(
|
613 |
+
src_masks,
|
614 |
+
lambda logits: calculate_uncertainty(logits),
|
615 |
+
self.num_points,
|
616 |
+
self.oversample_ratio,
|
617 |
+
self.importance_sample_ratio,
|
618 |
+
).type(src_masks.dtype)
|
619 |
+
# get gt labels
|
620 |
+
point_labels = point_sample(
|
621 |
+
target_masks,
|
622 |
+
point_coords,
|
623 |
+
align_corners=False,
|
624 |
+
).squeeze(1)
|
625 |
+
|
626 |
+
point_logits = point_sample(
|
627 |
+
src_masks,
|
628 |
+
point_coords,
|
629 |
+
align_corners=False,
|
630 |
+
).squeeze(1)
|
631 |
+
|
632 |
+
num_masks = len(src_masks)
|
633 |
+
losses = {
|
634 |
+
"loss_spatial_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
|
635 |
+
"loss_spatial_dice_0": dice_loss_jit(point_logits, point_labels, num_masks),
|
636 |
+
}
|
637 |
+
|
638 |
+
# losses.update({"loss_spatial_ce_0": loss_spa_ce_pos + loss_spa_ce_neg})
|
639 |
+
losses.update({"loss_spatial_ce_0": loss_spa_ce_pos})
|
640 |
+
|
641 |
+
del src_masks
|
642 |
+
del target_masks
|
643 |
+
return losses
|
644 |
+
|
645 |
+
def loss_boxes(self, outputs, targets, indices, num_boxes, layer_id, extra):
|
646 |
+
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
647 |
+
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
|
648 |
+
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
|
649 |
+
"""
|
650 |
+
if layer_id >= self.top_x_layers['box']:
|
651 |
+
return {"loss_bbox_0": 0, "loss_giou_0": 0}
|
652 |
+
|
653 |
+
assert 'pred_boxes' in outputs
|
654 |
+
|
655 |
+
if indices is None or len(targets) == 0:
|
656 |
+
loss = outputs['pred_boxes'].sum() * 0.0
|
657 |
+
losses = {"loss_bbox_0": loss, "loss_giou_0": loss}
|
658 |
+
return losses
|
659 |
+
|
660 |
+
src_idx = self._get_src_permutation_idx(indices)
|
661 |
+
tgt_idx = self._get_tgt_permutation_idx(indices)
|
662 |
+
src_boxes = outputs["pred_boxes"]
|
663 |
+
src_boxes = src_boxes[src_idx].sigmoid()
|
664 |
+
|
665 |
+
target_boxes = [t['boxes'] for t in targets]
|
666 |
+
max_size = _max_by_axis([list(box.shape) for box in target_boxes])
|
667 |
+
max_size = [len(target_boxes)] + max_size
|
668 |
+
empty_boxes = torch.zeros(max_size).to(src_boxes.device)
|
669 |
+
for idx, tar_box in enumerate(target_boxes):
|
670 |
+
empty_boxes[idx,:tar_box.shape[0],:] = tar_box
|
671 |
+
target_boxes = empty_boxes[tgt_idx]
|
672 |
+
|
673 |
+
# target_isthings = [t['is_things'] for t in targets]
|
674 |
+
# max_size = _max_by_axis([list(lab.shape) for lab in target_isthings])
|
675 |
+
# max_size = [len(target_isthings)] + max_size
|
676 |
+
# empty_lab = torch.zeros(max_size).to(src_boxes.device)
|
677 |
+
|
678 |
+
# for idx, tar_thing in enumerate(target_isthings):
|
679 |
+
# empty_lab[idx,:tar_thing.shape[0]] = tar_thing
|
680 |
+
# target_isthings = empty_lab[tgt_idx]
|
681 |
+
|
682 |
+
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
|
683 |
+
losses = {}
|
684 |
+
losses['loss_bbox_0'] = loss_bbox.sum() / num_boxes
|
685 |
+
|
686 |
+
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
|
687 |
+
box_ops.box_cxcywh_to_xyxy(src_boxes),
|
688 |
+
box_ops.box_cxcywh_to_xyxy(target_boxes)))
|
689 |
+
losses['loss_giou_0'] = loss_giou.sum() / num_boxes
|
690 |
+
return losses
|
691 |
+
|
692 |
+
def _get_src_permutation_idx(self, indices):
|
693 |
+
# permute predictions following indices
|
694 |
+
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
695 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
696 |
+
return batch_idx, src_idx
|
697 |
+
|
698 |
+
def _get_tgt_permutation_idx(self, indices):
|
699 |
+
# permute targets following indices
|
700 |
+
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
701 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
702 |
+
return batch_idx, tgt_idx
|
703 |
+
|
704 |
+
def get_loss(self, loss, outputs, targets, indices, num_masks, layer_id, extra):
|
705 |
+
loss_map = {
|
706 |
+
'labels': self.loss_labels,
|
707 |
+
'masks': self.loss_masks,
|
708 |
+
'boxes': self.loss_boxes,
|
709 |
+
'captions': self.loss_captions,
|
710 |
+
'retrievals': self.loss_itc,
|
711 |
+
'captionings': self.loss_captionings,
|
712 |
+
'groundings': self.loss_groundings,
|
713 |
+
'labels_openimage': self.loss_labels_openimage,
|
714 |
+
'spatials': self.loss_spatials,
|
715 |
+
}
|
716 |
+
assert loss in loss_map, f"do you really want to compute {loss} loss?"
|
717 |
+
return loss_map[loss](outputs, targets, indices, num_masks, layer_id, extra)
|
718 |
+
|
719 |
+
def forward(self, outputs, targets, extra=None):
|
720 |
+
"""This performs the loss computation.
|
721 |
+
Parameters:
|
722 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
723 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
724 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
725 |
+
"""
|
726 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
|
727 |
+
|
728 |
+
# Retrieve the matching between the outputs of the last layer and the targets
|
729 |
+
indices = self.matcher(outputs_without_aux, targets)
|
730 |
+
|
731 |
+
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
732 |
+
num_masks = sum(len(t["labels"]) for t in targets)
|
733 |
+
num_masks = torch.as_tensor(
|
734 |
+
[num_masks], dtype=torch.float, device=next(iter(outputs_without_aux.values())).device
|
735 |
+
)
|
736 |
+
if is_dist_avail_and_initialized():
|
737 |
+
torch.distributed.all_reduce(num_masks)
|
738 |
+
num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
|
739 |
+
|
740 |
+
# Compute all the requested losses
|
741 |
+
losses = {}
|
742 |
+
for loss in self.losses:
|
743 |
+
losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
|
744 |
+
|
745 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
746 |
+
if "aux_outputs" in outputs:
|
747 |
+
# NOTE: we reverse the aux_outputs so that the first is the second last layer
|
748 |
+
for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
|
749 |
+
indices = self.matcher(aux_outputs, targets)
|
750 |
+
for loss in self.losses:
|
751 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
|
752 |
+
l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
|
753 |
+
losses.update(l_dict)
|
754 |
+
|
755 |
+
return losses
|
756 |
+
|
757 |
+
def forward_vlp(self, outputs, targets, extra=None):
|
758 |
+
"""This performs the loss computation.
|
759 |
+
Parameters:
|
760 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
761 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
762 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
763 |
+
"""
|
764 |
+
# Compute all the requested losses
|
765 |
+
losses = {}
|
766 |
+
num_masks = indices = None
|
767 |
+
for loss in self.losses:
|
768 |
+
losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
|
769 |
+
|
770 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
771 |
+
if "aux_outputs" in outputs:
|
772 |
+
# NOTE: we reverse the aux_outputs so that the first is the second last layer
|
773 |
+
for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
|
774 |
+
for loss in self.losses:
|
775 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
|
776 |
+
l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
|
777 |
+
losses.update(l_dict)
|
778 |
+
|
779 |
+
return losses
|
780 |
+
|
781 |
+
def forward_grounding(self, outputs, targets, extra=None):
|
782 |
+
"""This performs the loss computation.
|
783 |
+
Parameters:
|
784 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
785 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
786 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
787 |
+
"""
|
788 |
+
# Compute all the requested losses
|
789 |
+
losses = {}
|
790 |
+
indices = [[] for i in range(len(targets))]
|
791 |
+
|
792 |
+
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
793 |
+
num_masks = sum(len(t["grounding_masks"]) for t in targets) + 1e-7
|
794 |
+
num_masks = torch.as_tensor(
|
795 |
+
[num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
|
796 |
+
)
|
797 |
+
if is_dist_avail_and_initialized():
|
798 |
+
torch.distributed.all_reduce(num_masks)
|
799 |
+
num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
|
800 |
+
|
801 |
+
for loss in self.losses:
|
802 |
+
losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
|
803 |
+
|
804 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
805 |
+
if "aux_outputs" in outputs:
|
806 |
+
# NOTE: we reverse the aux_outputs so that the first is the second last layer
|
807 |
+
for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
|
808 |
+
for loss in self.losses:
|
809 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
|
810 |
+
l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
|
811 |
+
losses.update(l_dict)
|
812 |
+
|
813 |
+
return losses
|
814 |
+
|
815 |
+
def forward_openimage(self, outputs, targets, extra=None):
|
816 |
+
"""This performs the loss computation.
|
817 |
+
Parameters:
|
818 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
819 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
820 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
821 |
+
"""
|
822 |
+
neg_class_emb = all_gather_grad(torch.cat([x['neg_class_emb'] for x in targets]))
|
823 |
+
neg_hash = all_gather_grad(torch.cat([x['neg_hash'] for x in targets]))
|
824 |
+
|
825 |
+
extra['neg_class_emb'] = neg_class_emb
|
826 |
+
extra['neg_hash'] = neg_hash
|
827 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
|
828 |
+
|
829 |
+
# Retrieve the matching between the outputs of the last layer and the targets
|
830 |
+
indices, pred_logits = self.matcher.openimage_forward(outputs_without_aux, targets, extra=extra)
|
831 |
+
outputs['pred_logits'] = pred_logits
|
832 |
+
|
833 |
+
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
834 |
+
num_masks = sum(len(t["labels"]) for t in targets)
|
835 |
+
num_masks = torch.as_tensor(
|
836 |
+
[num_masks], dtype=torch.float, device=neg_class_emb.device
|
837 |
+
)
|
838 |
+
if is_dist_avail_and_initialized():
|
839 |
+
torch.distributed.all_reduce(num_masks)
|
840 |
+
num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
|
841 |
+
|
842 |
+
# Compute all the requested losses
|
843 |
+
losses = {}
|
844 |
+
for loss in self.losses:
|
845 |
+
losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
|
846 |
+
|
847 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
848 |
+
if "aux_outputs" in outputs:
|
849 |
+
# NOTE: we reverse the aux_outputs so that the first is the second last layer
|
850 |
+
for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
|
851 |
+
indices, pred_logits = self.matcher.openimage_forward(aux_outputs, targets, extra=extra)
|
852 |
+
aux_outputs['pred_logits'] = pred_logits
|
853 |
+
for loss in self.losses:
|
854 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
|
855 |
+
l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
|
856 |
+
losses.update(l_dict)
|
857 |
+
|
858 |
+
return losses
|
859 |
+
|
860 |
+
def __repr__(self):
|
861 |
+
head = "Criterion " + self.__class__.__name__
|
862 |
+
body = [
|
863 |
+
"matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
|
864 |
+
"losses: {}".format(self.losses),
|
865 |
+
"weight_dict: {}".format(self.weight_dict),
|
866 |
+
"num_classes: {}".format(self.num_classes),
|
867 |
+
"eos_coef: {}".format(self.eos_coef),
|
868 |
+
"num_points: {}".format(self.num_points),
|
869 |
+
"oversample_ratio: {}".format(self.oversample_ratio),
|
870 |
+
"importance_sample_ratio: {}".format(self.importance_sample_ratio),
|
871 |
+
]
|
872 |
+
_repr_indent = 4
|
873 |
+
lines = [head] + [" " * _repr_indent + line for line in body]
|
874 |
+
return "\n".join(lines)
|
modeling/modules/matcher.py
ADDED
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Modified by Xueyan Zou ([email protected])
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
9 |
+
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
|
10 |
+
"""
|
11 |
+
Modules to compute the matching cost and solve the corresponding LSAP.
|
12 |
+
"""
|
13 |
+
import warnings
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import numpy as np
|
17 |
+
from scipy.optimize import linear_sum_assignment
|
18 |
+
from torch import nn
|
19 |
+
from torch.cuda.amp import autocast
|
20 |
+
|
21 |
+
from .point_features import point_sample
|
22 |
+
from ..language.loss import vl_similarity
|
23 |
+
|
24 |
+
def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
|
25 |
+
"""
|
26 |
+
Compute the DICE loss, similar to generalized IOU for masks
|
27 |
+
Args:
|
28 |
+
inputs: A float tensor of arbitrary shape.
|
29 |
+
The predictions for each example.
|
30 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
31 |
+
classification label for each element in inputs
|
32 |
+
(0 for the negative class and 1 for the positive class).
|
33 |
+
"""
|
34 |
+
inputs = inputs.sigmoid()
|
35 |
+
inputs = inputs.flatten(1)
|
36 |
+
numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
|
37 |
+
denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
|
38 |
+
loss = 1 - (numerator + 1) / (denominator + 1)
|
39 |
+
return loss
|
40 |
+
|
41 |
+
|
42 |
+
batch_dice_loss_jit = torch.jit.script(
|
43 |
+
batch_dice_loss
|
44 |
+
) # type: torch.jit.ScriptModule
|
45 |
+
|
46 |
+
|
47 |
+
def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
inputs: A float tensor of arbitrary shape.
|
51 |
+
The predictions for each example.
|
52 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
53 |
+
classification label for each element in inputs
|
54 |
+
(0 for the negative class and 1 for the positive class).
|
55 |
+
Returns:
|
56 |
+
Loss tensor
|
57 |
+
"""
|
58 |
+
hw = inputs.shape[1]
|
59 |
+
|
60 |
+
pos = F.binary_cross_entropy_with_logits(
|
61 |
+
inputs, torch.ones_like(inputs), reduction="none"
|
62 |
+
)
|
63 |
+
neg = F.binary_cross_entropy_with_logits(
|
64 |
+
inputs, torch.zeros_like(inputs), reduction="none"
|
65 |
+
)
|
66 |
+
|
67 |
+
loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
|
68 |
+
"nc,mc->nm", neg, (1 - targets)
|
69 |
+
)
|
70 |
+
|
71 |
+
return loss / hw
|
72 |
+
|
73 |
+
|
74 |
+
batch_sigmoid_ce_loss_jit = torch.jit.script(
|
75 |
+
batch_sigmoid_ce_loss
|
76 |
+
) # type: torch.jit.ScriptModule
|
77 |
+
|
78 |
+
|
79 |
+
class HungarianMatcher(nn.Module):
|
80 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
81 |
+
|
82 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
83 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
84 |
+
while the others are un-matched (and thus treated as non-objects).
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0, spatial_cost = None):
|
88 |
+
"""Creates the matcher
|
89 |
+
|
90 |
+
Params:
|
91 |
+
cost_class: This is the relative weight of the classification error in the matching cost
|
92 |
+
cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
|
93 |
+
cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
|
94 |
+
"""
|
95 |
+
super().__init__()
|
96 |
+
self.cost_class = cost_class
|
97 |
+
self.cost_mask = cost_mask
|
98 |
+
self.cost_dice = cost_dice
|
99 |
+
|
100 |
+
self.num_points = num_points
|
101 |
+
self.spatial_cost_class = cost_class
|
102 |
+
self.spatial_cost_mask = cost_mask
|
103 |
+
self.spatial_cost_dice = cost_dice
|
104 |
+
assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
|
105 |
+
|
106 |
+
@torch.no_grad()
|
107 |
+
def memory_efficient_forward(self, outputs, targets):
|
108 |
+
"""More memory-friendly matching"""
|
109 |
+
bs, num_queries = outputs["pred_logits"].shape[:2]
|
110 |
+
|
111 |
+
if bs == 0 or len(targets) == 0:
|
112 |
+
return None
|
113 |
+
|
114 |
+
indices = []
|
115 |
+
|
116 |
+
# Iterate through batch size
|
117 |
+
for b in range(bs):
|
118 |
+
out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes]
|
119 |
+
tgt_ids = targets[b]["labels"]
|
120 |
+
|
121 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
122 |
+
# but approximate it in 1 - proba[target class].
|
123 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
124 |
+
cost_class = -out_prob[:, tgt_ids]
|
125 |
+
|
126 |
+
out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
|
127 |
+
# gt masks are already padded when preparing target
|
128 |
+
tgt_mask = targets[b]["masks"].to(out_mask)
|
129 |
+
|
130 |
+
out_mask = out_mask[:, None]
|
131 |
+
tgt_mask = tgt_mask[:, None]
|
132 |
+
# all masks share the same set of points for efficient matching!
|
133 |
+
point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
|
134 |
+
# get gt labels
|
135 |
+
tgt_mask = point_sample(
|
136 |
+
tgt_mask,
|
137 |
+
point_coords.repeat(tgt_mask.shape[0], 1, 1),
|
138 |
+
align_corners=False,
|
139 |
+
).squeeze(1)
|
140 |
+
|
141 |
+
out_mask = point_sample(
|
142 |
+
out_mask,
|
143 |
+
point_coords.repeat(out_mask.shape[0], 1, 1),
|
144 |
+
align_corners=False,
|
145 |
+
).squeeze(1)
|
146 |
+
|
147 |
+
with autocast(enabled=False):
|
148 |
+
out_mask = out_mask.float()
|
149 |
+
tgt_mask = tgt_mask.float()
|
150 |
+
# Compute the focal loss between masks
|
151 |
+
cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
|
152 |
+
|
153 |
+
# Compute the dice loss betwen masks
|
154 |
+
cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
|
155 |
+
|
156 |
+
# Final cost matrix
|
157 |
+
C = (
|
158 |
+
self.cost_mask * cost_mask
|
159 |
+
+ self.cost_class * cost_class
|
160 |
+
+ self.cost_dice * cost_dice
|
161 |
+
)
|
162 |
+
C = C.reshape(num_queries, -1).cpu()
|
163 |
+
if C.isnan().any():
|
164 |
+
C[C.isnan()] = 1e6 ### temporary fix
|
165 |
+
warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
166 |
+
raise
|
167 |
+
indices.append(linear_sum_assignment(C))
|
168 |
+
|
169 |
+
return [
|
170 |
+
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
|
171 |
+
for i, j in indices
|
172 |
+
]
|
173 |
+
|
174 |
+
@torch.no_grad()
|
175 |
+
def openimage_forward(self, outputs, targets, extra):
|
176 |
+
"""More memory-friendly matching"""
|
177 |
+
bs, num_queries = outputs["pred_captions"].shape[:2]
|
178 |
+
if bs == 0 or len(targets) == 0:
|
179 |
+
return None
|
180 |
+
|
181 |
+
neg_class_emb = extra['neg_class_emb']
|
182 |
+
neg_hash = extra['neg_hash']
|
183 |
+
_, unique_indices = np.unique(neg_hash.cpu().numpy(), return_index=True)
|
184 |
+
neg_class_emb = neg_class_emb[unique_indices]
|
185 |
+
neg_hash = neg_hash[unique_indices]
|
186 |
+
|
187 |
+
indices = []
|
188 |
+
pred_logits = []
|
189 |
+
# Iterate through batch size
|
190 |
+
for b in range(bs):
|
191 |
+
_pos_class_emb = targets[b]['pos_class_emb']
|
192 |
+
_pos_hash = targets[b]['pos_hash']
|
193 |
+
_neg_overlap_pos = ~(neg_hash[..., None] == _pos_hash).any(-1)
|
194 |
+
_neg_class_emb = neg_class_emb[_neg_overlap_pos]
|
195 |
+
t_emb = torch.cat((_pos_class_emb, _neg_class_emb))
|
196 |
+
v_emb = outputs["pred_captions"][b]
|
197 |
+
del _pos_class_emb
|
198 |
+
del _neg_class_emb
|
199 |
+
|
200 |
+
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
201 |
+
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
|
202 |
+
|
203 |
+
out_prob = vl_similarity(v_emb, t_emb, temperature=extra['lang_logit'])
|
204 |
+
pred_logits += [out_prob]
|
205 |
+
out_prob = out_prob.softmax(-1)
|
206 |
+
tgt_ids = targets[b]["labels"]
|
207 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
208 |
+
# but approximate it in 1 - proba[target class].
|
209 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
210 |
+
cost_class = -out_prob[:, tgt_ids]
|
211 |
+
|
212 |
+
out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
|
213 |
+
# gt masks are already padded when preparing target
|
214 |
+
tgt_mask = targets[b]["masks"].to(out_mask)
|
215 |
+
|
216 |
+
out_mask = out_mask[:, None]
|
217 |
+
tgt_mask = tgt_mask[:, None]
|
218 |
+
# all masks share the same set of points for efficient matching!
|
219 |
+
point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
|
220 |
+
# get gt labels
|
221 |
+
tgt_mask = point_sample(
|
222 |
+
tgt_mask,
|
223 |
+
point_coords.repeat(tgt_mask.shape[0], 1, 1),
|
224 |
+
align_corners=False,
|
225 |
+
).squeeze(1)
|
226 |
+
|
227 |
+
out_mask = point_sample(
|
228 |
+
out_mask,
|
229 |
+
point_coords.repeat(out_mask.shape[0], 1, 1),
|
230 |
+
align_corners=False,
|
231 |
+
).squeeze(1)
|
232 |
+
|
233 |
+
with autocast(enabled=False):
|
234 |
+
out_mask = out_mask.float()
|
235 |
+
tgt_mask = tgt_mask.float()
|
236 |
+
# Compute the focal loss between masks
|
237 |
+
cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
|
238 |
+
|
239 |
+
# Compute the dice loss betwen masks
|
240 |
+
cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
|
241 |
+
|
242 |
+
# Final cost matrix
|
243 |
+
C = (
|
244 |
+
self.cost_mask * cost_mask
|
245 |
+
+ self.cost_class * cost_class
|
246 |
+
+ self.cost_dice * cost_dice
|
247 |
+
)
|
248 |
+
C = C.reshape(num_queries, -1).cpu()
|
249 |
+
if C.isnan().any():
|
250 |
+
C[C.isnan()] = 1e6 ### temporary fix
|
251 |
+
warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
252 |
+
raise
|
253 |
+
indices.append(linear_sum_assignment(C))
|
254 |
+
|
255 |
+
return [
|
256 |
+
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
|
257 |
+
for i, j in indices
|
258 |
+
], pred_logits
|
259 |
+
|
260 |
+
@torch.no_grad()
|
261 |
+
def grounding_forward(self, outputs, targets, extra):
|
262 |
+
"""More memory-friendly matching"""
|
263 |
+
bs, num_queries = outputs["pred_gmasks"].shape[:2]
|
264 |
+
|
265 |
+
if bs == 0 or len(targets) == 0:
|
266 |
+
return None
|
267 |
+
|
268 |
+
indices = []
|
269 |
+
# Iterate through batch size
|
270 |
+
for b in range(bs):
|
271 |
+
out_prob = outputs["pred_logits"][b]
|
272 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
273 |
+
# but approximate it in 1 - proba[target class].
|
274 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
275 |
+
cost_class = -out_prob.softmax(dim=0)
|
276 |
+
|
277 |
+
out_mask = outputs["pred_gmasks"][b] # [num_queries, H_pred, W_pred]
|
278 |
+
# gt masks are already padded when preparing target
|
279 |
+
tgt_mask = targets[b]["grounding_masks"].to(out_mask)
|
280 |
+
|
281 |
+
out_mask = out_mask[:, None]
|
282 |
+
tgt_mask = tgt_mask[:, None]
|
283 |
+
|
284 |
+
# all masks share the same set of points for efficient matching!
|
285 |
+
point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
|
286 |
+
# get gt labels
|
287 |
+
tgt_mask = point_sample(
|
288 |
+
tgt_mask,
|
289 |
+
point_coords.repeat(tgt_mask.shape[0], 1, 1),
|
290 |
+
align_corners=False,
|
291 |
+
).squeeze(1)
|
292 |
+
|
293 |
+
out_mask = point_sample(
|
294 |
+
out_mask,
|
295 |
+
point_coords.repeat(out_mask.shape[0], 1, 1),
|
296 |
+
align_corners=False,
|
297 |
+
).squeeze(1)
|
298 |
+
|
299 |
+
with autocast(enabled=False):
|
300 |
+
out_mask = out_mask.float()
|
301 |
+
tgt_mask = tgt_mask.float()
|
302 |
+
# Compute the focal loss between masks
|
303 |
+
cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
|
304 |
+
|
305 |
+
# Compute the dice loss betwen masks
|
306 |
+
cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
|
307 |
+
|
308 |
+
# Final cost matrix
|
309 |
+
C = (
|
310 |
+
self.cost_mask * cost_mask
|
311 |
+
+ self.cost_class * cost_class
|
312 |
+
+ self.cost_dice * cost_dice
|
313 |
+
)
|
314 |
+
C = C.reshape(num_queries, -1).cpu()
|
315 |
+
if C.isnan().any():
|
316 |
+
C[C.isnan()] = 1e6 ### temporary fix
|
317 |
+
warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
318 |
+
raise
|
319 |
+
indices.append(linear_sum_assignment(C))
|
320 |
+
|
321 |
+
return [
|
322 |
+
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
|
323 |
+
for i, j in indices
|
324 |
+
]
|
325 |
+
|
326 |
+
@torch.no_grad()
|
327 |
+
def spatial_forward(self, outputs, targets, extra):
|
328 |
+
"""More memory-friendly matching"""
|
329 |
+
bs, num_queries = outputs["pred_smasks"].shape[:2]
|
330 |
+
|
331 |
+
if bs == 0 or len(targets) == 0:
|
332 |
+
return None
|
333 |
+
|
334 |
+
indices = []
|
335 |
+
# Iterate through batch size
|
336 |
+
for b in range(bs):
|
337 |
+
out_mask = outputs["pred_smasks"][b] # [num_queries, H_pred, W_pred]
|
338 |
+
# gt masks are already padded when preparing target
|
339 |
+
tgt_mask = targets[b]["gt_spatial_masks"].to(out_mask)
|
340 |
+
nd,ns = outputs["pred_pos_logits"][b].shape
|
341 |
+
index_masking = 1-torch.eye(ns, device=out_mask.device, dtype=tgt_mask.dtype).repeat_interleave(nd//ns,dim=0)
|
342 |
+
neg_masking = torch.zeros((nd,ns), device=out_mask.device, dtype=tgt_mask.dtype)
|
343 |
+
neg_masking.masked_fill_(index_masking.bool(), -float('inf'))
|
344 |
+
pos_masking = torch.zeros((nd,ns), device=out_mask.device, dtype=tgt_mask.dtype)
|
345 |
+
pos_masking.masked_fill_(index_masking.bool(), float('inf'))
|
346 |
+
out_prob = (outputs["pred_pos_logits"][b]+neg_masking)[:,:len(tgt_mask)] # remove redundant predictions for padding
|
347 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
348 |
+
# but approximate it in 1 - proba[target class].
|
349 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
350 |
+
cost_class = -out_prob.softmax(dim=0)
|
351 |
+
|
352 |
+
out_mask = out_mask[:, None]
|
353 |
+
tgt_mask = tgt_mask[:, None]
|
354 |
+
|
355 |
+
# all masks share the same set of points for efficient matching!
|
356 |
+
point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
|
357 |
+
# get gt labels
|
358 |
+
tgt_mask = point_sample(
|
359 |
+
tgt_mask,
|
360 |
+
point_coords.repeat(tgt_mask.shape[0], 1, 1),
|
361 |
+
align_corners=False,
|
362 |
+
).squeeze(1)
|
363 |
+
|
364 |
+
out_mask = point_sample(
|
365 |
+
out_mask,
|
366 |
+
point_coords.repeat(out_mask.shape[0], 1, 1),
|
367 |
+
align_corners=False,
|
368 |
+
).squeeze(1)
|
369 |
+
|
370 |
+
with autocast(enabled=False):
|
371 |
+
out_mask = out_mask.float()
|
372 |
+
tgt_mask = tgt_mask.float()
|
373 |
+
# Compute the focal loss between masks
|
374 |
+
cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) + pos_masking[:,:len(tgt_mask)]
|
375 |
+
# Compute the dice loss betwen masks
|
376 |
+
cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) + pos_masking[:,:len(tgt_mask)]
|
377 |
+
|
378 |
+
# Final cost matrix
|
379 |
+
C = (
|
380 |
+
self.spatial_cost_mask * cost_mask
|
381 |
+
+ self.spatial_cost_class * cost_class
|
382 |
+
+ self.spatial_cost_dice * cost_dice
|
383 |
+
)
|
384 |
+
C = C.reshape(num_queries, -1).cpu()
|
385 |
+
if C.isnan().any():
|
386 |
+
C[C.isnan()] = 1e6 ### temporary fix
|
387 |
+
warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
388 |
+
raise
|
389 |
+
indices.append(linear_sum_assignment(C))
|
390 |
+
|
391 |
+
return [
|
392 |
+
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
|
393 |
+
for i, j in indices
|
394 |
+
]
|
395 |
+
|
396 |
+
@torch.no_grad()
|
397 |
+
def spatial_forward_pn(self, outputs, targets, extra):
|
398 |
+
"""More memory-friendly matching"""
|
399 |
+
bs, num_queries = outputs["pred_smasks"].shape[:2]
|
400 |
+
|
401 |
+
if bs == 0 or len(targets) == 0:
|
402 |
+
return None
|
403 |
+
|
404 |
+
fp_mask = extra['false_positive_mask']
|
405 |
+
gt_mask = torch.stack([targets[b]["gt_spatial_masks"] for b in range(bs)])
|
406 |
+
|
407 |
+
indices = []
|
408 |
+
# Iterate through batch size
|
409 |
+
for b in range(bs):
|
410 |
+
out_prob = outputs["pred_neg_logits"][b]
|
411 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
412 |
+
# but approximate it in 1 - proba[target class].
|
413 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
414 |
+
cost_class = -out_prob.softmax(dim=0)
|
415 |
+
|
416 |
+
out_mask = outputs["pred_smasks"][b] # [num_queries, H_pred, W_pred]
|
417 |
+
tgt_mask = fp_mask[b].to(out_mask)
|
418 |
+
ign_mask = (gt_mask[b] | fp_mask[b]).to(out_mask)
|
419 |
+
|
420 |
+
out_mask = out_mask[:, None]
|
421 |
+
tgt_mask = tgt_mask[:, None]
|
422 |
+
ign_mask = ign_mask[:, None]
|
423 |
+
|
424 |
+
# all masks share the same set of points for efficient matching!
|
425 |
+
point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
|
426 |
+
|
427 |
+
# get gt labels
|
428 |
+
tgt_mask = point_sample(
|
429 |
+
tgt_mask,
|
430 |
+
point_coords.repeat(tgt_mask.shape[0], 1, 1),
|
431 |
+
align_corners=False,
|
432 |
+
).squeeze(1)
|
433 |
+
|
434 |
+
out_mask = point_sample(
|
435 |
+
out_mask,
|
436 |
+
point_coords.repeat(out_mask.shape[0], 1, 1),
|
437 |
+
align_corners=False,
|
438 |
+
).squeeze(1)
|
439 |
+
|
440 |
+
ign_mask = point_sample(
|
441 |
+
ign_mask,
|
442 |
+
point_coords.repeat(ign_mask.shape[0], 1, 1),
|
443 |
+
align_corners=False,
|
444 |
+
).squeeze(1)
|
445 |
+
|
446 |
+
with autocast(enabled=False):
|
447 |
+
out_mask = out_mask.float()
|
448 |
+
tgt_mask = tgt_mask.float()
|
449 |
+
ign_mask = ign_mask.float()
|
450 |
+
|
451 |
+
# Compute the focal loss between masks
|
452 |
+
cost_mask = batch_sigmoid_ce_loss_jit(out_mask*ign_mask, tgt_mask*ign_mask)
|
453 |
+
|
454 |
+
# Compute the dice loss betwen masks
|
455 |
+
cost_dice = batch_dice_loss_jit(out_mask*ign_mask, tgt_mask*ign_mask)
|
456 |
+
|
457 |
+
# Final cost matrix
|
458 |
+
C = (
|
459 |
+
self.spatial_cost_mask * cost_mask
|
460 |
+
+ self.spatial_cost_class * cost_class
|
461 |
+
+ self.spatial_cost_dice * cost_dice
|
462 |
+
)
|
463 |
+
C = C.reshape(num_queries, -1).cpu()
|
464 |
+
if C.isnan().any():
|
465 |
+
C[C.isnan()] = 1e6 ### temporary fix
|
466 |
+
warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
467 |
+
raise
|
468 |
+
indices.append(linear_sum_assignment(C))
|
469 |
+
|
470 |
+
return [
|
471 |
+
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
|
472 |
+
for i, j in indices
|
473 |
+
]
|
474 |
+
|
475 |
+
@torch.no_grad()
|
476 |
+
def caption_forward_womask(self, outputs, targets, extra):
|
477 |
+
"""More memory-friendly matching"""
|
478 |
+
bs, _ = outputs["pred_logits"].shape[:2]
|
479 |
+
|
480 |
+
if bs == 0 or len(targets) == 0:
|
481 |
+
return None
|
482 |
+
|
483 |
+
indices = []
|
484 |
+
t_emb = torch.cat([t['captions'] for t in targets])
|
485 |
+
v_emb = outputs['unmatched_pred_captions']
|
486 |
+
caption_target_count = np.cumsum([0] + [len(t['captions']) for t in targets])
|
487 |
+
|
488 |
+
# Iterate through batch size
|
489 |
+
for b in range(bs):
|
490 |
+
v_emb[b] = v_emb[b] / (v_emb[b].norm(dim=-1, keepdim=True) + 1e-7)
|
491 |
+
num_queries = len(v_emb[b])
|
492 |
+
out_prob = vl_similarity(v_emb[b][None,], t_emb, temperature=extra['temperature']).softmax(-1)[0]
|
493 |
+
tgt_ids = [idx for idx in range(caption_target_count[b], caption_target_count[b+1])]
|
494 |
+
|
495 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
496 |
+
# but approximate it in 1 - proba[target class].
|
497 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
498 |
+
cost_class = -out_prob[:, tgt_ids]
|
499 |
+
|
500 |
+
# Final cost matrix
|
501 |
+
C = (self.cost_class * cost_class)
|
502 |
+
C = C.reshape(num_queries, -1).cpu()
|
503 |
+
if C.isnan().any():
|
504 |
+
C[C.isnan()] = 1e6 ### temporary fix
|
505 |
+
warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
506 |
+
raise
|
507 |
+
indices.append(linear_sum_assignment(C))
|
508 |
+
|
509 |
+
return [
|
510 |
+
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
|
511 |
+
for i, j in indices
|
512 |
+
]
|
513 |
+
|
514 |
+
@torch.no_grad()
|
515 |
+
def caption_forward_wmask(self, outputs, targets, extra):
|
516 |
+
"""More memory-friendly matching"""
|
517 |
+
bs, _ = outputs["pred_logits"].shape[:2]
|
518 |
+
|
519 |
+
if bs == 0 or len(targets) == 0:
|
520 |
+
return None
|
521 |
+
|
522 |
+
indices = []
|
523 |
+
t_emb = torch.cat([t['captions'] for t in targets])
|
524 |
+
v_emb = outputs['unmatched_pred_captions']
|
525 |
+
caption_target_count = np.cumsum([0] + [len(t['captions']) for t in targets])
|
526 |
+
|
527 |
+
# Iterate through batch size
|
528 |
+
for b in range(bs):
|
529 |
+
v_emb[b] = v_emb[b] / (v_emb[b].norm(dim=-1, keepdim=True) + 1e-7)
|
530 |
+
num_queries = len(v_emb[b])
|
531 |
+
|
532 |
+
out_prob = vl_similarity(v_emb[b][None,], t_emb, temperature=extra['temperature']).softmax(-1)[0]
|
533 |
+
tgt_ids = [idx for idx in range(caption_target_count[b], caption_target_count[b+1])]
|
534 |
+
|
535 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
536 |
+
# but approximate it in 1 - proba[target class].
|
537 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
538 |
+
cost_class = -out_prob[:, tgt_ids]
|
539 |
+
|
540 |
+
out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
|
541 |
+
# gt masks are already padded when preparing target
|
542 |
+
tgt_mask = targets[b]["masks"].to(out_mask)
|
543 |
+
|
544 |
+
out_mask = out_mask[:, None]
|
545 |
+
tgt_mask = tgt_mask[:, None]
|
546 |
+
# all masks share the same set of points for efficient matching!
|
547 |
+
point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
|
548 |
+
# get gt labels
|
549 |
+
tgt_mask = point_sample(
|
550 |
+
tgt_mask,
|
551 |
+
point_coords.repeat(tgt_mask.shape[0], 1, 1),
|
552 |
+
align_corners=False,
|
553 |
+
).squeeze(1)
|
554 |
+
|
555 |
+
out_mask = point_sample(
|
556 |
+
out_mask,
|
557 |
+
point_coords.repeat(out_mask.shape[0], 1, 1),
|
558 |
+
align_corners=False,
|
559 |
+
).squeeze(1)
|
560 |
+
|
561 |
+
with autocast(enabled=False):
|
562 |
+
out_mask = out_mask.float()
|
563 |
+
tgt_mask = tgt_mask.float()
|
564 |
+
# Compute the focal loss between masks
|
565 |
+
cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
|
566 |
+
|
567 |
+
# Compute the dice loss betwen masks
|
568 |
+
cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
|
569 |
+
|
570 |
+
# Final cost matrix
|
571 |
+
C = (
|
572 |
+
self.cost_mask * cost_mask
|
573 |
+
+ self.cost_class * cost_class
|
574 |
+
+ self.cost_dice * cost_dice
|
575 |
+
)
|
576 |
+
C = C.reshape(num_queries, -1).cpu()
|
577 |
+
if C.isnan().any():
|
578 |
+
C[C.isnan()] = 1e6 ### temporary fix
|
579 |
+
warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
580 |
+
raise
|
581 |
+
indices.append(linear_sum_assignment(C))
|
582 |
+
|
583 |
+
return [
|
584 |
+
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
|
585 |
+
for i, j in indices
|
586 |
+
]
|
587 |
+
|
588 |
+
@torch.no_grad()
|
589 |
+
def forward(self, outputs, targets, mode='default', extra={}):
|
590 |
+
"""Performs the matching
|
591 |
+
|
592 |
+
Params:
|
593 |
+
outputs: This is a dict that contains at least these entries:
|
594 |
+
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
595 |
+
"pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
|
596 |
+
|
597 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
598 |
+
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
599 |
+
objects in the target) containing the class labels
|
600 |
+
"masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
|
601 |
+
|
602 |
+
Returns:
|
603 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
604 |
+
- index_i is the indices of the selected predictions (in order)
|
605 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
606 |
+
For each batch element, it holds:
|
607 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
608 |
+
"""
|
609 |
+
if mode == 'default':
|
610 |
+
return self.memory_efficient_forward(outputs, targets)
|
611 |
+
elif mode == 'grounding':
|
612 |
+
return self.grounding_forward(outputs, targets, extra)
|
613 |
+
elif mode == 'spatial':
|
614 |
+
return self.spatial_forward(outputs, targets, extra)
|
615 |
+
elif mode == 'spatial_pn':
|
616 |
+
return self.spatial_forward_pn(outputs, targets, extra)
|
617 |
+
elif mode == 'caption_womask':
|
618 |
+
return self.caption_forward_womask(outputs, targets, extra)
|
619 |
+
elif mode == 'caption_wmask':
|
620 |
+
return self.caption_forward_wmask(outputs, targets, extra)
|
621 |
+
else:
|
622 |
+
assert False, "Mode {} is not supported.".format(mode)
|
623 |
+
|
624 |
+
def __repr__(self, _repr_indent=4):
|
625 |
+
head = "Matcher " + self.__class__.__name__
|
626 |
+
body = [
|
627 |
+
"cost_class: {}".format(self.cost_class),
|
628 |
+
"cost_mask: {}".format(self.cost_mask),
|
629 |
+
"cost_dice: {}".format(self.cost_dice),
|
630 |
+
]
|
631 |
+
lines = [head] + [" " * _repr_indent + line for line in body]
|
632 |
+
return "\n".join(lines)
|