kernel-luso-comfort commited on
Commit
6ba63c9
·
1 Parent(s): cbd253a

Add initial module structure and entry points for modeling and utilities

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. Dockerfile +77 -0
  3. README.md +5 -5
  4. colabs/ENVIRONMENT.md +6 -0
  5. colabs/biomedparse_inference_demo.py +156 -0
  6. colabs/environment.yml +149 -0
  7. colabs/requirements-colab-pip-freeze.txt +567 -0
  8. colabs/requirements-colab.txt +39 -0
  9. configs/biomedparse_inference.yaml +204 -0
  10. entrypoint.sh +5 -0
  11. examples/Part_1_516_pathology_breast.png +3 -0
  12. inference_utils/inference.py +149 -0
  13. inference_utils/output_processing.py +91 -0
  14. inference_utils/processing_utils.py +182 -0
  15. inference_utils/target_dist.json +1 -0
  16. main.py +106 -0
  17. modeling/BaseModel.py +45 -0
  18. modeling/__init__.py +1 -0
  19. modeling/architectures/__init__.py +5 -0
  20. modeling/architectures/build.py +22 -0
  21. modeling/architectures/seem_model_demo.py +923 -0
  22. modeling/architectures/seem_model_v0.py +1160 -0
  23. modeling/architectures/seem_model_v1.py +1179 -0
  24. modeling/architectures/xdecoder_model.py +937 -0
  25. modeling/body/__init__.py +10 -0
  26. modeling/body/build.py +13 -0
  27. modeling/body/xdecoder_head.py +126 -0
  28. modeling/interface/__init__.py +13 -0
  29. modeling/interface/build.py +14 -0
  30. modeling/interface/modules.py +200 -0
  31. modeling/interface/prototype/__init__.py +0 -0
  32. modeling/interface/prototype/attention_data_struct_seemdemo.py +265 -0
  33. modeling/interface/prototype/attention_data_struct_seemv0.py +264 -0
  34. modeling/interface/prototype/attention_data_struct_seemv1.py +302 -0
  35. modeling/interface/seem_demo.py +397 -0
  36. modeling/interface/seem_v0.py +392 -0
  37. modeling/interface/seem_v1.py +389 -0
  38. modeling/interface/xdecoder.py +497 -0
  39. modeling/language/LangEncoder/__init__.py +35 -0
  40. modeling/language/LangEncoder/build.py +16 -0
  41. modeling/language/LangEncoder/transformer.py +222 -0
  42. modeling/language/__init__.py +10 -0
  43. modeling/language/build.py +14 -0
  44. modeling/language/loss.py +232 -0
  45. modeling/language/misc.py +66 -0
  46. modeling/language/vlpencoder.py +206 -0
  47. modeling/modules/__init__.py +6 -0
  48. modeling/modules/attention.py +487 -0
  49. modeling/modules/criterion.py +874 -0
  50. modeling/modules/matcher.py +632 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM continuumio/miniconda3:latest
5
+
6
+ # Add build argument to force rebuild
7
+ ARG CACHEBUST=1
8
+
9
+ # Avoid tzdata interactive configuration
10
+ ENV DEBIAN_FRONTEND=noninteractive
11
+ ENV TZ=UTC
12
+
13
+ # Install system dependencies
14
+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
15
+ git \
16
+ build-essential \
17
+ python3-dev \
18
+ wget \
19
+ openmpi-bin \
20
+ libopenmpi-dev \
21
+ libopenmpi3 \
22
+ libhwloc15 \
23
+ libevent-dev \
24
+ libpmix2 \
25
+ libgl1 \
26
+ libglib2.0-0 \
27
+ && rm -rf /var/lib/apt/lists/*
28
+
29
+ # Set up OpenMPI environment
30
+ ENV OMPI_MCA_btl_vader_single_copy_mechanism=none \
31
+ OMPI_ALLOW_RUN_AS_ROOT=1 \
32
+ OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \
33
+ PATH=/usr/lib/x86_64-linux-gnu/openmpi/bin:$PATH \
34
+ LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/openmpi/lib:/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
35
+
36
+ # Copy environment file
37
+ COPY colabs/environment.yml /tmp/environment.yml
38
+
39
+ # Create conda environment
40
+ RUN conda env create -f /tmp/environment.yml && \
41
+ conda run -n biomedparse pip install gradio==3.50.2
42
+
43
+ # Initialize conda in bash
44
+ RUN conda init bash
45
+
46
+ # Make RUN commands use the new environment
47
+ SHELL ["conda", "run", "-n", "biomedparse", "/bin/bash", "-c"]
48
+
49
+ # Set up a new user named "user" with user ID 1000
50
+ RUN useradd -m -u 1000 user
51
+
52
+ # Switch to the "user" user
53
+ USER user
54
+
55
+ # Set up HF token for the user
56
+ RUN --mount=type=secret,id=HF_TOKEN,mode=0444,required=true \
57
+ echo "export HF_TOKEN=$(cat /run/secrets/HF_TOKEN)" >> $HOME/.bashrc
58
+
59
+ # Set home to the user's home directory
60
+ ENV HOME=/home/user \
61
+ PATH=/home/user/.local/bin:$PATH
62
+
63
+ # Set the working directory to the user's home directory
64
+ WORKDIR $HOME/app
65
+
66
+ # Copy all files to the app directory
67
+ COPY --chown=user . $HOME/app
68
+
69
+ # Set permissions for entrypoint script
70
+ RUN chmod 755 $HOME/app/entrypoint.sh
71
+
72
+ # Add conda environment to user's path
73
+ RUN echo "conda activate biomedparse" >> $HOME/.bashrc
74
+
75
+ # Use entrypoint script to set up environment and run application
76
+ ENTRYPOINT ["/bin/bash", "-c"]
77
+ CMD ["exec /home/user/app/entrypoint.sh"]
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
- title: BiomedParse
3
- emoji: 📊
4
- colorFrom: gray
5
- colorTo: purple
6
  sdk: docker
7
  pinned: false
8
- short_description: BiomedParse
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Biomedparse Docker
3
+ emoji: 📉
4
+ colorFrom: yellow
5
+ colorTo: blue
6
  sdk: docker
7
  pinned: false
8
+ license: cc-by-nc-sa-4.0
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
colabs/ENVIRONMENT.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Description of Google Colab Environment
2
+
3
+ - Hardware: Python 3 Google Compute Engine Backend on T4 GPU
4
+ - CUDA version: 12.2
5
+ - Driver Version: 535.104.05
6
+ - Python version: 3.10.12
colabs/biomedparse_inference_demo.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """biomedparse_inference_demo.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1jL4wvdtBWz6G_yBkFn8tyDD0hV1RtKVZ
8
+
9
+ # BiomedParse Inference Demo Notebook
10
+
11
+ Welcome to the demo notebook for BiomedParse, a comprehensive tool for biomedical image analysis. BiomedParse is designed to simultaneously handle segmentation, detection, and recognition tasks across major biomedical image modalities, providing a unified solution for complex image analysis in biomedical research.
12
+
13
+ [[`Paper`](https://aka.ms/biomedparse-paper)] [[`Demo`](https://microsoft.github.io/BiomedParse/)] [[`Model`](https://huggingface.co/microsoft/BiomedParse)] [[`Data`](https://huggingface.co/datasets/microsoft/BiomedParseData)]
14
+
15
+ ## Model Checkpoint Access
16
+
17
+ The BiomedParse model checkpoint is hosted on [HuggingFace](https://huggingface.co/microsoft/BiomedParse). To access the model:
18
+
19
+ 1. Visit the [model page](https://huggingface.co/microsoft/BiomedParse).
20
+ 2. Make sure to review and accept the terms of use to gain access to the checkpoint.
21
+ 3. Retrieve your HuggingFace access token from your user profile.
22
+
23
+ ## Setting Up Access
24
+
25
+ To use the model, set your Hugging Face access token in the HF_TOKEN environment variable or as a Colab secret. This step ensures secure and authorized access to the model resources.
26
+ """
27
+
28
+ # Set your Hugging Face access token in your environment
29
+ # import os
30
+ # os.environ['HF_TOKEN'] = 'your_huggingface_access_token_here'
31
+
32
+ # Or, if you are using Google Colab, set HF_TOKEN on Colab secrets.
33
+
34
+ from google.colab import userdata
35
+ import huggingface_hub
36
+
37
+ huggingface_hub.login(userdata.get('HF_TOKEN'))
38
+
39
+ from huggingface_hub import hf_hub_download
40
+
41
+ model_file = hf_hub_download(repo_id="microsoft/BiomedParse", filename="biomedparse_v1.pt", local_dir="pretrained")
42
+
43
+ print(f"Downloaded model file to: {model_file}")
44
+
45
+ """## Environment Setup"""
46
+
47
+ !git clone https://github.com/microsoft/BiomedParse
48
+
49
+ !pip install -r BiomedParse/assets/requirements/requirements.txt
50
+
51
+ """# Restart Colab Runtime"""
52
+
53
+ # Make sure to restart Colab runtime after installing dependencies
54
+ import os
55
+ try:
56
+ import google.colab
57
+ os._exit(0)
58
+ except ImportError:
59
+ pass
60
+
61
+ import os
62
+ os.chdir('/content/BiomedParse')
63
+ print(os.getcwd())
64
+
65
+ """## Load the model weights"""
66
+
67
+ from PIL import Image
68
+ import torch
69
+ import argparse
70
+ import numpy as np
71
+ from modeling.BaseModel import BaseModel
72
+ from modeling import build_model
73
+ from utilities.distributed import init_distributed # changed from utils
74
+ from utilities.arguments import load_opt_from_config_files
75
+ from utilities.constants import BIOMED_CLASSES
76
+ from inference_utils.inference import interactive_infer_image
77
+
78
+ conf_files = "configs/biomedparse_inference.yaml"
79
+ opt = load_opt_from_config_files([conf_files])
80
+ opt = init_distributed(opt)
81
+
82
+ model_file = "../pretrained/biomedparse_v1.pt"
83
+
84
+ model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda()
85
+ with torch.no_grad():
86
+ model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(BIOMED_CLASSES + ["background"], is_eval=True)
87
+
88
+ """# Run Inference"""
89
+
90
+ # RGB image input of shape (H, W, 3). Currently only batch size 1 is supported.
91
+ image = Image.open('examples/Part_1_516_pathology_breast.png', formats=['png'])
92
+ image = image.convert('RGB')
93
+
94
+ # text prompts querying objects in the image. Multiple ones can be provided.
95
+ prompts = ['neoplastic cells', 'inflammatory cells']
96
+
97
+ pred_mask = interactive_infer_image(model, image, prompts)
98
+ pred_mask.shape
99
+
100
+ # load ground truth mask
101
+ gt_masks = []
102
+ for prompt in prompts:
103
+ gt_mask = Image.open(f"examples/Part_1_516_pathology_breast_{prompt.replace(' ', '+')}.png", formats=['png'])
104
+ gt_mask = 1*(np.array(gt_mask.convert('RGB'))[:,:,0] > 0)
105
+ gt_masks.append(gt_mask)
106
+
107
+ # prediction with ground truth mask
108
+ for i, pred in enumerate(pred_mask):
109
+ gt = gt_masks[i]
110
+ dice = (1*(pred>0.5) & gt).sum() * 2.0 / (1*(pred>0.5).sum() + gt.sum())
111
+ print(f'Dice score for {prompts[i]}: {dice:.4f}')
112
+
113
+ import numpy as np
114
+ import matplotlib.pyplot as plt
115
+ from PIL import Image
116
+ import matplotlib.patches as mpatches
117
+
118
+ def overlay_masks(image, masks, colors):
119
+ overlay = image.copy()
120
+ overlay = np.array(overlay, dtype=np.uint8)
121
+ for mask, color in zip(masks, colors):
122
+ overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype(np.uint8)
123
+ return Image.fromarray(overlay)
124
+
125
+ def generate_colors(n):
126
+ cmap = plt.get_cmap('tab10')
127
+ colors = [tuple(int(255 * val) for val in cmap(i)[:3]) for i in range(n)]
128
+ return colors
129
+
130
+ original_image = Image.open('examples/Part_1_516_pathology_breast.png').convert('RGB')
131
+
132
+ colors = generate_colors(len(prompts))
133
+
134
+ pred_overlay = overlay_masks(original_image, [1*(pred_mask[i] > 0.5) for i in range(len(prompts))], colors)
135
+
136
+ gt_overlay = overlay_masks(original_image, gt_masks, colors)
137
+
138
+ legend_patches = [mpatches.Patch(color=np.array(color) / 255, label=prompt) for color, prompt in zip(colors, prompts)]
139
+
140
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
141
+ axes[0].imshow(original_image)
142
+ axes[0].set_title("Original Image")
143
+ axes[0].axis('off')
144
+
145
+ axes[1].imshow(pred_overlay)
146
+ axes[1].set_title("Predictions")
147
+ axes[1].axis('off')
148
+ axes[1].legend(handles=legend_patches, loc='upper right', fontsize='small')
149
+
150
+ axes[2].imshow(gt_overlay)
151
+ axes[2].set_title("Ground Truth")
152
+ axes[2].axis('off')
153
+ axes[2].legend(handles=legend_patches, loc='upper right', fontsize='small')
154
+
155
+ plt.tight_layout()
156
+ plt.show()
colabs/environment.yml ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: biomedparse
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - _openmp_mutex=5.1=1_gnu
9
+ - blas=1.0=mkl
10
+ - brotli-python=1.0.9=py39h6a678d5_8
11
+ - bzip2=1.0.8=h5eee18b_6
12
+ - ca-certificates=2024.7.2=h06a4308_0
13
+ - certifi=2024.7.4=py39h06a4308_0
14
+ - charset-normalizer=3.3.2=pyhd3eb1b0_0
15
+ - cuda-cudart=12.4.127=0
16
+ - cuda-cupti=12.4.127=0
17
+ - cuda-libraries=12.4.0=0
18
+ - cuda-nvrtc=12.4.127=0
19
+ - cuda-nvtx=12.4.127=0
20
+ - cuda-opencl=12.6.37=0
21
+ - cuda-runtime=12.4.0=0
22
+ - cuda-version=12.6=3
23
+ - ffmpeg=4.3=hf484d3e_0
24
+ - filelock=3.13.1=py39h06a4308_0
25
+ - freetype=2.12.1=h4a9f257_0
26
+ - gmp=6.2.1=h295c915_3
27
+ - gmpy2=2.1.2=py39heeb90bb_0
28
+ - gnutls=3.6.15=he1e5248_0
29
+ - idna=3.7=py39h06a4308_0
30
+ - intel-openmp=2023.1.0=hdb19cb5_46306
31
+ - jinja2=3.1.4=py39h06a4308_0
32
+ - jpeg=9e=h5eee18b_3
33
+ - lame=3.100=h7b6447c_0
34
+ - lcms2=2.12=h3be6417_0
35
+ - ld_impl_linux-64=2.38=h1181459_1
36
+ - lerc=3.0=h295c915_0
37
+ - libcublas=12.4.2.65=0
38
+ - libcufft=11.2.0.44=0
39
+ - libcufile=1.11.0.15=0
40
+ - libcurand=10.3.7.37=0
41
+ - libcusolver=11.6.0.99=0
42
+ - libcusparse=12.3.0.142=0
43
+ - libdeflate=1.17=h5eee18b_1
44
+ - libffi=3.4.4=h6a678d5_1
45
+ - libgcc-ng=11.2.0=h1234567_1
46
+ - libgomp=11.2.0=h1234567_1
47
+ - libiconv=1.16=h5eee18b_3
48
+ - libidn2=2.3.4=h5eee18b_0
49
+ - libjpeg-turbo=2.0.0=h9bf148f_0
50
+ - libnpp=12.2.5.2=0
51
+ - libnvfatbin=12.6.20=0
52
+ - libnvjitlink=12.4.99=0
53
+ - libnvjpeg=12.3.1.89=0
54
+ - libpng=1.6.39=h5eee18b_0
55
+ - libstdcxx-ng=11.2.0=h1234567_1
56
+ - libtasn1=4.19.0=h5eee18b_0
57
+ - libtiff=4.5.1=h6a678d5_0
58
+ - libunistring=0.9.10=h27cfd23_0
59
+ - libwebp-base=1.3.2=h5eee18b_0
60
+ - llvm-openmp=14.0.6=h9e868ea_0
61
+ - lz4-c=1.9.4=h6a678d5_1
62
+ - markupsafe=2.1.3=py39h5eee18b_0
63
+ - mkl=2023.1.0=h213fc3f_46344
64
+ - mkl-service=2.4.0=py39h5eee18b_1
65
+ - mkl_fft=1.3.8=py39h5eee18b_0
66
+ - mkl_random=1.2.4=py39hdb19cb5_0
67
+ - mpc=1.1.0=h10f8cd9_1
68
+ - mpfr=4.0.2=hb69a4c5_1
69
+ - mpmath=1.3.0=py39h06a4308_0
70
+ - ncurses=6.4=h6a678d5_0
71
+ - nettle=3.7.3=hbbd107a_1
72
+ - networkx=3.2.1=py39h06a4308_0
73
+ - openh264=2.1.1=h4ff587b_0
74
+ - openjpeg=2.5.2=he7f1fd0_0
75
+ - openssl=3.0.14=h5eee18b_0
76
+ - pip=24.2=py39h06a4308_0
77
+ - pysocks=1.7.1=py39h06a4308_0
78
+ - python=3.9.19=h955ad1f_1
79
+ - pytorch=2.4.0=py3.9_cuda12.4_cudnn9.1.0_0
80
+ - pytorch-cuda=12.4=hc786d27_6
81
+ - pytorch-mutex=1.0=cuda
82
+ - pyyaml=6.0.1=py39h5eee18b_0
83
+ - readline=8.2=h5eee18b_0
84
+ - requests=2.32.3=py39h06a4308_0
85
+ - setuptools=72.1.0=py39h06a4308_0
86
+ - sqlite=3.45.3=h5eee18b_0
87
+ - sympy=1.12=py39h06a4308_0
88
+ - tbb=2021.8.0=hdb19cb5_0
89
+ - tk=8.6.14=h39e8969_0
90
+ - torchaudio=2.4.0=py39_cu124
91
+ - torchtriton=3.0.0=py39
92
+ - torchvision=0.19.0=py39_cu124
93
+ - typing_extensions=4.11.0=py39h06a4308_0
94
+ - tzdata=2024a=h04d1e81_0
95
+ - urllib3=2.2.2=py39h06a4308_0
96
+ - wheel=0.43.0=py39h06a4308_0
97
+ - xz=5.4.6=h5eee18b_1
98
+ - yaml=0.2.5=h7b6447c_0
99
+ - zlib=1.2.13=h5eee18b_1
100
+ - zstd=1.5.5=hc292b87_2
101
+ - pip:
102
+ - accelerate==0.23.0
103
+ - antlr4-python3-runtime==4.9.3
104
+ - appdirs==1.4.4
105
+ - black==21.4b2
106
+ - open-clip-torch==2.26.1
107
+ - cloudpickle==3.0.0
108
+ - cython==3.0.2
109
+ - deepspeed==0.10.3
110
+ - git+https://github.com/MaureenZOU/detectron2-xyz.git
111
+ - diffdist==0.1
112
+ - einops==0.8.0
113
+ - ftfy==6.1.1
114
+ - fvcore==0.1.5.post20221221
115
+ - hjson==3.1.0
116
+ - huggingface-hub==0.17.3
117
+ - hydra-core==1.3.2
118
+ - imageio==2.35.1
119
+ - infinibatch==0.1.1
120
+ - iopath==0.1.9
121
+ - json-tricks==3.17.3
122
+ - kornia==0.7.0
123
+ - mpi4py==3.1.5
124
+ - mup==1.0.0
125
+ - mypy-extensions==1.0.0
126
+ - ninja==1.11.1.1
127
+ - nltk==3.8.1
128
+ - numpy==1.23.1
129
+ - omegaconf==2.3.0
130
+ - opencv-python==4.8.1.78
131
+ - pandas==2.0.3
132
+ - pathspec==0.12.1
133
+ - pillow==9.4.0
134
+ - portalocker==2.10.1
135
+ - py-cpuinfo==9.0.0
136
+ - pycocotools==2.0.7
137
+ - pydantic==1.10.18
138
+ - pydot==3.0.1
139
+ - regex==2023.10.3
140
+ - scikit-image==0.21.0
141
+ - scikit-learn==1.3.1
142
+ - sentencepiece==0.1.99
143
+ - tabulate==0.9.0
144
+ - termcolor==2.4.0
145
+ - timm==0.4.12
146
+ - tokenizers==0.14.1
147
+ - transformers==4.34.0
148
+ - vision-datasets==0.2.2
149
+ - yacs==0.1.8
colabs/requirements-colab-pip-freeze.txt ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ accelerate==0.23.0
3
+ aiohappyeyeballs==2.4.4
4
+ aiohttp==3.11.10
5
+ aiosignal==1.3.2
6
+ alabaster==1.0.0
7
+ albucore==0.0.19
8
+ albumentations==1.4.20
9
+ altair==5.5.0
10
+ annotated-types==0.7.0
11
+ antlr4-python3-runtime==4.9.3
12
+ anyio==3.7.1
13
+ appdirs==1.4.4
14
+ argon2-cffi==23.1.0
15
+ argon2-cffi-bindings==21.2.0
16
+ array_record==0.5.1
17
+ arviz==0.20.0
18
+ astropy==6.1.7
19
+ astropy-iers-data==0.2024.12.16.0.35.48
20
+ astunparse==1.6.3
21
+ async-timeout==4.0.3
22
+ atpublic==4.1.0
23
+ attrs==24.3.0
24
+ audioread==3.0.1
25
+ autograd==1.7.0
26
+ babel==2.16.0
27
+ backcall==0.2.0
28
+ beautifulsoup4==4.12.3
29
+ bigframes==1.29.0
30
+ bigquery-magics==0.4.0
31
+ black==21.4b2
32
+ bleach==6.2.0
33
+ blinker==1.9.0
34
+ blis==0.7.11
35
+ blosc2==2.7.1
36
+ bokeh==3.6.2
37
+ Bottleneck==1.4.2
38
+ bqplot==0.12.43
39
+ branca==0.8.1
40
+ CacheControl==0.14.1
41
+ cachetools==5.5.0
42
+ catalogue==2.0.10
43
+ certifi==2024.12.14
44
+ cffi==1.17.1
45
+ chardet==5.2.0
46
+ charset-normalizer==3.4.0
47
+ chex==0.1.88
48
+ clarabel==0.9.0
49
+ click==8.1.7
50
+ cloudpathlib==0.20.0
51
+ cloudpickle==3.1.0
52
+ cmake==3.31.2
53
+ cmdstanpy==1.2.5
54
+ colorcet==3.1.0
55
+ colorlover==0.3.0
56
+ colour==0.1.5
57
+ community==1.0.0b1
58
+ confection==0.1.5
59
+ cons==0.4.6
60
+ contourpy==1.3.1
61
+ cryptography==43.0.3
62
+ cuda-python==12.2.1
63
+ cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-24.10.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
64
+ cufflinks==0.17.3
65
+ cupy-cuda12x==12.2.0
66
+ cvxopt==1.3.2
67
+ cvxpy==1.6.0
68
+ cycler==0.12.1
69
+ cymem==2.0.10
70
+ Cython==3.0.2
71
+ dask==2024.10.0
72
+ datascience==0.17.6
73
+ db-dtypes==1.3.1
74
+ dbus-python==1.2.18
75
+ debugpy==1.8.0
76
+ decorator==4.4.2
77
+ deepspeed==0.10.3
78
+ defusedxml==0.7.1
79
+ Deprecated==1.2.15
80
+ detectron2 @ git+https://github.com/MaureenZOU/detectron2-xyz.git@42121d75e10d9f858f3a91b6a39f5722c02868f0
81
+ diffdist==0.1
82
+ diffusers==0.31.0
83
+ distro==1.9.0
84
+ dlib==19.24.2
85
+ dm-tree==0.1.8
86
+ docker-pycreds==0.4.0
87
+ docstring_parser==0.16
88
+ docutils==0.21.2
89
+ dopamine_rl==4.1.0
90
+ duckdb==1.1.3
91
+ earthengine-api==1.4.3
92
+ easydict==1.13
93
+ editdistance==0.8.1
94
+ eerepr==0.0.4
95
+ einops==0.8.0
96
+ en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
97
+ entrypoints==0.4
98
+ et_xmlfile==2.0.0
99
+ etils==1.11.0
100
+ etuples==0.3.9
101
+ eval_type_backport==0.2.0
102
+ exceptiongroup==1.2.2
103
+ fastai==2.7.18
104
+ fastcore==1.7.27
105
+ fastdownload==0.0.7
106
+ fastjsonschema==2.21.1
107
+ fastprogress==1.0.3
108
+ fastrlock==0.8.3
109
+ filelock==3.16.1
110
+ firebase-admin==6.6.0
111
+ Flask==3.1.0
112
+ flatbuffers==24.3.25
113
+ flax==0.8.5
114
+ folium==0.19.2
115
+ fonttools==4.55.3
116
+ frozendict==2.4.6
117
+ frozenlist==1.5.0
118
+ fsspec==2024.10.0
119
+ ftfy==6.1.1
120
+ future==1.0.0
121
+ fvcore==0.1.5.post20221221
122
+ gast==0.6.0
123
+ gcsfs==2024.10.0
124
+ GDAL==3.6.4
125
+ gdown==5.2.0
126
+ geemap==0.35.1
127
+ gensim==4.3.3
128
+ geocoder==1.38.1
129
+ geographiclib==2.0
130
+ geopandas==1.0.1
131
+ geopy==2.4.1
132
+ gin-config==0.5.0
133
+ gitdb==4.0.11
134
+ GitPython==3.1.43
135
+ glob2==0.7
136
+ google==2.0.3
137
+ google-ai-generativelanguage==0.6.10
138
+ google-api-core==2.19.2
139
+ google-api-python-client==2.155.0
140
+ google-auth==2.27.0
141
+ google-auth-httplib2==0.2.0
142
+ google-auth-oauthlib==1.2.1
143
+ google-cloud-aiplatform==1.74.0
144
+ google-cloud-bigquery==3.25.0
145
+ google-cloud-bigquery-connection==1.17.0
146
+ google-cloud-bigquery-storage==2.27.0
147
+ google-cloud-bigtable==2.27.0
148
+ google-cloud-core==2.4.1
149
+ google-cloud-datastore==2.20.2
150
+ google-cloud-firestore==2.19.0
151
+ google-cloud-functions==1.19.0
152
+ google-cloud-iam==2.17.0
153
+ google-cloud-language==2.16.0
154
+ google-cloud-pubsub==2.27.1
155
+ google-cloud-resource-manager==1.14.0
156
+ google-cloud-storage==2.19.0
157
+ google-cloud-translate==3.19.0
158
+ google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz
159
+ google-crc32c==1.6.0
160
+ google-genai==0.3.0
161
+ google-generativeai==0.8.3
162
+ google-pasta==0.2.0
163
+ google-resumable-media==2.7.2
164
+ googleapis-common-protos==1.66.0
165
+ googledrivedownloader==0.4
166
+ graphviz==0.20.3
167
+ greenlet==3.1.1
168
+ grpc-google-iam-v1==0.13.1
169
+ grpcio==1.68.1
170
+ grpcio-status==1.62.3
171
+ gspread==6.0.2
172
+ gspread-dataframe==3.3.1
173
+ gym==0.25.2
174
+ gym-notices==0.0.8
175
+ h11==0.14.0
176
+ h5netcdf==1.4.1
177
+ h5py==3.12.1
178
+ hjson==3.1.0
179
+ holidays==0.63
180
+ holoviews==1.20.0
181
+ html5lib==1.1
182
+ httpcore==1.0.7
183
+ httpimport==1.4.0
184
+ httplib2==0.22.0
185
+ httpx==0.28.1
186
+ huggingface-hub==0.17.3
187
+ humanize==4.11.0
188
+ hydra-core==1.3.2
189
+ hyperopt==0.2.7
190
+ ibis-framework==9.2.0
191
+ idna==3.10
192
+ imageio==2.36.1
193
+ imageio-ffmpeg==0.5.1
194
+ imagesize==1.4.1
195
+ imbalanced-learn==0.12.4
196
+ imgaug==0.4.0
197
+ immutabledict==4.2.1
198
+ importlib_metadata==8.5.0
199
+ importlib_resources==6.4.5
200
+ imutils==0.5.4
201
+ infinibatch==0.1.1
202
+ inflect==7.4.0
203
+ iniconfig==2.0.0
204
+ intel-cmplr-lib-ur==2025.0.4
205
+ intel-openmp==2025.0.4
206
+ iopath==0.1.9
207
+ ipyevents==2.0.2
208
+ ipyfilechooser==0.6.0
209
+ ipykernel==5.5.6
210
+ ipyleaflet==0.19.2
211
+ ipyparallel==8.8.0
212
+ ipython==7.34.0
213
+ ipython-genutils==0.2.0
214
+ ipython-sql==0.5.0
215
+ ipytree==0.2.2
216
+ ipywidgets==7.7.1
217
+ itsdangerous==2.2.0
218
+ jax==0.4.33
219
+ jax-cuda12-pjrt==0.4.33
220
+ jax-cuda12-plugin==0.4.33
221
+ jaxlib==0.4.33
222
+ jeepney==0.7.1
223
+ jellyfish==1.1.0
224
+ jieba==0.42.1
225
+ Jinja2==3.1.4
226
+ jiter==0.8.2
227
+ joblib==1.4.2
228
+ json-tricks==3.17.3
229
+ jsonpatch==1.33
230
+ jsonpickle==4.0.1
231
+ jsonpointer==3.0.0
232
+ jsonschema==4.23.0
233
+ jsonschema-specifications==2024.10.1
234
+ jupyter-client==6.1.12
235
+ jupyter-console==6.1.0
236
+ jupyter-leaflet==0.19.2
237
+ jupyter-server==1.24.0
238
+ jupyter_core==5.7.2
239
+ jupyterlab_pygments==0.3.0
240
+ jupyterlab_widgets==3.0.13
241
+ kaggle==1.6.17
242
+ kagglehub==0.3.5
243
+ keras==3.5.0
244
+ keyring==23.5.0
245
+ kiwisolver==1.4.7
246
+ kornia==0.7.0
247
+ langchain==0.3.12
248
+ langchain-core==0.3.25
249
+ langchain-text-splitters==0.3.3
250
+ langcodes==3.5.0
251
+ langsmith==0.2.3
252
+ language_data==1.3.0
253
+ launchpadlib==1.10.16
254
+ lazr.restfulclient==0.14.4
255
+ lazr.uri==1.0.6
256
+ lazy_loader==0.4
257
+ libclang==18.1.1
258
+ libcudf-cu12 @ https://pypi.nvidia.com/libcudf-cu12/libcudf_cu12-24.10.1-py3-none-manylinux_2_28_x86_64.whl
259
+ librosa==0.10.2.post1
260
+ lightgbm==4.5.0
261
+ linkify-it-py==2.0.3
262
+ llvmlite==0.43.0
263
+ locket==1.0.0
264
+ logical-unification==0.4.6
265
+ lxml==5.3.0
266
+ marisa-trie==1.2.1
267
+ Markdown==3.7
268
+ markdown-it-py==3.0.0
269
+ MarkupSafe==3.0.2
270
+ matplotlib==3.8.0
271
+ matplotlib-inline==0.1.7
272
+ matplotlib-venn==1.1.1
273
+ mdit-py-plugins==0.4.2
274
+ mdurl==0.1.2
275
+ miniKanren==1.0.3
276
+ missingno==0.5.2
277
+ mistune==3.0.2
278
+ mizani==0.13.1
279
+ mkl==2025.0.1
280
+ ml-dtypes==0.4.1
281
+ mlxtend==0.23.3
282
+ more-itertools==10.5.0
283
+ moviepy==1.0.3
284
+ mpi4py==3.1.5
285
+ mpmath==1.3.0
286
+ msgpack==1.1.0
287
+ multidict==6.1.0
288
+ multipledispatch==1.0.0
289
+ multitasking==0.0.11
290
+ mup==1.0.0
291
+ murmurhash==1.0.11
292
+ music21==9.3.0
293
+ mypy-extensions==1.0.0
294
+ namex==0.0.8
295
+ narwhals==1.18.4
296
+ natsort==8.4.0
297
+ nbclassic==1.1.0
298
+ nbclient==0.10.1
299
+ nbconvert==7.16.4
300
+ nbformat==5.10.4
301
+ ndindex==1.9.2
302
+ nest-asyncio==1.6.0
303
+ networkx==3.4.2
304
+ nibabel==5.3.2
305
+ ninja==1.11.1.3
306
+ nltk==3.8.1
307
+ notebook==6.5.5
308
+ notebook_shim==0.2.4
309
+ numba==0.60.0
310
+ numexpr==2.10.2
311
+ numpy==1.26.4
312
+ nvidia-cublas-cu12==12.6.4.1
313
+ nvidia-cuda-cupti-cu12==12.6.80
314
+ nvidia-cuda-nvcc-cu12==12.6.85
315
+ nvidia-cuda-runtime-cu12==12.6.77
316
+ nvidia-cudnn-cu12==9.6.0.74
317
+ nvidia-cufft-cu12==11.3.0.4
318
+ nvidia-curand-cu12==10.3.7.77
319
+ nvidia-cusolver-cu12==11.7.1.2
320
+ nvidia-cusparse-cu12==12.5.4.2
321
+ nvidia-nccl-cu12==2.23.4
322
+ nvidia-nvjitlink-cu12==12.6.85
323
+ nvtx==0.2.10
324
+ nx-cugraph-cu12 @ https://pypi.nvidia.com/nx-cugraph-cu12/nx_cugraph_cu12-24.10.0-py3-none-any.whl
325
+ oauth2client==4.1.3
326
+ oauthlib==3.2.2
327
+ omegaconf==2.3.0
328
+ open_clip_torch==2.26.1
329
+ openai==1.57.4
330
+ opencv-contrib-python==4.10.0.84
331
+ opencv-python==4.8.1.78
332
+ opencv-python-headless==4.10.0.84
333
+ openpyxl==3.1.5
334
+ opentelemetry-api==1.29.0
335
+ opentelemetry-sdk==1.29.0
336
+ opentelemetry-semantic-conventions==0.50b0
337
+ opt_einsum==3.4.0
338
+ optax==0.2.4
339
+ optree==0.13.1
340
+ orbax-checkpoint==0.6.4
341
+ orjson==3.10.12
342
+ osqp==0.6.7.post3
343
+ packaging==24.2
344
+ pandas==2.0.3
345
+ pandas-datareader==0.10.0
346
+ pandas-gbq==0.25.0
347
+ pandas-stubs==2.2.2.240909
348
+ pandocfilters==1.5.1
349
+ panel==1.5.4
350
+ param==2.2.0
351
+ parso==0.8.4
352
+ parsy==2.1
353
+ partd==1.4.2
354
+ pathlib==1.0.1
355
+ pathspec==0.12.1
356
+ patsy==1.0.1
357
+ peewee==3.17.8
358
+ peft==0.14.0
359
+ pexpect==4.9.0
360
+ pickleshare==0.7.5
361
+ Pillow==9.4.0
362
+ platformdirs==4.3.6
363
+ plotly==5.24.1
364
+ plotnine==0.14.4
365
+ pluggy==1.5.0
366
+ ply==3.11
367
+ polars==1.9.0
368
+ pooch==1.8.2
369
+ portalocker==3.0.0
370
+ portpicker==1.5.2
371
+ preshed==3.0.9
372
+ prettytable==3.12.0
373
+ proglog==0.1.10
374
+ progressbar2==4.5.0
375
+ prometheus_client==0.21.1
376
+ promise==2.3
377
+ prompt_toolkit==3.0.48
378
+ propcache==0.2.1
379
+ prophet==1.1.6
380
+ proto-plus==1.25.0
381
+ protobuf==4.25.5
382
+ psutil==5.9.5
383
+ psycopg2==2.9.10
384
+ ptyprocess==0.7.0
385
+ py-cpuinfo==9.0.0
386
+ py4j==0.10.9.7
387
+ pyarrow==17.0.0
388
+ pyasn1==0.6.1
389
+ pyasn1_modules==0.4.1
390
+ pycocotools==2.0.7
391
+ pycparser==2.22
392
+ pydantic==1.10.19
393
+ pydantic_core==2.27.1
394
+ pydata-google-auth==1.9.0
395
+ pydot==3.0.3
396
+ pydotplus==2.0.2
397
+ PyDrive==1.3.1
398
+ PyDrive2==1.21.3
399
+ pyerfa==2.0.1.5
400
+ pygame==2.6.1
401
+ pygit2==1.16.0
402
+ Pygments==2.18.0
403
+ PyGObject==3.42.1
404
+ PyJWT==2.10.1
405
+ pylibcudf-cu12 @ https://pypi.nvidia.com/pylibcudf-cu12/pylibcudf_cu12-24.10.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
406
+ pylibcugraph-cu12==24.10.0
407
+ pylibraft-cu12==24.10.0
408
+ pymc==5.19.1
409
+ pymystem3==0.2.0
410
+ pynvjitlink-cu12==0.4.0
411
+ pyogrio==0.10.0
412
+ Pyomo==6.8.2
413
+ PyOpenGL==3.1.7
414
+ pyOpenSSL==24.2.1
415
+ pyparsing==3.2.0
416
+ pyperclip==1.9.0
417
+ pyproj==3.7.0
418
+ pyshp==2.3.1
419
+ PySocks==1.7.1
420
+ pyspark==3.5.3
421
+ pytensor==2.26.4
422
+ pytest==8.3.4
423
+ python-apt==0.0.0
424
+ python-box==7.3.0
425
+ python-dateutil==2.8.2
426
+ python-louvain==0.16
427
+ python-slugify==8.0.4
428
+ python-utils==3.9.1
429
+ pytz==2024.2
430
+ pyviz_comms==3.0.3
431
+ PyWavelets==1.8.0
432
+ PyYAML==6.0.1
433
+ pyzmq==24.0.1
434
+ qdldl==0.1.7.post4
435
+ ratelim==0.1.6
436
+ referencing==0.35.1
437
+ regex==2023.10.3
438
+ requests==2.32.3
439
+ requests-oauthlib==1.3.1
440
+ requests-toolbelt==1.0.0
441
+ requirements-parser==0.9.0
442
+ rich==13.9.4
443
+ rmm-cu12==24.10.0
444
+ rpds-py==0.22.3
445
+ rpy2==3.4.2
446
+ rsa==4.9
447
+ safetensors==0.4.5
448
+ scikit-image==0.21.0
449
+ scikit-learn==1.3.1
450
+ scipy==1.13.1
451
+ scooby==0.10.0
452
+ scs==3.2.7
453
+ seaborn==0.13.2
454
+ SecretStorage==3.3.1
455
+ Send2Trash==1.8.3
456
+ sentence-transformers==3.3.1
457
+ sentencepiece==0.1.99
458
+ sentry-sdk==2.19.2
459
+ setproctitle==1.3.4
460
+ shap==0.46.0
461
+ shapely==2.0.6
462
+ shellingham==1.5.4
463
+ simple-parsing==0.1.6
464
+ six==1.17.0
465
+ sklearn-pandas==2.2.0
466
+ slicer==0.0.8
467
+ smart-open==7.1.0
468
+ smmap==5.0.1
469
+ sniffio==1.3.1
470
+ snowballstemmer==2.2.0
471
+ soundfile==0.12.1
472
+ soupsieve==2.6
473
+ soxr==0.5.0.post1
474
+ spacy==3.7.5
475
+ spacy-legacy==3.0.12
476
+ spacy-loggers==1.0.5
477
+ Sphinx==8.1.3
478
+ sphinxcontrib-applehelp==2.0.0
479
+ sphinxcontrib-devhelp==2.0.0
480
+ sphinxcontrib-htmlhelp==2.1.0
481
+ sphinxcontrib-jsmath==1.0.1
482
+ sphinxcontrib-qthelp==2.0.0
483
+ sphinxcontrib-serializinghtml==2.0.0
484
+ SQLAlchemy==2.0.36
485
+ sqlglot==25.1.0
486
+ sqlparse==0.5.3
487
+ srsly==2.5.0
488
+ stanio==0.5.1
489
+ statsmodels==0.14.4
490
+ StrEnum==0.4.15
491
+ stringzilla==3.11.1
492
+ sympy==1.13.1
493
+ tables==3.10.1
494
+ tabulate==0.9.0
495
+ tbb==2022.0.0
496
+ tcmlib==1.2.0
497
+ tenacity==9.0.0
498
+ tensorboard==2.17.1
499
+ tensorboard-data-server==0.7.2
500
+ tensorflow==2.17.1
501
+ tensorflow-datasets==4.9.7
502
+ tensorflow-hub==0.16.1
503
+ tensorflow-io-gcs-filesystem==0.37.1
504
+ tensorflow-metadata==1.13.1
505
+ tensorflow-probability==0.24.0
506
+ tensorstore==0.1.71
507
+ termcolor==2.5.0
508
+ terminado==0.18.1
509
+ text-unidecode==1.3
510
+ textblob==0.17.1
511
+ tf-slim==1.1.0
512
+ tf_keras==2.17.0
513
+ thinc==8.2.5
514
+ threadpoolctl==3.5.0
515
+ tifffile==2024.12.12
516
+ timm==0.4.12
517
+ tinycss2==1.4.0
518
+ tokenizers==0.14.1
519
+ toml==0.10.2
520
+ tomli==2.2.1
521
+ toolz==0.12.1
522
+ torch @ https://download.pytorch.org/whl/cu121_full/torch-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl
523
+ torchaudio @ https://download.pytorch.org/whl/cu121/torchaudio-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl
524
+ torchsummary==1.5.1
525
+ torchvision @ https://download.pytorch.org/whl/cu121/torchvision-0.20.1%2Bcu121-cp310-cp310-linux_x86_64.whl
526
+ tornado==6.3.3
527
+ tqdm==4.67.1
528
+ traitlets==5.7.1
529
+ traittypes==0.2.1
530
+ transformers==4.34.0
531
+ tweepy==4.14.0
532
+ typeguard==4.4.1
533
+ typer==0.15.1
534
+ types-pytz==2024.2.0.20241003
535
+ types-setuptools==75.6.0.20241126
536
+ typing_extensions==4.12.2
537
+ tzdata==2024.2
538
+ tzlocal==5.2
539
+ uc-micro-py==1.0.3
540
+ umf==0.9.1
541
+ uritemplate==4.1.1
542
+ urllib3==2.2.3
543
+ vega-datasets==0.9.0
544
+ vision-datasets==0.2.2
545
+ wadllib==1.3.6
546
+ wandb==0.19.1
547
+ wasabi==1.1.3
548
+ wcwidth==0.2.13
549
+ weasel==0.4.1
550
+ webcolors==24.11.1
551
+ webencodings==0.5.1
552
+ websocket-client==1.8.0
553
+ websockets==14.1
554
+ Werkzeug==3.1.3
555
+ widgetsnbextension==3.6.10
556
+ wordcloud==1.9.4
557
+ wrapt==1.17.0
558
+ xarray==2024.11.0
559
+ xarray-einstats==0.8.0
560
+ xgboost==2.1.3
561
+ xlrd==2.0.1
562
+ xyzservices==2024.9.0
563
+ yacs==0.1.8
564
+ yarl==1.18.3
565
+ yellowbrick==1.5
566
+ yfinance==0.2.50
567
+ zipp==3.21.0
colabs/requirements-colab.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pillow==9.4.0
2
+ opencv-python==4.8.1.78
3
+ pyyaml==6.0.1
4
+ json_tricks==3.17.3
5
+ yacs==0.1.8
6
+ scikit-learn==1.3.1
7
+ pandas==2.0.3
8
+ timm==0.4.12
9
+ numpy==1.26.4
10
+ einops==0.8.0
11
+ fvcore==0.1.5.post20221221
12
+ transformers==4.34.0
13
+ sentencepiece==0.1.99
14
+ ftfy==6.1.1
15
+ regex==2023.10.3
16
+ nltk==3.8.1
17
+ mpi4py==3.1.5
18
+ vision-datasets==0.2.2
19
+ cython==3.0.2
20
+ pycocotools==2.0.7
21
+ diffdist==0.1
22
+ #pyarrow==13.0.0
23
+ #cityscapesscripts==2.2.2
24
+ #shapely==1.8.0
25
+ scikit-image==0.21.0
26
+ mup==1.0.0
27
+ accelerate==0.23.0
28
+ kornia==0.7.0
29
+ deepspeed==0.10.3
30
+ #wandb==0.15.12
31
+ infinibatch==0.1.1
32
+ open-clip-torch==2.26.1
33
+ git+https://github.com/MaureenZOU/detectron2-xyz.git
34
+ #gradio==3.42.0
35
+ #torch==2.3.1 #2.0.1
36
+ #torchvision==0.15.2
37
+ #torchaudio==2.0.2
38
+ #torch==2.1.0
39
+ #torchvision==0.16.0
configs/biomedparse_inference.yaml ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define Test/Trainer/Saving
2
+ PIPELINE: XDecoderPipeline
3
+ TRAINER: xdecoder
4
+ SAVE_DIR: "../../data/output/test"
5
+ base_path: "./"
6
+
7
+ # Resume Logistic
8
+ RESUME: false
9
+ WEIGHT: false
10
+ RESUME_FROM: ""
11
+ EVAL_AT_START: false
12
+
13
+ # Logging and Debug
14
+ WANDB: False
15
+ LOG_EVERY: 100
16
+ FIND_UNUSED_PARAMETERS: false
17
+
18
+ # Speed up training
19
+ FP16: false
20
+ PORT: "36873"
21
+
22
+ # misc
23
+ LOADER:
24
+ JOINT: False
25
+ KEY_DATASET: "coco"
26
+
27
+ STANDARD_TEXT_FOR_EVAL: False
28
+
29
+ ##################
30
+ # Task settings
31
+ ##################
32
+ VERBOSE: true
33
+ MODEL:
34
+ NAME: seem_model_demo
35
+ HEAD: xdecoder_head
36
+ DIM_PROJ: 512
37
+ TEXT:
38
+ ARCH: vlpencoder
39
+ NAME: transformer
40
+ TOKENIZER: clip
41
+ CONTEXT_LENGTH: 77 # 77
42
+ WIDTH: 512
43
+ HEADS: 8
44
+ LAYERS: 12 # 6
45
+ AUTOGRESSIVE: True
46
+ BACKBONE:
47
+ NAME: focal
48
+ PRETRAINED: ""
49
+ LOAD_PRETRAINED: false
50
+ FOCAL:
51
+ PRETRAIN_IMG_SIZE: 224
52
+ PATCH_SIZE: 4
53
+ EMBED_DIM: 192
54
+ DEPTHS: [2, 2, 18, 2]
55
+ FOCAL_LEVELS: [4, 4, 4, 4]
56
+ FOCAL_WINDOWS: [3, 3, 3, 3]
57
+ DROP_PATH_RATE: 0.3
58
+ MLP_RATIO: 4.0
59
+ DROP_RATE: 0.0
60
+ PATCH_NORM: True
61
+ USE_CONV_EMBED: True
62
+ SCALING_MODULATOR: True
63
+ USE_CHECKPOINT: False
64
+ USE_POSTLN: true
65
+ USE_POSTLN_IN_MODULATION: false
66
+ USE_LAYERSCALE: True
67
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
68
+ OUT_INDICES: [0, 1, 2, 3]
69
+ ENCODER:
70
+ NAME: transformer_encoder_fpn
71
+ IGNORE_VALUE: 255
72
+ NUM_CLASSES: 16
73
+ BINARY_CLASSES: False
74
+ LOSS_WEIGHT: 1.0
75
+ CONVS_DIM: 512
76
+ MASK_DIM: 512
77
+ NORM: "GN"
78
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
79
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
80
+ COMMON_STRIDE: 4
81
+ TRANSFORMER_ENC_LAYERS: 6
82
+ DECODER:
83
+ NAME: seem_demo
84
+ TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
85
+ MASK:
86
+ ENABLED: False
87
+ DETECTION: False
88
+ SPATIAL:
89
+ ENABLED: True
90
+ MAX_ITER: 1
91
+ GROUNDING:
92
+ ENABLED: True
93
+ MAX_LEN: 5
94
+ TEXT_WEIGHT: 2.0
95
+ CLASS_WEIGHT: 0.5
96
+ VISUAL:
97
+ ENABLED: False
98
+ AUDIO:
99
+ ENABLED: False
100
+ RETRIEVAL:
101
+ ENABLED: False
102
+ LVIS:
103
+ ENABLED: True
104
+ THRES: 0.7
105
+ OPENIMAGE:
106
+ ENABLED: False
107
+ NEGATIVE_SAMPLES: 5
108
+ GROUNDING:
109
+ ENABLED: False
110
+ MAX_LEN: 5
111
+ CAPTION:
112
+ ENABLED: False
113
+ PHRASE_PROB: 0.5
114
+ SIM_THRES: 0.95
115
+ DEEP_SUPERVISION: True
116
+ NO_OBJECT_WEIGHT: 0.1
117
+ GCLASS_WEIGHT: 0.4
118
+ GMASK_WEIGHT: 1.0
119
+ GDICE_WEIGHT: 1.0
120
+ SCLASS_WEIGHT: 0.4
121
+ SMASK_WEIGHT: 1.0
122
+ SDICE_WEIGHT: 1.0
123
+ OCLASS_WEIGHT: 0.4
124
+ OMASK_WEIGHT: 1.0
125
+ ODICE_WEIGHT: 1.0
126
+ CLASS_WEIGHT: 2.0
127
+ MASK_WEIGHT: 5.0
128
+ DICE_WEIGHT: 5.0
129
+ BBOX_WEIGHT: 5.0
130
+ GIOU_WEIGHT: 2.0
131
+ CAPTION_WEIGHT: 2.0
132
+ COST_SPATIAL:
133
+ CLASS_WEIGHT: 5.0
134
+ MASK_WEIGHT: 2.0
135
+ DICE_WEIGHT: 2.0
136
+ HIDDEN_DIM: 512
137
+ NUM_OBJECT_QUERIES: 101
138
+ NHEADS: 8
139
+ DROPOUT: 0.0
140
+ DIM_FEEDFORWARD: 2048
141
+ MAX_SPATIAL_LEN: [512, 512, 512, 512]
142
+ # ENC_LAYERS: 0
143
+ PRE_NORM: False
144
+ ENFORCE_INPUT_PROJ: False
145
+ SIZE_DIVISIBILITY: 32
146
+ TRAIN_NUM_POINTS: 12544
147
+ OVERSAMPLE_RATIO: 3.0
148
+ IMPORTANCE_SAMPLE_RATIO: 0.75
149
+ DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
150
+ TOP_GROUNDING_LAYERS: 10
151
+ TOP_CAPTION_LAYERS: 10
152
+ TOP_SPATIAL_LAYERS: 10
153
+ TOP_OPENIMAGE_LAYERS: 10
154
+ TEST:
155
+ SEMANTIC_ON: True
156
+ INSTANCE_ON: True
157
+ PANOPTIC_ON: True
158
+ OVERLAP_THRESHOLD: 0.8
159
+ OBJECT_MASK_THRESHOLD: 0.4
160
+ SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false
161
+ DETECTIONS_PER_IMAGE: 100
162
+
163
+ # Multi-modal Architecture, order matters
164
+ ATTENTION_ARCH:
165
+ VARIABLE:
166
+ queries: ["object"]
167
+ tokens: ["grounding", "spatial", "visual", "audio"]
168
+ SELF_ATTENTION:
169
+ queries:
170
+ object:
171
+ [
172
+ "queries_object",
173
+ "tokens_grounding",
174
+ "tokens_spatial",
175
+ "tokens_visual",
176
+ "tokens_audio",
177
+ ]
178
+ tokens:
179
+ grounding: ["queries_object", "tokens_grounding"]
180
+ spatial: ["tokens_spatial"]
181
+ visual: ["tokens_visual"]
182
+ audio: ["queries_object", "tokens_audio"]
183
+ CROSS_ATTENTION:
184
+ queries:
185
+ object: True
186
+ tokens:
187
+ grounding: False
188
+ spatial: False
189
+ visual: False
190
+ audio: False
191
+ MASKING:
192
+ ["tokens_spatial", "tokens_grounding", "tokens_visual", "tokens_audio"]
193
+ DUPLICATION:
194
+ queries:
195
+ grounding: "queries_object"
196
+ spatial: "queries_object"
197
+ SPATIAL_MEMORIES: 32
198
+
199
+ INPUT:
200
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
201
+ PIXEL_STD: [58.395, 57.120, 57.375]
202
+ # INPUT:
203
+ # PIXEL_MEAN: [64.284, 59.293, 59.962]
204
+ # PIXEL_STD: [62.484, 60.865, 59.835]
entrypoint.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ if [ -f "/run/secrets/HF_TOKEN" ]; then
3
+ export HF_TOKEN=$(cat /run/secrets/HF_TOKEN)
4
+ fi
5
+ exec conda run --no-capture-output -n biomedparse python main.py
examples/Part_1_516_pathology_breast.png ADDED

Git LFS Details

  • SHA256: 473e76cd22df5b7d9da17ed49dc7139be1f6d62d4854c49236bd953b35b04c34
  • Pointer size: 131 Bytes
  • Size of remote file: 966 kB
inference_utils/inference.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ #from utils.visualizer import Visualizer
7
+ # from detectron2.utils.colormap import random_color
8
+ # from detectron2.data import MetadataCatalog
9
+ # from detectron2.structures import BitMasks
10
+ from modeling.language.loss import vl_similarity
11
+ from utilities.constants import BIOMED_CLASSES
12
+ #from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
13
+
14
+ # import cv2
15
+ # import os
16
+ # import glob
17
+ # import subprocess
18
+ from PIL import Image
19
+ import random
20
+
21
+ t = []
22
+ t.append(transforms.Resize((1024, 1024), interpolation=Image.BICUBIC))
23
+ transform = transforms.Compose(t)
24
+ #metadata = MetadataCatalog.get('coco_2017_train_panoptic')
25
+ all_classes = ['background'] + [name.replace('-other','').replace('-merged','')
26
+ for name in BIOMED_CLASSES] + ["others"]
27
+ # colors_list = [(np.array(color['color'])/255).tolist() for color in COCO_CATEGORIES] + [[1, 1, 1]]
28
+
29
+ # use color list from matplotlib
30
+ import matplotlib.colors as mcolors
31
+ colors = dict(mcolors.TABLEAU_COLORS, **mcolors.BASE_COLORS)
32
+ colors_list = [list(colors.values())[i] for i in range(16)]
33
+
34
+ from .output_processing import mask_stats, combine_masks
35
+
36
+
37
+ @torch.no_grad()
38
+ def interactive_infer_image(model, image, prompts):
39
+
40
+ image_resize = transform(image)
41
+ width = image.size[0]
42
+ height = image.size[1]
43
+ image_resize = np.asarray(image_resize)
44
+ image = torch.from_numpy(image_resize.copy()).permute(2,0,1).cuda()
45
+
46
+ data = {"image": image, 'text': prompts, "height": height, "width": width}
47
+
48
+ # inistalize task
49
+ model.model.task_switch['spatial'] = False
50
+ model.model.task_switch['visual'] = False
51
+ model.model.task_switch['grounding'] = True
52
+ model.model.task_switch['audio'] = False
53
+ model.model.task_switch['grounding'] = True
54
+
55
+
56
+ batch_inputs = [data]
57
+ results,image_size,extra = model.model.evaluate_demo(batch_inputs)
58
+
59
+ pred_masks = results['pred_masks'][0]
60
+ v_emb = results['pred_captions'][0]
61
+ t_emb = extra['grounding_class']
62
+
63
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
64
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
65
+
66
+ temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
67
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
68
+
69
+ matched_id = out_prob.max(0)[1]
70
+ pred_masks_pos = pred_masks[matched_id,:,:]
71
+ pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
72
+
73
+ # interpolate mask to ori size
74
+ pred_mask_prob = F.interpolate(pred_masks_pos[None,], (data['height'], data['width']),
75
+ mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy()
76
+ pred_masks_pos = (1*(pred_mask_prob > 0.5)).astype(np.uint8)
77
+
78
+ return pred_mask_prob
79
+
80
+
81
+
82
+ # def interactive_infer_panoptic_biomedseg(model, image, tasks, reftxt=None):
83
+ # image_ori = transform(image)
84
+ # #mask_ori = image['mask']
85
+ # width = image_ori.size[0]
86
+ # height = image_ori.size[1]
87
+ # image_ori = np.asarray(image_ori)
88
+ # visual = Visualizer(image_ori, metadata=metadata)
89
+ # images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
90
+
91
+ # data = {"image": images, "height": height, "width": width}
92
+ # if len(tasks) == 0:
93
+ # tasks = ["Panoptic"]
94
+
95
+ # # inistalize task
96
+ # model.model.task_switch['spatial'] = False
97
+ # model.model.task_switch['visual'] = False
98
+ # model.model.task_switch['grounding'] = False
99
+ # model.model.task_switch['audio'] = False
100
+
101
+ # # check if reftxt is list of strings
102
+ # assert isinstance(reftxt, list), f"reftxt should be a list of strings, but got {type(reftxt)}"
103
+ # model.model.task_switch['grounding'] = True
104
+ # predicts = {}
105
+ # for i, txt in enumerate(reftxt):
106
+ # data['text'] = txt
107
+ # batch_inputs = [data]
108
+
109
+ # results,image_size,extra = model.model.evaluate_demo(batch_inputs)
110
+
111
+ # pred_masks = results['pred_masks'][0]
112
+ # v_emb = results['pred_captions'][0]
113
+ # t_emb = extra['grounding_class']
114
+
115
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
116
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
117
+
118
+ # temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale
119
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
120
+
121
+ # matched_id = out_prob.max(0)[1]
122
+ # pred_masks_pos = pred_masks[matched_id,:,:]
123
+ # pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
124
+
125
+
126
+ # # interpolate mask to ori size
127
+ # #pred_masks_pos = (F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']] > 0.0).float().cpu().numpy()
128
+ # # masks.append(pred_masks_pos[0])
129
+ # # mask = pred_masks_pos[0]
130
+ # # masks.append(mask)
131
+ # # interpolate mask to ori size
132
+ # pred_mask_prob = F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy()
133
+ # #pred_masks_pos = 1*(pred_mask_prob > 0.5)
134
+ # predicts[txt] = pred_mask_prob[0]
135
+
136
+ # masks = combine_masks(predicts)
137
+
138
+ # predict_mask_stats = {}
139
+ # print(masks.keys())
140
+ # for i, txt in enumerate(masks):
141
+ # mask = masks[txt]
142
+ # demo = visual.draw_binary_mask(mask, color=colors_list[i], text=txt)
143
+ # predict_mask_stats[txt] = mask_stats((predicts[txt]*255), image_ori)
144
+
145
+ # res = demo.get_image()
146
+ # torch.cuda.empty_cache()
147
+ # # return Image.fromarray(res), stroke_inimg, stroke_refimg
148
+ # return Image.fromarray(res), None, predict_mask_stats
149
+
inference_utils/output_processing.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from scipy import stats
3
+ import numpy as np
4
+
5
+ import huggingface_hub
6
+
7
+
8
+ def check_mask_stats(img, mask, modality_type, target):
9
+ # img: np.array, shape=(H, W, 3) RGB image with pixel values in [0, 255]
10
+ # mask: np.array, shape=(H, W, 1) mask probability scaled to [0,255] with pixel values in [0, 255]
11
+ # modality_type: str, see target_dist.json for the list of modality types
12
+ # target: str, see target_dist.json for the list of targets
13
+
14
+ huggingface_hub.hf_hub_download('microsoft/BiomedParse', filename='target_dist.json', local_dir='./inference_utils')
15
+ huggingface_hub.hf_hub_download('microsoft/BiomedParse', filename="config.yaml", local_dir="./configs")
16
+ target_dist = json.load(open("inference_utils/target_dist.json"))
17
+
18
+ if modality_type not in target_dist:
19
+ raise ValueError(f"Currently support modality types: {list(target_dist.keys())}")
20
+
21
+ if target not in target_dist[modality_type]:
22
+ raise ValueError(f"Currently support targets for {modality_type}: {list(target_dist[modality_type].keys())}")
23
+
24
+ ms = mask_stats(mask, img)
25
+
26
+ ps = [stats.ks_1samp([ms[i]], stats.beta(param[0], param[1]).cdf).pvalue for i, param in enumerate(target_dist[modality_type][target])]
27
+ p_value = np.prod(ps)
28
+
29
+ adj_p_value = p_value**0.24 # adjustment for four test products
30
+
31
+ return adj_p_value
32
+
33
+
34
+
35
+ def mask_stats(mask, img):
36
+ # mask is a prediction mask with pixel values in [0, 255] for probability in [0, 1]
37
+ # img is a RGB image with pixel values in [0, 255]
38
+ if mask.max() <= 127:
39
+ return [0, 0, 0, 0]
40
+ return [mask[mask>=128].mean()/256, img[:,:,0][mask>=128].mean()/256,
41
+ img[:,:,1][mask>=128].mean()/256, img[:,:,2][mask>=128].mean()/256]
42
+
43
+
44
+
45
+ def combine_masks(predicts):
46
+ # predicts: a dictionary of pixel probability, {TARGET: pred_prob}
47
+ pixel_preds = {}
48
+ target_area = {}
49
+ target_probs = {}
50
+ for target in predicts:
51
+ pred = predicts[target]
52
+ pred_region = np.where(pred > 0.1)
53
+ target_area[target] = 0
54
+ target_probs[target] = 0
55
+ for (i,j) in zip(*pred_region):
56
+ if (i,j) not in pixel_preds:
57
+ pixel_preds[(i,j)] = {}
58
+ pixel_preds[(i,j)][target] = pred[i,j]
59
+ target_area[target] += 1
60
+ target_probs[target] += pred[i,j]
61
+ for target in predicts:
62
+ if target_area[target] == 0:
63
+ continue
64
+ target_probs[target] /= target_area[target]
65
+
66
+ # generate combined masks
67
+ combined_areas = {t: 0 for t in predicts}
68
+ for index in pixel_preds:
69
+ pred_target = sorted(pixel_preds[index].keys(), key=lambda t: pixel_preds[index][t], reverse=True)[0]
70
+ combined_areas[pred_target] += 1
71
+
72
+ # discard targets with small areas
73
+ discard_targets = []
74
+ for target in predicts:
75
+ if combined_areas[target] < 0.6 * target_area[target]:
76
+ discard_targets.append(target)
77
+
78
+ # keep the most confident target
79
+ most_confident_target = sorted(predicts.keys(), key=lambda t: target_probs[t], reverse=True)[0]
80
+
81
+ discard_targets = [t for t in discard_targets if t != most_confident_target]
82
+
83
+ masks = {t: np.zeros_like(predicts[t]).astype(np.uint8) for t in predicts if t not in discard_targets}
84
+ for index in pixel_preds:
85
+ candidates = [t for t in pixel_preds[index] if t not in discard_targets and pixel_preds[index][t] > 0.5]
86
+ if len(candidates) == 0:
87
+ continue
88
+ pred_target = max(candidates, key=lambda t: pixel_preds[index][t])
89
+ masks[pred_target][index[0], index[1]] = 1
90
+
91
+ return masks
inference_utils/processing_utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from skimage import transform
3
+ import pydicom
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ import nibabel as nib
7
+ import SimpleITK as sitk
8
+ from skimage import measure
9
+
10
+
11
+ """
12
+ This script contains utility functions for reading and processing different imaging modalities.
13
+ """
14
+
15
+
16
+ CT_WINDOWS = {'abdomen': [-150, 250],
17
+ 'lung': [-1000, 1000],
18
+ 'pelvis': [-55, 200],
19
+ 'liver': [-25, 230],
20
+ 'colon': [-68, 187],
21
+ 'pancreas': [-100, 200]}
22
+
23
+ def process_intensity_image(image_data, is_CT, site=None):
24
+ # process intensity-based image. If CT, apply site specific windowing
25
+
26
+ # image_data: 2D numpy array of shape (H, W)
27
+
28
+ # return: 3-channel numpy array of shape (H, W, 3) as model input
29
+
30
+ if is_CT:
31
+ # process image with windowing
32
+ if site and site in CT_WINDOWS:
33
+ window = CT_WINDOWS[site]
34
+ else:
35
+ raise ValueError(f'Please choose CT site from {CT_WINDOWS.keys()}')
36
+ lower_bound, upper_bound = window
37
+ else:
38
+ # process image with intensity range 0.5-99.5 percentile
39
+ lower_bound, upper_bound = np.percentile(
40
+ image_data[image_data > 0], 0.5
41
+ ), np.percentile(image_data[image_data > 0], 99.5)
42
+
43
+ image_data_pre = np.clip(image_data, lower_bound, upper_bound)
44
+ image_data_pre = (
45
+ (image_data_pre - image_data_pre.min())
46
+ / (image_data_pre.max() - image_data_pre.min())
47
+ * 255.0
48
+ )
49
+
50
+ # pad to square with equal padding on both sides
51
+ shape = image_data_pre.shape
52
+ if shape[0] > shape[1]:
53
+ pad = (shape[0]-shape[1])//2
54
+ pad_width = ((0,0), (pad, pad))
55
+ elif shape[0] < shape[1]:
56
+ pad = (shape[1]-shape[0])//2
57
+ pad_width = ((pad, pad), (0,0))
58
+ else:
59
+ pad_width = None
60
+
61
+ if pad_width is not None:
62
+ image_data_pre = np.pad(image_data_pre, pad_width, 'constant', constant_values=0)
63
+
64
+ # resize image to 1024x1024
65
+ image_size = 1024
66
+ resize_image = transform.resize(image_data_pre, (image_size, image_size), order=3,
67
+ mode='constant', preserve_range=True, anti_aliasing=True)
68
+
69
+ # convert to 3-channel image
70
+ resize_image = np.stack([resize_image]*3, axis=-1)
71
+
72
+ return resize_image.astype(np.uint8)
73
+
74
+
75
+
76
+ def read_dicom(image_path, is_CT, site=None):
77
+ # read dicom file and return pixel data
78
+
79
+ # dicom_file: str, path to dicom file
80
+ # is_CT: bool, whether image is CT or not
81
+ # site: str, one of CT_WINDOWS.keys()
82
+ # return: 2D numpy array of shape (H, W)
83
+
84
+ ds = pydicom.dcmread(image_path)
85
+ image_array = ds.pixel_array * ds.RescaleSlope + ds.RescaleIntercept
86
+
87
+ image_array = process_intensity_image(image_array, is_CT, site)
88
+
89
+ return image_array
90
+
91
+
92
+ def read_nifti(image_path, is_CT, slice_idx, site=None, HW_index=(0, 1), channel_idx=None):
93
+ # read nifti file and return pixel data
94
+
95
+ # image_path: str, path to nifti file
96
+ # is_CT: bool, whether image is CT or not
97
+ # slice_idx: int, slice index to read
98
+ # site: str, one of CT_WINDOWS.keys()
99
+ # HW_index: tuple, index of height and width in the image shape
100
+ # return: 2D numpy array of shape (H, W)
101
+
102
+
103
+ nii = nib.load(image_path)
104
+ image_array = nii.get_fdata()
105
+
106
+ if HW_index != (0, 1):
107
+ image_array = np.moveaxis(image_array, HW_index, (0, 1))
108
+
109
+ # get slice
110
+ if channel_idx is None:
111
+ image_array = image_array[:, :, slice_idx]
112
+ else:
113
+ image_array = image_array[:, :, slice_idx, channel_idx]
114
+
115
+ image_array = process_intensity_image(image_array, is_CT, site)
116
+ return image_array
117
+
118
+
119
+
120
+ def read_rgb(image_path):
121
+ # read RGB image and return resized pixel data
122
+
123
+ # image_path: str, path to RGB image
124
+ # return: BytesIO buffer
125
+
126
+ # read image into numpy array
127
+ image = Image.open(image_path)
128
+ image = np.array(image)
129
+ if len(image.shape) == 2:
130
+ image = np.stack([image]*3, axis=-1)
131
+ elif image.shape[2] == 4:
132
+ image = image[:,:,:3]
133
+
134
+ # pad to square with equal padding on both sides
135
+ shape = image.shape
136
+ if shape[0] > shape[1]:
137
+ pad = (shape[0]-shape[1])//2
138
+ pad_width = ((0,0), (pad, pad), (0,0))
139
+ elif shape[0] < shape[1]:
140
+ pad = (shape[1]-shape[0])//2
141
+ pad_width = ((pad, pad), (0,0), (0,0))
142
+ else:
143
+ pad_width = None
144
+
145
+ if pad_width is not None:
146
+ image = np.pad(image, pad_width, 'constant', constant_values=0)
147
+
148
+ # resize image to 1024x1024 for each channel
149
+ image_size = 1024
150
+ resize_image = np.zeros((image_size, image_size, 3), dtype=np.uint8)
151
+ for i in range(3):
152
+ resize_image[:,:,i] = transform.resize(image[:,:,i], (image_size, image_size), order=3,
153
+ mode='constant', preserve_range=True, anti_aliasing=True)
154
+
155
+ return resize_image
156
+
157
+
158
+
159
+ def get_instances(mask):
160
+ # get intances from binary mask
161
+ seg = sitk.GetImageFromArray(mask)
162
+ filled = sitk.BinaryFillhole(seg)
163
+ d = sitk.SignedMaurerDistanceMap(filled, insideIsPositive=False, squaredDistance=False, useImageSpacing=False)
164
+
165
+ ws = sitk.MorphologicalWatershed( d, markWatershedLine=False, level=1)
166
+ ws = sitk.Mask( ws, sitk.Cast(seg, ws.GetPixelID()))
167
+ ins_mask = sitk.GetArrayFromImage(ws)
168
+
169
+ # filter out instances with small area outliers
170
+ props = measure.regionprops_table(ins_mask, properties=('label', 'area'))
171
+ mean_area = np.mean(props['area'])
172
+ std_area = np.std(props['area'])
173
+
174
+ threshold = mean_area - 2*std_area - 1
175
+ ins_mask_filtered = ins_mask.copy()
176
+ for i, area in zip(props['label'], props['area']):
177
+ if area < threshold:
178
+ ins_mask_filtered[ins_mask == i] = 0
179
+
180
+ return ins_mask_filtered
181
+
182
+
inference_utils/target_dist.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"CT-Abdomen": {"postcava": [[244.8001455798728, 5.314270814858824], [7.183679633251858, 5.168810995426391], [7.183679633251858, 5.168810995426391], [7.183679633251858, 5.168810995426391]], "aorta": [[570.5260544851909, 8.97527503179567], [3.3715049586348242, 1.4971164544774238], [3.3715049586348242, 1.4971164544774238], [3.3715049586348242, 1.4971164544774238]], "right kidney": [[831.8568013426873, 14.991866448573818], [4.970270375121704, 3.050385928796316], [4.970270375121704, 3.050385928796316], [4.970270375121704, 3.050385928796316]], "kidney": [[824.7288483151449, 17.740666994112335], [5.134294543833492, 3.188304874790919], [5.134294543833492, 3.188304874790919], [5.134294543833492, 3.188304874790919]], "left kidney": [[765.9269280548916, 14.314482540419498], [5.084499568327313, 3.2061871556243515], [5.084499568327313, 3.2061871556243515], [5.084499568327313, 3.2061871556243515]], "duodenum": [[121.5002253116006, 5.0616837393558045], [13.60882943690214, 15.313999640884173], [13.60882943690214, 15.313999640884173], [13.60882943690214, 15.313999640884173]], "pancreas": [[182.85416969377923, 6.9039775525067135], [17.489564177159146, 14.924761571311656], [17.489564177159146, 14.924761571311656], [17.489564177159146, 14.924761571311656]], "liver (non abdomen window)": [[481.5690096331249, 8.413924027868077], [6.047563882283547, 6.86712354789198], [6.047563882283547, 6.86712354789198], [6.047563882283547, 6.86712354789198]], "liver": [[497.88613290346797, 8.79208581405346], [20.552757782824486, 16.312687320589742], [20.552757782824486, 16.312687320589742], [20.552757782824486, 16.312687320589742]], "spleen": [[496.77984794364835, 8.498216025126785], [14.594250163059534, 10.71357260923987], [14.594250163059534, 10.71357260923987], [14.594250163059534, 10.71357260923987]], "stomach": [[137.7555592980079, 3.928159238756134], [5.978844398494112, 10.238758157160921], [5.978844398494112, 10.238758157160921], [5.978844398494112, 10.238758157160921]], "gallbladder": [[109.56988864543307, 3.4765854683723596], [32.35084093358493, 41.113482214152384], [32.35084093358493, 41.113482214152384], [32.35084093358493, 41.113482214152384]], "left adrenal gland": [[121.60075395406241, 4.266683492995461], [17.017417548383662, 18.48528509828753], [17.017417548383662, 18.48528509828753], [17.017417548383662, 18.48528509828753]], "adrenal gland": [[182.4265613513338, 7.813186080282246], [18.97442893128976, 20.599617257380345], [18.97442893128976, 20.599617257380345], [18.97442893128976, 20.599617257380345]], "right adrenal gland": [[158.21570288963346, 5.736947411814261], [17.17089273745977, 19.09450167978653], [17.17089273745977, 19.09450167978653], [17.17089273745977, 19.09450167978653]], "bladder": [[172.667607742299, 4.6885066612866835], [42.56984081338662, 56.45115036285909], [42.56984081338662, 56.45115036285909], [42.56984081338662, 56.45115036285909]], "esophagus": [[253.86092392814248, 6.886078359154348], [13.252110919965341, 15.437200766467301], [13.252110919965341, 15.437200766467301], [13.252110919965341, 15.437200766467301]]}, "CT-Chest": {"nodule": [[115.14726334918862, 3.0043952160348844], [5.275338876748403, 7.899248653413393], [5.275338876748403, 7.899248653413393], [5.275338876748403, 7.899248653413393]], "COVID-19 infection": [[226.93782607812352, 10.662200522447263], [11.74323002038987, 23.773784082857407], [11.74323002038987, 23.773784082857407], [11.74323002038987, 23.773784082857407]], "tumor": [[81.39154648592063, 3.0363381821985254], [9.799683628807484, 19.248706134279548], [9.799683628807484, 19.248706134279548], [9.799683628807484, 19.248706134279548]]}, "MRI-Abdomen": {"aorta": [[840.9822169946456, 13.699556855062456], [2.9798604461548766, 1.19765659474954], [2.9798604461548766, 1.19765659474954], [2.9798604461548766, 1.19765659474954]], "postcava": [[151.3891903352374, 4.700455115571472], [3.065810750535689, 2.074722812609995], [3.065810750535689, 2.074722812609995], [3.065810750535689, 2.074722812609995]], "right kidney": [[613.4017011464975, 11.282616103318485], [4.63815461741129, 2.2967740371944867], [4.63815461741129, 2.2967740371944867], [4.63815461741129, 2.2967740371944867]], "duodenum": [[88.51851857758399, 5.251374959142798], [9.350910364523573, 8.85976960554745], [9.350910364523573, 8.85976960554745], [9.350910364523573, 8.85976960554745]], "kidney": [[831.5762248415444, 18.739059302777875], [5.715871882386201, 2.6205541393599527], [5.715871882386201, 2.6205541393599527], [5.715871882386201, 2.6205541393599527]], "left kidney": [[255.4744196400276, 5.573793361388763], [6.081920320421431, 2.930383603114708], [6.081920320421431, 2.930383603114708], [6.081920320421431, 2.930383603114708]], "liver": [[491.1931789168259, 9.294627086787225], [10.138029098677139, 6.28829088692463], [10.138029098677139, 6.28829088692463], [10.138029098677139, 6.28829088692463]], "pancreas": [[136.2304629992425, 5.676744286342953], [19.631392824605342, 11.528214201070567], [19.631392824605342, 11.528214201070567], [19.631392824605342, 11.528214201070567]], "gallbladder": [[75.18767252055355, 2.8711737605829892], [14.500831537679415, 20.696868858705496], [14.500831537679415, 20.696868858705496], [14.500831537679415, 20.696868858705496]], "stomach": [[89.16380420023327, 4.461224829090838], [10.266772743753412, 16.943404348738376], [10.266772743753412, 16.943404348738376], [10.266772743753412, 16.943404348738376]], "spleen": [[413.92566589639046, 7.99961594912814], [7.267087388529462, 5.149714876028216], [7.267087388529462, 5.149714876028216], [7.267087388529462, 5.149714876028216]], "left adrenal gland": [[86.44109991236728, 4.826813402237061], [17.153928230900817, 14.858036650050408], [17.153928230900817, 14.858036650050408], [17.153928230900817, 14.858036650050408]], "adrenal gland": [[303.9642820935704, 16.729857009916806], [19.500678047021523, 17.02588768312544], [19.500678047021523, 17.02588768312544], [19.500678047021523, 17.02588768312544]], "right adrenal gland": [[172.36803145644578, 8.050377438528958], [15.257519917725558, 13.431078702905772], [15.257519917725558, 13.431078702905772], [15.257519917725558, 13.431078702905772]], "esophagus": [[193.1348898340059, 7.6397334220243325], [12.240331385391299, 16.812971132953354], [12.240331385391299, 16.812971132953354], [12.240331385391299, 16.812971132953354]]}, "MRI-Cardiac": {"left heart ventricle": [[964.9072936969454, 17.21177762137991], [5.880290818671821, 4.100959742819713], [5.880290818671821, 4.100959742819713], [5.880290818671821, 4.100959742819713]], "myocardium": [[448.3393673888417, 17.591805257426998], [5.208511169313307, 15.910705163394415], [5.208511169313307, 15.910705163394415], [5.208511169313307, 15.910705163394415]], "right heart ventricle": [[359.88937669636215, 9.392153523781843], [5.924076424141962, 5.554667293878979], [5.924076424141962, 5.554667293878979], [5.924076424141962, 5.554667293878979]]}, "MRI-FLAIR-Brain": {"edema": [[69.4159007224176, 5.568921766085619], [13.400334168570177, 4.965265405638592], [13.400334168570177, 4.965265405638592], [13.400334168570177, 4.965265405638592]], "tumor core": [[154.26935124167449, 8.089254912853598], [14.908340542645478, 4.820086393609397], [14.908340542645478, 4.820086393609397], [14.908340542645478, 4.820086393609397]], "whole tumor": [[485.48717118600956, 16.01178236475156], [25.74323915508559, 8.636438181178145], [25.74323915508559, 8.636438181178145], [25.74323915508559, 8.636438181178145]]}, "MRI-T1-Gd-Brain": {"enhancing tumor": [[175.6437881777937, 7.539344668413025], [17.864705093992068, 5.36432831714689], [17.864705093992068, 5.36432831714689], [17.864705093992068, 5.36432831714689]], "non-enhancing tumor": [[37.6625733247702, 3.8454536110058246], [6.568014639412233, 8.446289690167484], [6.568014639412233, 8.446289690167484], [6.568014639412233, 8.446289690167484]], "tumor core": [[180.88223552813486, 6.610443841067055], [9.70294999498087, 5.30262880784197], [9.70294999498087, 5.30262880784197], [9.70294999498087, 5.30262880784197]]}, "Pathology": {"connective tissue cells": [[46.71165884847293, 4.997126203483956], [9.942495884846476, 15.700775443760845], [4.328453739888501, 18.42621798468577], [9.798096322131162, 11.920352021312304]], "inflammatory cells": [[39.600337990197595, 3.1848025413959706], [6.287418328538852, 20.538379638162322], [2.9521703595392146, 25.264465092284006], [6.559595490616054, 12.004686961917436]], "neoplastic cells": [[82.29374052289526, 8.22429924322936], [9.592296798563375, 14.818916788142138], [4.948629785308088, 19.78516221506478], [10.729094314024243, 12.934345198477494]], "epithelial cells": [[91.75183574899573, 9.577544361042948], [13.469843493323452, 27.305962287612964], [4.696928248406198, 25.254143364646463], [11.077634907582583, 13.487595094752443]]}, "X-Ray-Chest": {"left lung": [[529.1669758355144, 7.465035502868491], [8.220284641505614, 11.62958600654364], [8.220284641505614, 11.62958600654364], [8.220284641505614, 11.62958600654364]], "lung": [[465.7809501354513, 7.147122106450173], [8.781306299078446, 12.335455073688102], [8.781306299078446, 12.335455073688102], [8.781306299078446, 12.335455073688102]], "right lung": [[567.6127039725319, 7.532428563004494], [8.067311420424144, 11.229763331648746], [8.067311420424144, 11.229763331648746], [8.067311420424144, 11.229763331648746]]}, "Ultrasound-Cardiac": {"left heart atrium": [[1188.687550702627, 24.234766943758856], [5.18832820435626, 13.705576921752291], [5.18832820435626, 13.705576921752291], [5.18832820435626, 13.705576921752291]], "left heart ventricle": [[2787.334986695437, 58.297232816307506], [15.28158405889985, 56.95469460140377], [15.28158405889985, 56.95469460140377], [15.28158405889985, 56.95469460140377]]}, "Endoscopy": {"neoplastic polyp": [[392.89875472390315, 5.4678888279040745], [7.477729277754545, 1.6522601344780465], [7.2704247484339035, 6.347521355120636], [4.3902399436060335, 6.543658310376327]], "polyp": [[163.7838288028474, 3.4851615302599117], [7.03659746479883, 1.9088902542177986], [6.992807172875011, 6.756628353721484], [5.185761648208865, 8.977427344868255]], "non-neoplastic polyp": [[214.9199548332033, 4.360826895414348], [7.303363948417486, 1.9789835935004905], [10.54652900087687, 9.009706115553772], [6.917879576439251, 10.404634951284532]]}, "Fundus": {"optic cup": [[1482.9561484784422, 35.78105120937013], [52.1031548324398, 1.5080077510381715], [10.023538467761934, 3.1641925551155046], [3.394564722036805, 2.4391933423559626]], "optic disc": [[626.9141229495486, 20.95002931507066], [18.278454005466408, 1.8261365514325893], [16.42282430959315, 11.171338052048034], [4.8937792939550135, 6.987302868644637]]}, "Dermoscopy": {"lesion": [[134.43456931870887, 4.743684855379663], [5.18053578956456, 2.3527492367343634], [3.809383004477107, 6.368793378843402], [2.3888068456218847, 6.655396307215968]], "melanoma": [[454.17848530764076, 9.6466178116726], [4.022144360826467, 7.870140640677671], [4.87109613458874, 18.93721534855073], [3.107895746664011, 13.604075970992069]]}, "OCT": {"edema": [[260.11475018501574, 7.379315940573871], [4.162158474003, 17.437425953761988], [12.65808078622105, 81.37165793634547], [1.763378481483125, 4.427309203795247]]}}
main.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from huggingface_hub import hf_hub_download
8
+ from modeling.BaseModel import BaseModel
9
+ from modeling import build_model
10
+ from utilities.distributed import init_distributed
11
+ from utilities.arguments import load_opt_from_config_files
12
+ from utilities.constants import BIOMED_CLASSES
13
+ from inference_utils.inference import interactive_infer_image
14
+
15
+
16
+ def overlay_masks(image, masks, colors):
17
+ overlay = image.copy()
18
+ overlay = np.array(overlay, dtype=np.uint8)
19
+ for mask, color in zip(masks, colors):
20
+ overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype(
21
+ np.uint8
22
+ )
23
+ return Image.fromarray(overlay)
24
+
25
+
26
+ def generate_colors(n):
27
+ cmap = plt.get_cmap("tab10")
28
+ colors = [tuple(int(255 * val) for val in cmap(i)[:3]) for i in range(n)]
29
+ return colors
30
+
31
+
32
+ def init_model():
33
+ # Download model
34
+ model_file = hf_hub_download(
35
+ repo_id="microsoft/BiomedParse",
36
+ filename="biomedparse_v1.pt",
37
+ token=os.getenv("HF_TOKEN"),
38
+ )
39
+
40
+ # Initialize model
41
+ conf_files = "configs/biomedparse_inference.yaml"
42
+ opt = load_opt_from_config_files([conf_files])
43
+ opt = init_distributed(opt)
44
+
45
+ model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda()
46
+ with torch.no_grad():
47
+ model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
48
+ BIOMED_CLASSES + ["background"], is_eval=True
49
+ )
50
+
51
+ return model
52
+
53
+
54
+ def predict(image, prompts):
55
+ if not prompts:
56
+ return None
57
+
58
+ # Convert string input to list
59
+ prompts = [p.strip() for p in prompts.split(",")]
60
+
61
+ # Convert to RGB if needed
62
+ if image.mode != "RGB":
63
+ image = image.convert("RGB")
64
+
65
+ # Get predictions
66
+ pred_mask = interactive_infer_image(model, image, prompts)
67
+
68
+ # Generate visualization
69
+ colors = generate_colors(len(prompts))
70
+ pred_overlay = overlay_masks(
71
+ image, [1 * (pred_mask[i] > 0.5) for i in range(len(prompts))], colors
72
+ )
73
+
74
+ return pred_overlay
75
+
76
+
77
+ def run():
78
+ global model
79
+ model = init_model()
80
+
81
+ demo = gr.Interface(
82
+ fn=predict,
83
+ inputs=[
84
+ gr.Image(type="pil", label="Input Image"),
85
+ gr.Textbox(
86
+ label="Prompts",
87
+ placeholder="Enter prompts separated by commas (e.g., neoplastic cells, inflammatory cells)",
88
+ ),
89
+ ],
90
+ outputs=gr.Image(type="pil", label="Prediction"),
91
+ title="BiomedParse Demo",
92
+ description="Upload a biomedical image and enter prompts (separated by commas) to detect specific features.",
93
+ examples=[
94
+ [
95
+ "examples/Part_1_516_pathology_breast.png",
96
+ "neoplastic cells, inflammatory cells",
97
+ ]
98
+ ],
99
+ )
100
+
101
+ demo.launch(server_name="0.0.0.0", server_port=7860)
102
+
103
+
104
+ if __name__ == "__main__":
105
+ print(f"HF_TOKEN={os.getenv('HF_TOKEN')}")
106
+ run()
modeling/BaseModel.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from utilities.model import align_and_update_state_dicts
8
+
9
+ from utilities.distributed import init_distributed
10
+ from utilities.arguments import load_opt_from_config_files
11
+
12
+ import huggingface_hub
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class BaseModel(nn.Module):
18
+ def __init__(self, opt, module: nn.Module):
19
+ super(BaseModel, self).__init__()
20
+ self.opt = opt
21
+ self.model = module
22
+
23
+ def forward(self, *inputs, **kwargs):
24
+ outputs = self.model(*inputs, **kwargs)
25
+ return outputs
26
+
27
+ def save_pretrained(self, save_dir):
28
+ torch.save(self.model.state_dict(), os.path.join(save_dir, "model_state_dict.pt"))
29
+
30
+ def from_pretrained(self, pretrained, filename: str = "biomedparse_v1.pt",
31
+ local_dir: str = "./pretrained", config_dir: str = "./configs"):
32
+ if pretrained.startswith("hf_hub:"):
33
+ hub_name = pretrained.split(":")[1]
34
+ huggingface_hub.hf_hub_download(hub_name, filename=filename,
35
+ local_dir=local_dir)
36
+ huggingface_hub.hf_hub_download(hub_name, filename="config.yaml",
37
+ local_dir=config_dir)
38
+ load_dir = os.path.join(local_dir, filename)
39
+ else:
40
+ load_dir = pretrained
41
+
42
+ state_dict = torch.load(load_dir, map_location=self.opt['device'])
43
+ state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict)
44
+ self.model.load_state_dict(state_dict, strict=False)
45
+ return self
modeling/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .architectures import build_model
modeling/architectures/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .xdecoder_model import *
2
+ from .seem_model_v0 import *
3
+ from .seem_model_v1 import *
4
+ from .seem_model_demo import *
5
+ from .build import build_model
modeling/architectures/build.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _model_entrypoints = {}
2
+
3
+
4
+ def build_model(config, **kwargs):
5
+ model_name = config['MODEL']['NAME']
6
+
7
+ if not is_model(model_name):
8
+ raise ValueError(f'Unkown model: {model_name}')
9
+
10
+ return model_entrypoints(model_name)(config, **kwargs)
11
+
12
+ def register_model(fn):
13
+ module_name_split = fn.__module__.split('.')
14
+ model_name = module_name_split[-1]
15
+ _model_entrypoints[model_name] = fn
16
+ return fn
17
+
18
+ def model_entrypoints(model_name):
19
+ return _model_entrypoints[model_name]
20
+
21
+ def is_model(model_name):
22
+ return model_name in _model_entrypoints
modeling/architectures/seem_model_demo.py ADDED
@@ -0,0 +1,923 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # SEEM -- Segment Everything Everywhere All at Once
3
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
4
+ # Written by Xueyan Zou ([email protected])
5
+ # --------------------------------------------------------
6
+
7
+ import random
8
+ from typing import Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+ from kornia.contrib import distance_transform
15
+
16
+ from detectron2.structures import Boxes, ImageList, Instances, BitMasks
17
+ from detectron2.utils.memory import retry_if_cuda_oom
18
+ from detectron2.data import MetadataCatalog
19
+
20
+ from .build import register_model
21
+
22
+ from ..utils import configurable, get_class_names, get_iou
23
+ from ..vision.backbone import build_backbone, Backbone
24
+ from ..body import build_xdecoder_head
25
+ from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
26
+ from ..language import build_language_encoder
27
+ from ..language.loss import vl_similarity
28
+ from utilities.prompt_engineering import prompt_engineering
29
+ from utilities.constants import COCO_PANOPTIC_CLASSES
30
+
31
+
32
+ class GeneralizedSEEM(nn.Module):
33
+
34
+ @configurable
35
+ def __init__(
36
+ self,
37
+ *,
38
+ backbone: Backbone,
39
+ sem_seg_head: nn.Module,
40
+ criterion: nn.Module,
41
+ losses: dict,
42
+ num_queries: int,
43
+ object_mask_threshold: float,
44
+ overlap_threshold: float,
45
+ metadata,
46
+ task_switch: dict,
47
+ phrase_prob: float,
48
+ size_divisibility: int,
49
+ sem_seg_postprocess_before_inference: bool,
50
+ pixel_mean: Tuple[float],
51
+ pixel_std: Tuple[float],
52
+ # inference
53
+ semantic_on: bool,
54
+ panoptic_on: bool,
55
+ instance_on: bool,
56
+ test_topk_per_image: int,
57
+ train_dataset_name: str,
58
+ interactive_mode: str,
59
+ interactive_iter: str,
60
+ dilation_kernel: torch.Tensor,
61
+ ):
62
+ super().__init__()
63
+ self.backbone = backbone
64
+ self.sem_seg_head = sem_seg_head
65
+ self.criterion = criterion
66
+ self.losses = losses
67
+ self.num_queries = num_queries
68
+ self.overlap_threshold = overlap_threshold
69
+ self.object_mask_threshold = object_mask_threshold
70
+ self.metadata = metadata
71
+ if size_divisibility < 0:
72
+ # use backbone size_divisibility if not set
73
+ size_divisibility = self.backbone.size_divisibility
74
+ self.size_divisibility = size_divisibility
75
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
76
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
77
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
78
+
79
+ # additional args
80
+ self.semantic_on = semantic_on
81
+ self.instance_on = instance_on
82
+ self.panoptic_on = panoptic_on
83
+
84
+ # caption argument
85
+ self.task_switch = task_switch
86
+ self.phrase_prob = phrase_prob
87
+
88
+ self.test_topk_per_image = test_topk_per_image
89
+ self.train_class_names = None
90
+ self.interactive_mode = interactive_mode
91
+ self.interactive_iter = interactive_iter
92
+
93
+ if not self.semantic_on:
94
+ assert self.sem_seg_postprocess_before_inference
95
+
96
+ self.register_buffer("dilation_kernel", dilation_kernel)
97
+
98
+ @classmethod
99
+ def from_config(cls, cfg):
100
+ enc_cfg = cfg['MODEL']['ENCODER']
101
+ dec_cfg = cfg['MODEL']['DECODER']
102
+
103
+ openimage_switch = {'grounding': dec_cfg['OPENIMAGE']['GROUNDING'].get('ENABLED', False),
104
+ 'mask': dec_cfg['OPENIMAGE'].get('ENABLED', False)}
105
+
106
+ task_switch = {'bbox': dec_cfg.get('DETECTION', False),
107
+ 'mask': dec_cfg.get('MASK', True),
108
+ 'spatial': dec_cfg['SPATIAL'].get('ENABLED', False),
109
+ 'grounding': dec_cfg['GROUNDING'].get('ENABLED', False),
110
+ 'openimage': openimage_switch,
111
+ 'visual': dec_cfg['VISUAL'].get('ENABLED', False),
112
+ 'audio': dec_cfg['AUDIO'].get('ENABLED', False)}
113
+
114
+ # build model
115
+ extra = {'task_switch': task_switch}
116
+ backbone = build_backbone(cfg)
117
+ lang_encoder = build_language_encoder(cfg)
118
+ sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra=extra)
119
+
120
+ # Training Settings.
121
+ loss_weights = {}
122
+ matcher = None
123
+ losses = {}
124
+ weight_dict = {}
125
+ grd_weight = {}
126
+ top_x_layers = {}
127
+ criterion = None
128
+ train_dataset_name = None
129
+ phrase_prob = None
130
+ # Loss parameters:
131
+ deep_supervision = None
132
+ no_object_weight = None
133
+
134
+ interactive_mode = 'best'
135
+ interactive_iter = 20
136
+ dilation = 3
137
+ dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
138
+
139
+ return {
140
+ "backbone": backbone,
141
+ "sem_seg_head": sem_seg_head,
142
+ "criterion": criterion,
143
+ "losses": losses,
144
+ "num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
145
+ "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
146
+ "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
147
+ "metadata": None,
148
+ "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
149
+ "sem_seg_postprocess_before_inference": (
150
+ dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
151
+ or dec_cfg['TEST']['PANOPTIC_ON']
152
+ or dec_cfg['TEST']['INSTANCE_ON']
153
+ ),
154
+ "pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
155
+ "pixel_std": cfg['INPUT']['PIXEL_STD'],
156
+ "task_switch": task_switch,
157
+ "phrase_prob": phrase_prob,
158
+ # inference
159
+ "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
160
+ "instance_on": dec_cfg['TEST']['INSTANCE_ON'],
161
+ "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
162
+ "test_topk_per_image": cfg['MODEL']['DECODER']['TEST']['DETECTIONS_PER_IMAGE'],
163
+ "train_dataset_name": train_dataset_name,
164
+ "interactive_mode": interactive_mode,
165
+ "interactive_iter": interactive_iter,
166
+ "dilation_kernel": dilation_kernel,
167
+ }
168
+
169
+ @property
170
+ def device(self):
171
+ return self.pixel_mean.device
172
+
173
+ def forward(self, batched_inputs, mode='default'):
174
+ if self.training:
175
+ losses = {}
176
+ if self.task_switch['mask']:
177
+ losses_seg = self.forward_seg(batched_inputs)
178
+ losses.update(losses_seg)
179
+ if self.task_switch['openimage'] and self.task_switch['openimage']['mask']:
180
+ losses_openimage = self.forward_openimage(batched_inputs['openimage'])
181
+ losses_openimage = {key.replace('mask', 'openimage'):value for key, value in losses_openimage.items()}
182
+ losses_openimage = {key.replace('grounding', 'grounding_openimage'):value for key, value in losses_openimage.items()}
183
+ losses.update(losses_openimage)
184
+ for k in list(losses.keys()):
185
+ if k in self.criterion.weight_dict:
186
+ losses[k] *= self.criterion.weight_dict[k]
187
+ else: # remove this loss if not specified in `weight_dict`
188
+ losses.pop(k)
189
+ return losses
190
+ else:
191
+ if mode == 'interactive':
192
+ return self.evaluate_interactive(batched_inputs)
193
+ elif mode == 'grounding_spatial':
194
+ return self.evaluate_grounding_sptial(batched_inputs, mode)
195
+ elif mode in ['grounding_phrasecut', 'grounding_refcoco']:
196
+ return self.evaluate_grounding(batched_inputs, mode)
197
+ else:
198
+ return self.evaluate(batched_inputs)
199
+
200
+
201
+ def forward_seg(self, batched_inputs):
202
+ images = [x["image"].to(self.device) for x in batched_inputs]
203
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
204
+ images = ImageList.from_tensors(images, self.size_divisibility)
205
+
206
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
207
+
208
+ extra = {}
209
+ # mask classification target
210
+ if "instances" in batched_inputs[0]:
211
+ # input bounding box is checked to be correct.
212
+ targets = self.prepare_targets(batched_inputs, images)
213
+
214
+ if self.task_switch['grounding']:
215
+ grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
216
+ grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens, padding_value=-1)
217
+ non_zero_query_mask = (grounding_tokens.sum(dim=-1) == -grounding_tokens.shape[-1])
218
+ grounding_tokens[non_zero_query_mask] = 0
219
+
220
+ extra['grounding_tokens'] = grounding_tokens
221
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
222
+
223
+ if self.task_switch['spatial']:
224
+ pos_masks = [x['spatial_query']['rand_shape'].to(self.device) for x in batched_inputs]
225
+ neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs]
226
+ fp_masks = torch.stack([(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs])
227
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks, 'false_positive_mask': fp_masks})
228
+
229
+ features = self.backbone(images.tensor)
230
+ mask_features, _, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
231
+
232
+ # forward spatial only without gradient
233
+ if self.task_switch['spatial']:
234
+ with torch.no_grad():
235
+ # generate random integeter between [0,3]
236
+ rand_iter_num = random.randint(0, 2)
237
+ for i in range(rand_iter_num):
238
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='spatial')
239
+ extra.update(outputs)
240
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
241
+
242
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='seg')
243
+ extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
244
+ 'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default')),
245
+ 'false_positive_mask': extra['false_positive_mask']}
246
+ # bipartite matching-based loss
247
+ self.criterion.losses = self.losses['seg'] # seg criterion losses
248
+ losses = self.criterion(outputs, targets, extra)
249
+
250
+ del outputs
251
+ return losses
252
+
253
+ def evaluate_demo(self, batched_inputs):
254
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
255
+ images = [x["image"].to(self.device) for x in batched_inputs]
256
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
257
+ images = ImageList.from_tensors(images, self.size_divisibility)
258
+ img_bs = images.tensor.shape[0]
259
+
260
+ targets = targets_grounding = queries_grounding = None
261
+ features = self.backbone(images.tensor)
262
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
263
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
264
+
265
+ extra = {}
266
+ if 'stroke' in batched_inputs[0]:
267
+ pos_masks = (batched_inputs[0]['stroke'].to(self.device)).unbind(0)
268
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
269
+ neg_masks = (batched_inputs[0]['stroke'].to(self.device) & False).unbind(0)
270
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
271
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
272
+
273
+ if 'visual' in batched_inputs[0]:
274
+ extra.update(batched_inputs[0]['visual'])
275
+
276
+ if 'text' in batched_inputs[0]:
277
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(batched_inputs[0]['text'], name='grounding', token=False, norm=False)
278
+ token_emb = gtext['token_emb']
279
+ tokens = gtext['tokens']
280
+ query_emb = token_emb[tokens['attention_mask'].bool()]
281
+ non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
282
+ extra['grounding_tokens'] = query_emb[:,None]
283
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
284
+ extra['grounding_class'] = gtext['class_emb']
285
+
286
+ if 'audio' in batched_inputs[0]:
287
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(batched_inputs[0]['audio'], name='grounding', token=False, norm=False)
288
+ token_emb = gtext['token_emb']
289
+ tokens = gtext['tokens']
290
+ query_emb = token_emb[tokens['attention_mask'].bool()]
291
+ non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
292
+ extra['audio_tokens'] = query_emb[:,None]
293
+ extra['audio_nonzero_mask'] = non_zero_query_mask.t()
294
+ extra['audio_class'] = gtext['class_emb']
295
+
296
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='demo')
297
+ return outputs, images.tensor.shape, extra
298
+
299
+ assert self.task_switch['spatial']
300
+ assert 'spatial_query' in batched_inputs[0]
301
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
302
+
303
+ images = [x["image"].to(self.device) for x in batched_inputs]
304
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
305
+ images = ImageList.from_tensors(images, self.size_divisibility)
306
+ img_bs = images.tensor.shape[0]
307
+
308
+ targets = targets_grounding = queries_grounding = None
309
+ extra = {}
310
+
311
+ features = self.backbone(images.tensor)
312
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
313
+
314
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
315
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
316
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
317
+ mask_features = mask_features.repeat(nm,1,1,1)
318
+
319
+ all_batch_shape_iou = []
320
+ pred_smask_pointer = None
321
+ prev_smask_pointer = None
322
+ pred_smask_all = None
323
+
324
+ query_index = self.sem_seg_head.predictor.query_index
325
+ assert self.interactive_mode == 'best'
326
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
327
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
328
+
329
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
330
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
331
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
332
+
333
+ for i in range(self.interactive_iter):
334
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
335
+ extra.update(outputs)
336
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
337
+
338
+ s = image_sizes[0]
339
+ b = batched_inputs[0]
340
+ pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
341
+ gt_smask = b['gt_masks_orisize']
342
+ all_batch_shape_iou += [get_iou(gt_smask, pred_smask_all)]
343
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
344
+
345
+ all_batch_shape_iou = torch.stack(all_batch_shape_iou)
346
+ processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
347
+ return processed_results
348
+
349
+ def evaluate(self, batched_inputs):
350
+ images = [x["image"].to(self.device) for x in batched_inputs]
351
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
352
+
353
+ images = ImageList.from_tensors(images, self.size_divisibility)
354
+ img_bs = images.tensor.shape[0]
355
+
356
+ targets = targets_grounding = queries_grounding = None
357
+ features = self.backbone(images.tensor)
358
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
359
+
360
+ mask_cls_results = outputs["pred_logits"]
361
+ mask_pred_results = outputs["pred_masks"]
362
+ box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
363
+
364
+ # upsample masks
365
+ mask_pred_results = F.interpolate(
366
+ mask_pred_results,
367
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
368
+ mode="bilinear",
369
+ align_corners=False,
370
+ )
371
+
372
+ input_size = mask_pred_results.shape[-2:]
373
+ del outputs
374
+
375
+ processed_results = []
376
+ for mask_cls_result, mask_pred_result, box_pred_result, input_per_image, image_size in zip(
377
+ mask_cls_results, mask_pred_results, box_pred_results, batched_inputs, images.image_sizes
378
+ ):
379
+ height = input_per_image.get("height", image_size[0])
380
+ width = input_per_image.get("width", image_size[1])
381
+ processed_results.append({})
382
+
383
+ if self.sem_seg_postprocess_before_inference:
384
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
385
+ mask_pred_result, image_size, height, width
386
+ )
387
+ mask_cls_result = mask_cls_result.to(mask_pred_result)
388
+
389
+ # semantic segmentation inference
390
+ if self.semantic_on:
391
+ r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
392
+ if not self.sem_seg_postprocess_before_inference:
393
+ r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
394
+ processed_results[-1]["sem_seg"] = r
395
+
396
+ # panoptic segmentation inference
397
+ if self.panoptic_on:
398
+ panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
399
+ processed_results[-1]["panoptic_seg"] = panoptic_r
400
+
401
+ # instance segmentation inference
402
+ if self.instance_on:
403
+ if self.task_switch['bbox']:
404
+ box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
405
+ instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
406
+ processed_results[-1]["instances"] = instance_r
407
+
408
+ return processed_results
409
+
410
+ def evaluate_interactive(self, batched_inputs):
411
+ assert self.task_switch['spatial']
412
+ assert 'spatial_query' in batched_inputs[0]
413
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
414
+
415
+ images = [x["image"].to(self.device) for x in batched_inputs]
416
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
417
+ images = ImageList.from_tensors(images, self.size_divisibility)
418
+ img_bs = images.tensor.shape[0]
419
+
420
+ targets = targets_grounding = queries_grounding = None
421
+ extra = {}
422
+
423
+ features = self.backbone(images.tensor)
424
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
425
+
426
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
427
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
428
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
429
+ mask_features = mask_features.repeat(nm,1,1,1)
430
+
431
+ all_batch_shape_iou = []
432
+ pred_smask_pointer = None
433
+ prev_smask_pointer = None
434
+ pred_smask_all = None
435
+
436
+ query_index = self.sem_seg_head.predictor.query_index
437
+ assert self.interactive_mode == 'best'
438
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
439
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
440
+
441
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
442
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
443
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
444
+
445
+ for i in range(self.interactive_iter):
446
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
447
+ extra.update(outputs)
448
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
449
+
450
+ s = image_sizes[0]
451
+ b = batched_inputs[0]
452
+ pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
453
+ gt_smask = b['gt_masks_orisize']
454
+ all_batch_shape_iou += [get_iou(gt_smask, pred_smask_all)]
455
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
456
+
457
+ all_batch_shape_iou = torch.stack(all_batch_shape_iou)
458
+ processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
459
+ return processed_results
460
+
461
+ def evaluate_referring_image(self, batched_inputs, extra={}):
462
+ assert self.task_switch['spatial']
463
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
464
+ assert self.interactive_mode == 'best'
465
+
466
+ images = [x["image"].to(self.device) for x in batched_inputs]
467
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
468
+ images = ImageList.from_tensors(images, self.size_divisibility)
469
+ img_bs = images.tensor.shape[0]
470
+
471
+ targets = targets_grounding = queries_grounding = None
472
+ features = self.backbone(images.tensor)
473
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
474
+
475
+ if 'spatial_query' in batched_inputs[0]:
476
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
477
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
478
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
479
+ mask_features = mask_features.repeat(nm,1,1,1)
480
+
481
+ query_index = self.sem_seg_head.predictor.query_index
482
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
483
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
484
+
485
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
486
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
487
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
488
+
489
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='refimg')
490
+ return outputs, images.tensor.shape
491
+
492
+ def evaluate_grounding(self, batched_inputs, mode):
493
+ images = [x["image"].to(self.device) for x in batched_inputs]
494
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
495
+ images = ImageList.from_tensors(images, self.size_divisibility)
496
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
497
+
498
+ extra = {}
499
+ # mask_pred_results = []
500
+ # for idx, batch_per_image in enumerate(batched_inputs):
501
+ # grd_texts = batch_per_image['groundings']['texts']
502
+ # grd_masks = []
503
+ # for anno_text in grd_texts:
504
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
505
+ # token_emb = gtext['token_emb']
506
+ # tokens = gtext['tokens']
507
+
508
+ # grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
509
+ # extra['grounding_tokens'] = grd_emb[:,None]
510
+
511
+ # assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
512
+ # features = self.backbone(images.tensor)
513
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
514
+
515
+ # pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
516
+ # v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
517
+ # t_emb = grd_emb[-1:]
518
+
519
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
520
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
521
+
522
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
523
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
524
+
525
+ # matched_id = out_prob.max(0)[1]
526
+ # grd_masks += [pred_gmasks[matched_id,:,:]]
527
+ # mask_pred_results += [torch.cat(grd_masks)]
528
+
529
+ # comment for multi object inference.
530
+ mask_pred_results = []
531
+ for idx, batch_per_image in enumerate(batched_inputs):
532
+ grd_texts = batch_per_image['groundings']['texts']
533
+ grd_texts = [x[0] for x in grd_texts]
534
+
535
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
536
+ token_emb = gtext['token_emb']
537
+ tokens = gtext['tokens']
538
+ query_emb = token_emb[tokens['attention_mask'].bool()]
539
+ non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
540
+
541
+ extra['grounding_tokens'] = query_emb[:,None]
542
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
543
+
544
+ features = self.backbone(images.tensor)
545
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
546
+
547
+ pred_gmasks = outputs['pred_gmasks'][idx]
548
+ v_emb = outputs['pred_gtexts'][idx]
549
+ t_emb = gtext['class_emb']
550
+
551
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
552
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
553
+
554
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
555
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
556
+
557
+ matched_id = out_prob.max(0)[1]
558
+ mask_pred_results += [pred_gmasks[matched_id,:,:]]
559
+
560
+ for i in range(len(mask_pred_results)):
561
+ # upsample masks
562
+ mask_pred_results[i] = F.interpolate(
563
+ mask_pred_results[i][None,],
564
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
565
+ mode="bilinear",
566
+ align_corners=False,
567
+ )[0]
568
+
569
+ processed_results = []
570
+ for mask_pred_result, input_per_image, image_size in zip(
571
+ mask_pred_results, batched_inputs, images.image_sizes
572
+ ):
573
+ height = input_per_image.get("height", image_size[0])
574
+ width = input_per_image.get("width", image_size[1])
575
+ processed_results.append({})
576
+
577
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
578
+ mask_pred_result, image_size, height, width
579
+ )
580
+ processed_results[-1]['grounding_mask'] = mask_pred_result
581
+
582
+ # compute bbox
583
+ # bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
584
+ # bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
585
+ # processed_results[-1]['grounding_box'] = bbox
586
+
587
+ return processed_results
588
+
589
+ def evaluate_grounding_sptial(self, batched_inputs, mode):
590
+ images = [x["image"].to(self.device) for x in batched_inputs]
591
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
592
+ images = ImageList.from_tensors(images, self.size_divisibility)
593
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
594
+
595
+ extra = {}
596
+ dilation = 3
597
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
598
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
599
+ pos_masks = (F.conv2d(pos_masks.float(), self.dilation_kernel, padding=dilation//2) > 0).unbind(0)
600
+
601
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
602
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
603
+
604
+ mask_pred_results = []
605
+ for idx, batch_per_image in enumerate(batched_inputs):
606
+ grd_texts = batch_per_image['groundings']['texts']
607
+ grd_masks = []
608
+ for idx2, anno_text in enumerate(grd_texts):
609
+ extra.update({'spatial_query_pos_mask': [pos_masks[idx2]], 'spatial_query_neg_mask': [neg_masks[idx2]]})
610
+
611
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
612
+ token_emb = gtext['token_emb']
613
+ tokens = gtext['tokens']
614
+
615
+ grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
616
+ non_zero_query_mask = torch.zeros(grd_emb[:,None].shape[:-1], dtype=torch.bool, device=grd_emb.device)
617
+ extra['grounding_tokens'] = grd_emb[:,None]
618
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
619
+
620
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
621
+ features = self.backbone(images.tensor)
622
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
623
+
624
+ pred_gmasks = outputs['pred_gmasks'][idx]
625
+ v_emb = outputs['pred_gtexts'][idx]
626
+ t_emb = gtext['class_emb']
627
+
628
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
629
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
630
+
631
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
632
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
633
+
634
+ matched_id = out_prob.max(0)[1]
635
+ grd_masks += [pred_gmasks[matched_id,:,:]]
636
+ mask_pred_results += [torch.cat(grd_masks)]
637
+
638
+ # comment for multi object inference.
639
+ # mask_pred_results = []
640
+ # for idx, batch_per_image in enumerate(batched_inputs):
641
+ # grd_texts = batch_per_image['groundings']['texts']
642
+ # grd_texts = [x[0] for x in grd_texts]
643
+
644
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
645
+ # token_emb = gtext['token_emb']
646
+ # tokens = gtext['tokens']
647
+ # query_emb = token_emb[tokens['attention_mask'].bool()]
648
+ # non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
649
+
650
+ # extra['grounding_tokens'] = query_emb[:,None]
651
+ # extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
652
+
653
+ # features = self.backbone(images.tensor)
654
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
655
+
656
+ # pred_gmasks = outputs['pred_gmasks'][idx]
657
+ # v_emb = outputs['pred_gtexts'][idx]
658
+ # t_emb = gtext['class_emb']
659
+
660
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
661
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
662
+
663
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
664
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
665
+
666
+ # matched_id = out_prob.max(0)[1]
667
+ # mask_pred_results += [pred_gmasks[matched_id,:,:]]
668
+
669
+ for i in range(len(mask_pred_results)):
670
+ # upsample masks
671
+ mask_pred_results[i] = F.interpolate(
672
+ mask_pred_results[i][None,],
673
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
674
+ mode="bilinear",
675
+ align_corners=False,
676
+ )[0]
677
+
678
+ processed_results = []
679
+ for mask_pred_result, input_per_image, image_size in zip(
680
+ mask_pred_results, batched_inputs, images.image_sizes
681
+ ):
682
+ height = input_per_image.get("height", image_size[0])
683
+ width = input_per_image.get("width", image_size[1])
684
+ processed_results.append({})
685
+
686
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
687
+ mask_pred_result, image_size, height, width
688
+ )
689
+ processed_results[-1]['grounding_mask'] = mask_pred_result
690
+
691
+ return processed_results
692
+
693
+ def prepare_targets(self, batched_inputs, images):
694
+ h_pad, w_pad = images.tensor.shape[-2:]
695
+ new_targets = []
696
+ for idx, batch_per_image in enumerate(batched_inputs):
697
+ targets_per_image = batch_per_image['instances'].to(self.device)
698
+ # pad gt
699
+ gt_masks = targets_per_image.gt_masks.tensor
700
+ padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
701
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
702
+
703
+ gt_boxes = targets_per_image.gt_boxes.tensor
704
+ ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
705
+ gt_boxes = gt_boxes / ratio
706
+ xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
707
+ gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
708
+
709
+ target_dict = {
710
+ "labels": targets_per_image.gt_classes,
711
+ "is_things": targets_per_image.is_things,
712
+ "masks": padded_masks,
713
+ "boxes": gt_boxes,
714
+ }
715
+
716
+ if self.task_switch['spatial']:
717
+ # prepare targets for spatial query
718
+ target_dict['gt_spatial_masks'] = batch_per_image['spatial_query']['gt_masks']
719
+
720
+ if self.task_switch['grounding']:
721
+ grd_masks = batch_per_image['groundings']['masks']
722
+ grd_texts = batch_per_image['groundings']['texts']
723
+ grd_hash = batch_per_image['groundings']['hash']
724
+ grd_task = batch_per_image['groundings']['mode']
725
+
726
+ if len(grd_masks) == 0:
727
+ padded_masks = None
728
+ else:
729
+ padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
730
+ padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
731
+
732
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
733
+ token_emb = gtext['token_emb']
734
+ tokens = gtext['tokens']
735
+
736
+ unique_hash_id = np.unique(grd_hash, return_index=True)[1]
737
+ selected_mask = np.zeros(len(grd_hash)).astype(bool)
738
+ selected_mask[unique_hash_id] = True
739
+
740
+ selected_token_emb = token_emb[selected_mask]
741
+ selected_attn_mask = tokens['attention_mask'][selected_mask]
742
+ query_emb = selected_token_emb[selected_attn_mask.bool()]
743
+
744
+ class_idx = tokens['attention_mask'].sum(dim=-1) - 1
745
+ class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
746
+ class_emb = token_emb[class_idx]
747
+
748
+ target_dict['grounding_masks'] = padded_masks
749
+ target_dict['grounding_query_embs'] = query_emb
750
+ target_dict['grounding_class_embs'] = class_emb
751
+ target_dict['grounding_hash'] = grd_hash
752
+ target_dict['grounding_task'] = grd_task
753
+
754
+ new_targets.append(target_dict)
755
+ return new_targets
756
+
757
+ def prepare_next_spaital_mask(self, outputs, batched_inputs):
758
+ gt_masks = [batched_inputs[i]['spatial_query']['gt_masks'] for i in range(len(batched_inputs))]
759
+ if self.training:
760
+ gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor
761
+ else:
762
+ gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor.transpose(0,1)
763
+
764
+ pred_masks = (F.interpolate(outputs['prev_mask'], size=gt_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5)
765
+ prev_masks = torch.stack(outputs['spatial_query_pos_mask']) | torch.stack(outputs['spatial_query_neg_mask'])
766
+
767
+ fn = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks) # fn: False Negative, gt:1, pred:0, prev:0
768
+ fp = (~gt_masks & pred_masks) & (~prev_masks) # fp: False Positive, gt:0, pred:1, prev:0
769
+
770
+ # compute iou between gt and pred
771
+ iou = (gt_masks & pred_masks).sum(list(range(1,len(fn.shape)))) / ((gt_masks | pred_masks).sum(dim=list(range(1,len(fn.shape)))) + 1e-8)
772
+ fn_sum = fn.sum(dim=list(range(1,len(fn.shape))))
773
+ fp_sum = fp.sum(dim=list(range(1,len(fp.shape))))
774
+
775
+ is_postive = fn_sum > fp_sum
776
+ # is_postive = torch.ones(len(fn_sum), device=torch.cuda.current_device()).bool()
777
+ select_mask = torch.stack([fn[i] if is_postive[i] else fp[i] for i in range(len(fn))])
778
+
779
+ # conv implementation
780
+ n,_,h,w=select_mask.shape
781
+ mask_dt = (distance_transform((~F.pad(select_mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(n,-1)
782
+ max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
783
+ next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
784
+ next_mask = next_mask.view(n,-1)
785
+ next_mask[max_xy_idx] = True
786
+ next_mask = next_mask.reshape((n,1,h,w)).float()
787
+ dilation = 3
788
+ next_mask = F.conv2d(next_mask, self.dilation_kernel, padding=dilation//2) > 0
789
+
790
+ # determine whether next mask is zero
791
+ keep = (iou < 0.925)
792
+ next_mask = next_mask & keep.view(-1,1,1,1)
793
+
794
+ pos_mask = []
795
+ neg_mask = []
796
+ for idx, ip in enumerate(is_postive):
797
+ if ip:
798
+ pos_mask += [outputs['spatial_query_pos_mask'][idx] | next_mask[idx]]
799
+ neg_mask += [outputs['spatial_query_neg_mask'][idx]]
800
+ else:
801
+ pos_mask += [outputs['spatial_query_pos_mask'][idx]]
802
+ neg_mask += [outputs['spatial_query_neg_mask'][idx] | next_mask[idx]]
803
+
804
+ if 'false_positive_mask' in outputs:
805
+ fp = outputs['false_positive_mask'] | fp
806
+ return {'spatial_query_pos_mask': pos_mask, 'spatial_query_neg_mask': neg_mask, 'false_positive_mask': fp}
807
+
808
+ def semantic_inference(self, mask_cls, mask_pred):
809
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
810
+ mask_pred = mask_pred.sigmoid()
811
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
812
+ return semseg
813
+
814
+ def panoptic_inference(self, mask_cls, mask_pred):
815
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
816
+ mask_pred = mask_pred.sigmoid()
817
+
818
+ keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
819
+ cur_scores = scores[keep]
820
+ cur_classes = labels[keep]
821
+ cur_masks = mask_pred[keep]
822
+ cur_mask_cls = mask_cls[keep]
823
+ cur_mask_cls = cur_mask_cls[:, :-1]
824
+
825
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
826
+
827
+ h, w = cur_masks.shape[-2:]
828
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
829
+ segments_info = []
830
+
831
+ current_segment_id = 0
832
+
833
+ if cur_masks.shape[0] == 0:
834
+ # We didn't detect any mask :(
835
+ return panoptic_seg, segments_info
836
+ else:
837
+ # take argmax
838
+ cur_mask_ids = cur_prob_masks.argmax(0)
839
+ stuff_memory_list = {}
840
+ for k in range(cur_classes.shape[0]):
841
+ pred_class = cur_classes[k].item()
842
+ isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
843
+ mask_area = (cur_mask_ids == k).sum().item()
844
+ original_area = (cur_masks[k] >= 0.5).sum().item()
845
+ mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
846
+
847
+ if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
848
+ if mask_area / original_area < self.overlap_threshold:
849
+ continue
850
+
851
+ # merge stuff regions
852
+ if not isthing:
853
+ if int(pred_class) in stuff_memory_list.keys():
854
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
855
+ continue
856
+ else:
857
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
858
+
859
+ current_segment_id += 1
860
+ panoptic_seg[mask] = current_segment_id
861
+
862
+ segments_info.append(
863
+ {
864
+ "id": current_segment_id,
865
+ "isthing": bool(isthing),
866
+ "category_id": int(pred_class),
867
+ }
868
+ )
869
+
870
+ return panoptic_seg, segments_info
871
+
872
+ def instance_inference(self, mask_cls, mask_pred, box_pred):
873
+ # mask_pred is already processed to have the same shape as original input
874
+ image_size = mask_pred.shape[-2:]
875
+
876
+ # [Q, K]
877
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
878
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
879
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
880
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
881
+
882
+ labels_per_image = labels[topk_indices]
883
+ topk_indices = (topk_indices // self.sem_seg_head.num_classes)
884
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
885
+ mask_pred = mask_pred[topk_indices]
886
+ if box_pred is not None:
887
+ box_pred = box_pred[topk_indices]
888
+
889
+ # if this is panoptic segmentation, we only keep the "thing" classes
890
+ if self.panoptic_on:
891
+ keep = torch.zeros_like(scores_per_image).bool()
892
+ for i, lab in enumerate(labels_per_image):
893
+ keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
894
+
895
+ scores_per_image = scores_per_image[keep]
896
+ labels_per_image = labels_per_image[keep]
897
+ mask_pred = mask_pred[keep]
898
+
899
+ if box_pred is not None:
900
+ box_pred = box_pred[keep]
901
+
902
+ result = Instances(image_size)
903
+ # mask (before sigmoid)
904
+ result.pred_masks = (mask_pred > 0).float()
905
+ # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
906
+ # Uncomment the following to get boxes from masks (this is slow)
907
+
908
+ if box_pred is not None:
909
+ result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
910
+ else:
911
+ result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
912
+
913
+ # calculate average mask prob
914
+ mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
915
+ result.scores = scores_per_image * mask_scores_per_image
916
+ result.pred_classes = labels_per_image
917
+
918
+ return result
919
+
920
+
921
+ @register_model
922
+ def get_seem_model(cfg, **kwargs):
923
+ return GeneralizedSEEM(cfg)
modeling/architectures/seem_model_v0.py ADDED
@@ -0,0 +1,1160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # SEEM -- Segment Everything Everywhere All at Once
3
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
4
+ # Written by Xueyan Zou ([email protected])
5
+ # --------------------------------------------------------
6
+
7
+ import random
8
+ from typing import Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+ from kornia.contrib import distance_transform
15
+
16
+ from detectron2.structures import Boxes, ImageList, Instances, BitMasks
17
+ from detectron2.utils.memory import retry_if_cuda_oom
18
+ from detectron2.data import MetadataCatalog
19
+
20
+ from .build import register_model
21
+
22
+ from ..utils import configurable, get_class_names, get_iou
23
+ from ..vision.backbone import build_backbone, Backbone
24
+ from ..body import build_xdecoder_head
25
+ from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
26
+ from ..language import build_language_encoder
27
+ from ..language.loss import vl_similarity
28
+ from utilities.prompt_engineering import prompt_engineering
29
+ from utilities.constants import COCO_PANOPTIC_CLASSES
30
+
31
+
32
+ class GeneralizedSEEM(nn.Module):
33
+
34
+ @configurable
35
+ def __init__(
36
+ self,
37
+ *,
38
+ backbone: Backbone,
39
+ sem_seg_head: nn.Module,
40
+ criterion: nn.Module,
41
+ losses: dict,
42
+ num_queries: int,
43
+ object_mask_threshold: float,
44
+ overlap_threshold: float,
45
+ metadata,
46
+ task_switch: dict,
47
+ phrase_prob: float,
48
+ size_divisibility: int,
49
+ sem_seg_postprocess_before_inference: bool,
50
+ pixel_mean: Tuple[float],
51
+ pixel_std: Tuple[float],
52
+ # inference
53
+ semantic_on: bool,
54
+ panoptic_on: bool,
55
+ instance_on: bool,
56
+ test_topk_per_image: int,
57
+ train_dataset_name: str,
58
+ interactive_mode: str,
59
+ interactive_iter: str,
60
+ dilation_kernel: torch.Tensor,
61
+ train_max_iter: int,
62
+ ):
63
+ """
64
+ Args:
65
+ backbone: a backbone module, must follow detectron2's backbone interface
66
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
67
+ criterion: a module that defines the loss
68
+ num_queries: int, number of queries
69
+ object_mask_threshold: float, threshold to filter query based on classification score
70
+ for panoptic segmentation inference
71
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
72
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
73
+ segmentation inference
74
+ size_divisibility: Some backbones require the input height and width to be divisible by a
75
+ specific integer. We can use this to override such requirement.
76
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
77
+ to original input size before semantic segmentation inference or after.
78
+ For high-resolution dataset like Mapillary, resizing predictions before
79
+ inference will cause OOM error.
80
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
81
+ the per-channel mean and std to be used to normalize the input image
82
+ semantic_on: bool, whether to output semantic segmentation prediction
83
+ instance_on: bool, whether to output instance segmentation prediction
84
+ panoptic_on: bool, whether to output panoptic segmentation prediction
85
+ test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
86
+ """
87
+ super().__init__()
88
+ self.backbone = backbone
89
+ self.sem_seg_head = sem_seg_head
90
+ self.criterion = criterion
91
+ self.losses = losses
92
+ self.num_queries = num_queries
93
+ self.overlap_threshold = overlap_threshold
94
+ self.object_mask_threshold = object_mask_threshold
95
+ self.metadata = metadata
96
+ if size_divisibility < 0:
97
+ # use backbone size_divisibility if not set
98
+ size_divisibility = self.backbone.size_divisibility
99
+ self.size_divisibility = size_divisibility
100
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
101
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
102
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
103
+
104
+ # additional args
105
+ self.semantic_on = semantic_on
106
+ self.instance_on = instance_on
107
+ self.panoptic_on = panoptic_on
108
+
109
+ # caption argument
110
+ self.task_switch = task_switch
111
+ self.phrase_prob = phrase_prob
112
+ self.train_max_iter = train_max_iter
113
+
114
+ self.test_topk_per_image = test_topk_per_image
115
+ self.train_class_names = get_class_names(train_dataset_name)
116
+ self.interactive_mode = interactive_mode
117
+ self.interactive_iter = interactive_iter
118
+
119
+ if not self.semantic_on:
120
+ assert self.sem_seg_postprocess_before_inference
121
+
122
+ self.register_buffer("dilation_kernel", dilation_kernel)
123
+
124
+ @classmethod
125
+ def from_config(cls, cfg):
126
+ enc_cfg = cfg['MODEL']['ENCODER']
127
+ dec_cfg = cfg['MODEL']['DECODER']
128
+
129
+ # Loss parameters:
130
+ deep_supervision = dec_cfg['DEEP_SUPERVISION']
131
+ no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']
132
+
133
+ # loss weights
134
+ loss_weights = {'mask': {'ce': dec_cfg['CLASS_WEIGHT'], 'dice': dec_cfg['DICE_WEIGHT'], 'bce': dec_cfg['MASK_WEIGHT']},
135
+ 'bbox': {'l1': dec_cfg['BBOX_WEIGHT'], 'giou': dec_cfg['GIOU_WEIGHT']},
136
+ 'spatial': {'ce': dec_cfg['SCLASS_WEIGHT'], 'dice': dec_cfg['SDICE_WEIGHT'], 'bce': dec_cfg['SMASK_WEIGHT']},
137
+ 'grounding': {'ce': dec_cfg['GCLASS_WEIGHT'], 'dice': dec_cfg['GDICE_WEIGHT'], 'bce': dec_cfg['GMASK_WEIGHT']},
138
+ 'openimage': {'ce': dec_cfg['OCLASS_WEIGHT'], 'dice': dec_cfg['ODICE_WEIGHT'], 'bce': dec_cfg['OMASK_WEIGHT']}}
139
+
140
+ openimage_switch = {'grounding': dec_cfg['OPENIMAGE']['GROUNDING'].get('ENABLED', False),
141
+ 'mask': dec_cfg['OPENIMAGE'].get('ENABLED', False)}
142
+
143
+ task_switch = {'bbox': dec_cfg.get('DETECTION', False),
144
+ 'mask': dec_cfg['MASK'].get('ENABLED', True),
145
+ 'spatial': dec_cfg['SPATIAL'].get('ENABLED', False),
146
+ 'grounding': dec_cfg['GROUNDING'].get('ENABLED', False),
147
+ 'openimage': openimage_switch}
148
+
149
+ top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10),
150
+ 'grounding': dec_cfg.get('TOP_GROUNDING_LAYERS', 10),
151
+ 'openimage': dec_cfg.get('TOP_OPENIMAGE_LAYERS', 10),
152
+ 'spatial': dec_cfg.get('TOP_SPATIAL_LAYERS', 10)}
153
+
154
+ spatial_cost = {"class_weight": dec_cfg['COST_SPATIAL']['CLASS_WEIGHT'],
155
+ "mask_weight": dec_cfg['COST_SPATIAL']['MASK_WEIGHT'],
156
+ "dice_weight": dec_cfg['COST_SPATIAL']['DICE_WEIGHT']}
157
+
158
+ extra = {'task_switch': task_switch}
159
+ backbone = build_backbone(cfg)
160
+ lang_encoder = build_language_encoder(cfg)
161
+ sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra=extra)
162
+
163
+ # building criterion
164
+ matcher = HungarianMatcher(
165
+ cost_class=loss_weights['mask']['ce'],
166
+ cost_mask=loss_weights['mask']['bce'],
167
+ cost_dice=loss_weights['mask']['dice'],
168
+ num_points=dec_cfg['TRAIN_NUM_POINTS'],
169
+ spatial_cost=spatial_cost,
170
+ )
171
+
172
+ # init weight dict and criterion loss functions.
173
+ losses = {'seg': [], 'openimage': []}
174
+ if task_switch['mask']:
175
+ losses['seg'] += ["labels", "masks"]
176
+ if task_switch['spatial']:
177
+ losses['seg'] += ["spatials"]
178
+ if task_switch['grounding']:
179
+ losses['seg'] += ["groundings"]
180
+ if task_switch['openimage']:
181
+ losses['openimage'] += ["labels_openimage", "masks"]
182
+ if task_switch['openimage']['grounding']:
183
+ losses['openimage'] += ["groundings"]
184
+
185
+ weight_dict = {}
186
+ for key, turn_on in task_switch.items():
187
+ if turn_on:
188
+ if isinstance(loss_weights[key], dict):
189
+ # HACK it should support bbox in the future
190
+ for key_, weight in loss_weights[key].items():
191
+ weight_dict["loss_{}_{}_0".format(key, key_)] = weight # NOTE: hard code for segmentation that has multiple loss
192
+ else:
193
+ weight_dict["loss_{}_0".format(key)] = loss_weights[key]
194
+
195
+ # generate full weight dict and remove not computed layers.
196
+ if deep_supervision:
197
+ dec_layers = dec_cfg['DEC_LAYERS']
198
+ aux_weight_dict = {}
199
+ for i in range(dec_layers - 1):
200
+ for k, v in weight_dict.items():
201
+ if (i+1) > (top_x_layers[k.split('_')[1]] - 1):
202
+ continue
203
+ aux_weight_dict.update({k.replace('_0', f"_{i+1}"): v})
204
+ weight_dict.update(aux_weight_dict)
205
+
206
+ grd_weight = {'text': dec_cfg['GROUNDING']['TEXT_WEIGHT'], 'class': dec_cfg['GROUNDING']['CLASS_WEIGHT']}
207
+ # generate critenrion for loss function.
208
+ criterion = SetCriterion(
209
+ sem_seg_head.num_classes,
210
+ matcher=matcher,
211
+ weight_dict=weight_dict,
212
+ top_x_layers=top_x_layers,
213
+ eos_coef=no_object_weight,
214
+ losses=[],
215
+ num_points=dec_cfg['TRAIN_NUM_POINTS'],
216
+ oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],
217
+ importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],
218
+ grounding_weight=grd_weight,
219
+ )
220
+
221
+ # extra logistic
222
+ train_dataset_name = cfg['DATASETS']['TRAIN'][0] # HACK for only one training set.
223
+ train_max_iter = dec_cfg['SPATIAL'].get('MAX_ITER', 3)
224
+ phrase_prob = dec_cfg['CAPTION'].get('PHRASE_PROB', 0.5)
225
+ interactive_mode = cfg['STROKE_SAMPLER']['EVAL']['MODE']
226
+ interactive_iter = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
227
+
228
+ dilation = 3
229
+ dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
230
+
231
+ return {
232
+ "backbone": backbone,
233
+ "sem_seg_head": sem_seg_head,
234
+ "criterion": criterion,
235
+ "losses": losses,
236
+ "num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
237
+ "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
238
+ "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
239
+ "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),
240
+ "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
241
+ "sem_seg_postprocess_before_inference": (
242
+ dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
243
+ or dec_cfg['TEST']['PANOPTIC_ON']
244
+ or dec_cfg['TEST']['INSTANCE_ON']
245
+ ),
246
+ "pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
247
+ "pixel_std": cfg['INPUT']['PIXEL_STD'],
248
+ "task_switch": task_switch,
249
+ "phrase_prob": phrase_prob,
250
+ # inference
251
+ "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
252
+ "instance_on": dec_cfg['TEST']['INSTANCE_ON'],
253
+ "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
254
+ "test_topk_per_image": cfg['TEST']['DETECTIONS_PER_IMAGE'],
255
+ "train_dataset_name": train_dataset_name,
256
+ "interactive_mode": interactive_mode,
257
+ "interactive_iter": interactive_iter,
258
+ "dilation_kernel": dilation_kernel,
259
+ "train_max_iter": train_max_iter,
260
+ }
261
+
262
+ @property
263
+ def device(self):
264
+ return self.pixel_mean.device
265
+
266
+ def forward(self, batched_inputs, mode='default'):
267
+ """
268
+ Args:
269
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
270
+ Each item in the list contains the inputs for one image.
271
+ For now, each item in the list is a dict that contains:
272
+ * "image": Tensor, image in (C, H, W) format.
273
+ * "instances": per-region ground truth
274
+ * Other information that's included in the original dicts, such as:
275
+ "height", "width" (int): the output resolution of the model (may be different
276
+ from input resolution), used in inference.
277
+ Returns:
278
+ list[dict]:
279
+ each dict has the results for one image. The dict contains the following keys:
280
+
281
+ * "sem_seg":
282
+ A Tensor that represents the
283
+ per-pixel segmentation prediced by the head.
284
+ The prediction has shape KxHxW that represents the logits of
285
+ each class for each pixel.
286
+ * "panoptic_seg":
287
+ A tuple that represent panoptic output
288
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
289
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
290
+ Each dict contains keys "id", "category_id", "isthing".
291
+ """
292
+ if self.training:
293
+ losses = {}
294
+ if self.task_switch['mask'] or self.task_switch['grounding'] or self.task_switch['spatial']:
295
+ losses_seg = self.forward_seg(batched_inputs)
296
+ losses.update(losses_seg)
297
+ if self.task_switch['openimage'] and self.task_switch['openimage']['mask']:
298
+ losses_openimage = self.forward_openimage(batched_inputs['openimage'])
299
+ losses_openimage = {key.replace('mask', 'openimage'):value for key, value in losses_openimage.items()}
300
+ losses_openimage = {key.replace('grounding', 'grounding_openimage'):value for key, value in losses_openimage.items()}
301
+ losses.update(losses_openimage)
302
+ for k in list(losses.keys()):
303
+ if k in self.criterion.weight_dict:
304
+ losses[k] *= self.criterion.weight_dict[k]
305
+ else: # remove this loss if not specified in `weight_dict`
306
+ losses.pop(k)
307
+ return losses
308
+ else:
309
+ if mode == 'interactive':
310
+ return self.evaluate_interactive(batched_inputs)
311
+ elif mode == 'interactive_grounding':
312
+ return self.evaluate_interactive_grounding(batched_inputs)
313
+ elif mode == 'grounding_spatial':
314
+ return self.evaluate_grounding_sptial(batched_inputs, mode)
315
+ elif mode in ['grounding_phrasecut', 'grounding_refcoco']:
316
+ return self.evaluate_grounding(batched_inputs, mode)
317
+ else:
318
+ return self.evaluate(batched_inputs)
319
+
320
+
321
+ def forward_seg(self, batched_inputs):
322
+ images = [x["image"].to(self.device) for x in batched_inputs]
323
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
324
+ images = ImageList.from_tensors(images, self.size_divisibility)
325
+
326
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
327
+
328
+ extra = {}
329
+ # mask classification target
330
+ if "instances" in batched_inputs[0]:
331
+ # input bounding box is checked to be correct.
332
+ targets = self.prepare_targets(batched_inputs, images)
333
+
334
+ if self.task_switch['grounding']:
335
+ grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
336
+ grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens, padding_value=-1)
337
+ non_zero_query_mask = (grounding_tokens.sum(dim=-1) == -grounding_tokens.shape[-1])
338
+ grounding_tokens[non_zero_query_mask] = 0
339
+
340
+ extra['grounding_tokens'] = grounding_tokens
341
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
342
+
343
+ if self.task_switch['spatial']:
344
+ pos_masks = [x['spatial_query']['rand_shape'].to(self.device) for x in batched_inputs]
345
+ neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs]
346
+ fp_masks = torch.stack([(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs])
347
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks, 'false_positive_mask': fp_masks})
348
+
349
+ features = self.backbone(images.tensor)
350
+ mask_features, _, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
351
+
352
+ # forward spatial only without gradient
353
+ if self.task_switch['spatial']:
354
+ with torch.no_grad():
355
+ # generate random integeter between [0,3]
356
+ rand_iter_num = random.randint(0, self.train_max_iter)
357
+ for i in range(rand_iter_num):
358
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='spatial')
359
+ extra.update(outputs)
360
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
361
+
362
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='seg')
363
+
364
+ extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
365
+ 'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default')),
366
+ 'false_positive_mask': extra['false_positive_mask']}
367
+ # bipartite matching-based loss
368
+ self.criterion.losses = self.losses['seg'] # seg criterion losses
369
+ losses = self.criterion(outputs, targets, extra)
370
+
371
+ del outputs
372
+ return losses
373
+
374
+ def evaluate(self, batched_inputs):
375
+ images = [x["image"].to(self.device) for x in batched_inputs]
376
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
377
+
378
+ images = ImageList.from_tensors(images, self.size_divisibility)
379
+ img_bs = images.tensor.shape[0]
380
+
381
+ targets = targets_grounding = queries_grounding = None
382
+ features = self.backbone(images.tensor)
383
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
384
+
385
+ mask_cls_results = outputs["pred_logits"]
386
+ mask_pred_results = outputs["pred_masks"]
387
+ box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
388
+
389
+ # upsample masks
390
+ mask_pred_results = F.interpolate(
391
+ mask_pred_results,
392
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
393
+ mode="bilinear",
394
+ align_corners=False,
395
+ )
396
+
397
+ input_size = mask_pred_results.shape[-2:]
398
+ del outputs
399
+
400
+ processed_results = []
401
+ for mask_cls_result, mask_pred_result, box_pred_result, input_per_image, image_size in zip(
402
+ mask_cls_results, mask_pred_results, box_pred_results, batched_inputs, images.image_sizes
403
+ ):
404
+ height = input_per_image.get("height", image_size[0])
405
+ width = input_per_image.get("width", image_size[1])
406
+ processed_results.append({})
407
+
408
+ if self.sem_seg_postprocess_before_inference:
409
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
410
+ mask_pred_result, image_size, height, width
411
+ )
412
+ mask_cls_result = mask_cls_result.to(mask_pred_result)
413
+
414
+ # semantic segmentation inference
415
+ if self.semantic_on:
416
+ r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
417
+ if not self.sem_seg_postprocess_before_inference:
418
+ r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
419
+ processed_results[-1]["sem_seg"] = r
420
+
421
+ # panoptic segmentation inference
422
+ if self.panoptic_on:
423
+ panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
424
+ processed_results[-1]["panoptic_seg"] = panoptic_r
425
+
426
+ # instance segmentation inference
427
+ if self.instance_on:
428
+ if self.task_switch['bbox']:
429
+ box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
430
+ instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
431
+ processed_results[-1]["instances"] = instance_r
432
+
433
+ return processed_results
434
+
435
+ def evaluate_interactive(self, batched_inputs):
436
+ assert self.task_switch['spatial']
437
+ assert 'spatial_query' in batched_inputs[0]
438
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
439
+
440
+ images = [x["image"].to(self.device) for x in batched_inputs]
441
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
442
+ images = ImageList.from_tensors(images, self.size_divisibility)
443
+ img_bs = images.tensor.shape[0]
444
+
445
+ targets = targets_grounding = queries_grounding = None
446
+ extra = {}
447
+
448
+ features = self.backbone(images.tensor)
449
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
450
+
451
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
452
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
453
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
454
+ mask_features = mask_features.repeat(nm,1,1,1)
455
+
456
+ all_batch_shape_iou = []
457
+ pred_smask_pointer = None
458
+ prev_smask_pointer = None
459
+ pred_smask_all = None
460
+
461
+ # visualization code
462
+ # v_pred_mask = []
463
+ # v_pos_mask = []
464
+ # v_neg_mask = []
465
+ # v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
466
+ query_index = self.sem_seg_head.predictor.query_index
467
+ if self.interactive_mode in ['best', 'best_random']:
468
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
469
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
470
+
471
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
472
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
473
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
474
+ elif self.interactive_mode == 'random':
475
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
476
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
477
+
478
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
479
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
480
+ extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
481
+ else:
482
+ assert False, "invalid interactive mode"
483
+
484
+ for i in range(self.interactive_iter):
485
+ # v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
486
+ # v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
487
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
488
+ extra.update(outputs)
489
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
490
+ # v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
491
+
492
+ s = image_sizes[0]
493
+ b = batched_inputs[0]
494
+ pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[:,0].sigmoid() > 0.5
495
+ gt_smask = b['gt_masks_orisize']
496
+ ious = get_iou(gt_smask, pred_smask_all)
497
+ all_batch_shape_iou += [ious]
498
+ if (ious > 0.9).sum() == len(ious):
499
+ all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
500
+ break
501
+ if self.interactive_mode in ['best', 'best_random']:
502
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
503
+ elif self.interactive_mode == 'random':
504
+ extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
505
+ else:
506
+ assert False, "invalid interactive mode"
507
+ all_batch_shape_iou = torch.stack(all_batch_shape_iou)
508
+ processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
509
+
510
+ return processed_results
511
+
512
+ def evaluate_interactive_single(self, batched_inputs, extra={}):
513
+ assert self.task_switch['spatial']
514
+ assert 'spatial_query' in batched_inputs[0]
515
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
516
+
517
+ images = [x["image"].to(self.device) for x in batched_inputs]
518
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
519
+ images = ImageList.from_tensors(images, self.size_divisibility)
520
+ img_bs = images.tensor.shape[0]
521
+
522
+ targets = targets_grounding = queries_grounding = None
523
+
524
+ features = self.backbone(images.tensor)
525
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
526
+
527
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
528
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
529
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
530
+ mask_features = mask_features.repeat(nm,1,1,1)
531
+
532
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
533
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
534
+
535
+ s = image_sizes[0]
536
+ b = batched_inputs[0]
537
+ pred_smask_ori = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
538
+ pred_smask_batch = pred_smask[:,:,:s[0],:s[1]].sigmoid() > 0.5
539
+ ious = []
540
+ if 'gt_masks_orisize' in b:
541
+ gt_smask = b['gt_masks_orisize'].to(pred_smask_ori.device)
542
+ ious = get_iou(gt_smask, pred_smask_ori)
543
+ processed_results = [{"mask_iou": ious, 'pred_mask_ori': pred_smask_ori, 'pred_mask_batch': pred_smask_batch}]
544
+ return processed_results
545
+
546
+ def evaluate_interactive_grounding(self, batched_inputs):
547
+ assert self.task_switch['spatial']
548
+ assert 'spatial_query' in batched_inputs[0]
549
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
550
+
551
+ images = [x["image"].to(self.device) for x in batched_inputs]
552
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
553
+ images = ImageList.from_tensors(images, self.size_divisibility)
554
+ img_bs = images.tensor.shape[0]
555
+
556
+ targets = targets_grounding = queries_grounding = None
557
+ extra = {}
558
+
559
+ features = self.backbone(images.tensor)
560
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
561
+
562
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
563
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
564
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
565
+ mask_features = mask_features.repeat(nm,1,1,1)
566
+
567
+ all_batch_shape_iou = []
568
+ pred_smask_pointer = None
569
+ prev_smask_pointer = None
570
+ pred_smask_all = None
571
+
572
+ # visualization code
573
+ # v_pred_mask = []
574
+ # v_pos_mask = []
575
+ # v_neg_mask = []
576
+ # v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
577
+ query_index = self.sem_seg_head.predictor.query_index
578
+ if self.interactive_mode in ['best', 'best_random']:
579
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
580
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
581
+
582
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
583
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
584
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
585
+ elif self.interactive_mode == 'random':
586
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
587
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
588
+
589
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
590
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
591
+ extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
592
+ else:
593
+ assert False, "invalid interactive mode"
594
+
595
+ grd_texts = batched_inputs[0]['classes']
596
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
597
+ token_emb = gtext['token_emb']
598
+ tokens = gtext['tokens']
599
+ query_emb = nn.utils.rnn.pad_sequence([_token_emb[_tokens.bool()] for _token_emb, _tokens in zip(token_emb, tokens['attention_mask'])], padding_value=-1)
600
+ non_zero_query_mask = (query_emb.sum(dim=-1) < 0)
601
+
602
+ extra['grounding_tokens'] = query_emb
603
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
604
+
605
+ for i in range(self.interactive_iter):
606
+ # v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
607
+ # v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
608
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
609
+ extra.update(outputs)
610
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
611
+ # v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
612
+
613
+ s = image_sizes[0]
614
+ b = batched_inputs[0]
615
+ pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[:,0].sigmoid() > 0.5
616
+ gt_smask = b['gt_masks_orisize']
617
+ ious = get_iou(gt_smask, pred_smask_all)
618
+ all_batch_shape_iou += [ious]
619
+ if (ious > 0.9).sum() == len(ious):
620
+ all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
621
+ break
622
+ if self.interactive_mode in ['best', 'best_random']:
623
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
624
+ elif self.interactive_mode == 'random':
625
+ extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
626
+ else:
627
+ assert False, "invalid interactive mode"
628
+ all_batch_shape_iou = torch.stack(all_batch_shape_iou)
629
+ processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
630
+
631
+ # visualization
632
+ # VL.step()
633
+ # import cv2
634
+ # v_masks = []
635
+ # v_pos_masks = []
636
+ # v_neg_masks = []
637
+ # txt = []
638
+
639
+ # img = batched_inputs[0]['image'].permute(1,2,0).cpu().numpy()
640
+ # mask_img = VL.overlay_single_mask_to_image(img[:,:,::-1], v_gt_mask.cpu().float().numpy())
641
+ # acc_pos_mask = np.zeros(v_pos_mask[0].shape)
642
+ # acc_neg_mask = np.zeros(v_neg_mask[0].shape)
643
+ # for x,y,z,iou in zip(v_pos_mask, v_neg_mask, v_pred_mask, all_batch_shape_iou):
644
+ # # dilate x,y
645
+ # x = cv2.dilate(x, np.ones((5,5), np.uint8), iterations=3)
646
+ # y = cv2.dilate(y, np.ones((5,5), np.uint8), iterations=3)
647
+ # acc_pos_mask += x
648
+ # acc_neg_mask += y
649
+
650
+ # v_masks += [z]
651
+ # v_pos_masks += [acc_pos_mask.clip(0,1)]
652
+ # v_neg_masks += [acc_neg_mask.clip(0,1)]
653
+ # txt += ["pred_{}".format(str(iou[0].item())[0:5])]
654
+
655
+ # VL.add_image(img[:,:,::-1])
656
+ # VL.insert(mask_img, "gt_mask")
657
+ # VL.overlay_obj_mask_to_image_withposneg(img[:,:,::-1], v_masks, v_pos_masks, v_neg_masks, txt, max_len=20)
658
+ return processed_results
659
+
660
+ def evaluate_referring_image(self, batched_inputs, extra={}):
661
+ assert self.task_switch['spatial']
662
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
663
+ assert self.interactive_mode == 'best'
664
+
665
+ images = [x["image"].to(self.device) for x in batched_inputs]
666
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
667
+ images = ImageList.from_tensors(images, self.size_divisibility)
668
+ img_bs = images.tensor.shape[0]
669
+
670
+ targets = targets_grounding = queries_grounding = None
671
+ features = self.backbone(images.tensor)
672
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
673
+
674
+ if 'spatial_query' in batched_inputs[0]:
675
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
676
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
677
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
678
+ mask_features = mask_features.repeat(nm,1,1,1)
679
+
680
+ query_index = self.sem_seg_head.predictor.query_index
681
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
682
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
683
+
684
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
685
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
686
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
687
+
688
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='refimg')
689
+ return outputs, images.tensor.shape
690
+
691
+ def evaluate_grounding(self, batched_inputs, mode):
692
+ images = [x["image"].to(self.device) for x in batched_inputs]
693
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
694
+ images = ImageList.from_tensors(images, self.size_divisibility)
695
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
696
+
697
+ extra = {}
698
+ # mask_pred_results = []
699
+ # for idx, batch_per_image in enumerate(batched_inputs):
700
+ # grd_texts = batch_per_image['groundings']['texts']
701
+ # grd_masks = []
702
+ # for anno_text in grd_texts:
703
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
704
+ # token_emb = gtext['token_emb']
705
+ # tokens = gtext['tokens']
706
+
707
+ # grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
708
+ # extra['grounding_tokens'] = grd_emb[:,None]
709
+
710
+ # assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
711
+ # features = self.backbone(images.tensor)
712
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
713
+
714
+ # pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
715
+ # v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
716
+ # t_emb = grd_emb[-1:]
717
+
718
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
719
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
720
+
721
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
722
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
723
+
724
+ # matched_id = out_prob.max(0)[1]
725
+ # grd_masks += [pred_gmasks[matched_id,:,:]]
726
+ # mask_pred_results += [torch.cat(grd_masks)]
727
+
728
+ # comment for multi object inference.
729
+ mask_pred_results = []
730
+ for idx, batch_per_image in enumerate(batched_inputs):
731
+ grd_texts = batch_per_image['groundings']['texts']
732
+ grd_texts = [x[0] for x in grd_texts]
733
+
734
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
735
+ token_emb = gtext['token_emb']
736
+ tokens = gtext['tokens']
737
+ query_emb = token_emb[tokens['attention_mask'].bool()]
738
+ non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
739
+
740
+ extra['grounding_tokens'] = query_emb[:,None]
741
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
742
+
743
+ features = self.backbone(images.tensor)
744
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
745
+
746
+ pred_gmasks = outputs['pred_gmasks'][idx]
747
+ v_emb = outputs['pred_gtexts'][idx]
748
+ t_emb = gtext['class_emb']
749
+
750
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
751
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
752
+
753
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
754
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
755
+
756
+ matched_id = out_prob.max(0)[1]
757
+ mask_pred_results += [pred_gmasks[matched_id,:,:]]
758
+
759
+ for i in range(len(mask_pred_results)):
760
+ # upsample masks
761
+ mask_pred_results[i] = F.interpolate(
762
+ mask_pred_results[i][None,],
763
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
764
+ mode="bilinear",
765
+ align_corners=False,
766
+ )[0]
767
+
768
+ processed_results = []
769
+ for mask_pred_result, input_per_image, image_size in zip(
770
+ mask_pred_results, batched_inputs, images.image_sizes
771
+ ):
772
+ height = input_per_image.get("height", image_size[0])
773
+ width = input_per_image.get("width", image_size[1])
774
+ processed_results.append({})
775
+
776
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
777
+ mask_pred_result, image_size, height, width
778
+ )
779
+ processed_results[-1]['grounding_mask'] = mask_pred_result
780
+
781
+ # compute bbox
782
+ # bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
783
+ # bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
784
+ # processed_results[-1]['grounding_box'] = bbox
785
+
786
+ return processed_results
787
+
788
+ def evaluate_grounding_sptial(self, batched_inputs, mode):
789
+ images = [x["image"].to(self.device) for x in batched_inputs]
790
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
791
+ images = ImageList.from_tensors(images, self.size_divisibility)
792
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
793
+
794
+ extra = {}
795
+ dilation = 3
796
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
797
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
798
+ pos_masks = (F.conv2d(pos_masks.float(), self.dilation_kernel, padding=dilation//2) > 0).unbind(0)
799
+
800
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
801
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
802
+
803
+ mask_pred_results = []
804
+ for idx, batch_per_image in enumerate(batched_inputs):
805
+ grd_texts = batch_per_image['groundings']['texts']
806
+ grd_masks = []
807
+ for idx2, anno_text in enumerate(grd_texts):
808
+ extra.update({'spatial_query_pos_mask': [pos_masks[idx2]], 'spatial_query_neg_mask': [neg_masks[idx2]]})
809
+
810
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
811
+ token_emb = gtext['token_emb']
812
+ tokens = gtext['tokens']
813
+
814
+ grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
815
+ non_zero_query_mask = torch.zeros(grd_emb[:,None].shape[:-1], dtype=torch.bool, device=grd_emb.device)
816
+ extra['grounding_tokens'] = grd_emb[:,None]
817
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
818
+
819
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
820
+ features = self.backbone(images.tensor)
821
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
822
+
823
+ pred_gmasks = outputs['pred_gmasks'][idx]
824
+ v_emb = outputs['pred_gtexts'][idx]
825
+ t_emb = gtext['class_emb']
826
+
827
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
828
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
829
+
830
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
831
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
832
+
833
+ matched_id = out_prob.max(0)[1]
834
+ grd_masks += [pred_gmasks[matched_id,:,:]]
835
+ # grd_masks += [outputs['prev_mask'][0]]
836
+
837
+ mask_pred_results += [torch.cat(grd_masks)]
838
+
839
+ # comment for multi object inference.
840
+ # mask_pred_results = []
841
+ # for idx, batch_per_image in enumerate(batched_inputs):
842
+ # grd_texts = batch_per_image['groundings']['texts']
843
+ # grd_texts = [x[0] for x in grd_texts]
844
+
845
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
846
+ # token_emb = gtext['token_emb']
847
+ # tokens = gtext['tokens']
848
+ # query_emb = token_emb[tokens['attention_mask'].bool()]
849
+ # non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
850
+
851
+ # extra['grounding_tokens'] = query_emb[:,None]
852
+ # extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
853
+
854
+ # features = self.backbone(images.tensor)
855
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
856
+
857
+ # pred_gmasks = outputs['pred_gmasks'][idx]
858
+ # v_emb = outputs['pred_gtexts'][idx]
859
+ # t_emb = gtext['class_emb']
860
+
861
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
862
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
863
+
864
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
865
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
866
+
867
+ # matched_id = out_prob.max(0)[1]
868
+ # mask_pred_results += [pred_gmasks[matched_id,:,:]]
869
+
870
+ for i in range(len(mask_pred_results)):
871
+ # upsample masks
872
+ mask_pred_results[i] = F.interpolate(
873
+ mask_pred_results[i][None,],
874
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
875
+ mode="bilinear",
876
+ align_corners=False,
877
+ )[0]
878
+
879
+ processed_results = []
880
+ for mask_pred_result, input_per_image, image_size in zip(
881
+ mask_pred_results, batched_inputs, images.image_sizes
882
+ ):
883
+ height = input_per_image.get("height", image_size[0])
884
+ width = input_per_image.get("width", image_size[1])
885
+ processed_results.append({})
886
+
887
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
888
+ mask_pred_result, image_size, height, width
889
+ )
890
+ processed_results[-1]['grounding_mask'] = mask_pred_result
891
+
892
+ return processed_results
893
+
894
+ def prepare_targets(self, batched_inputs, images):
895
+ h_pad, w_pad = images.tensor.shape[-2:]
896
+ new_targets = []
897
+ for idx, batch_per_image in enumerate(batched_inputs):
898
+ targets_per_image = batch_per_image['instances'].to(self.device)
899
+ # pad gt
900
+ gt_masks = targets_per_image.gt_masks.tensor
901
+ padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
902
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
903
+
904
+ gt_boxes = targets_per_image.gt_boxes.tensor
905
+ ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
906
+ gt_boxes = gt_boxes / ratio
907
+ xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
908
+ gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
909
+
910
+ target_dict = {
911
+ "labels": targets_per_image.gt_classes,
912
+ "is_things": targets_per_image.is_things,
913
+ "masks": padded_masks,
914
+ "boxes": gt_boxes,
915
+ }
916
+
917
+ if self.task_switch['spatial']:
918
+ # prepare targets for spatial query
919
+ target_dict['gt_spatial_masks'] = batch_per_image['spatial_query']['gt_masks']
920
+
921
+ if self.task_switch['grounding']:
922
+ grd_masks = batch_per_image['groundings']['masks']
923
+ grd_texts = batch_per_image['groundings']['texts']
924
+ grd_hash = batch_per_image['groundings']['hash']
925
+ grd_task = batch_per_image['groundings']['mode']
926
+
927
+ if len(grd_masks) == 0:
928
+ padded_masks = None
929
+ else:
930
+ padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
931
+ padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
932
+
933
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
934
+ token_emb = gtext['token_emb']
935
+ tokens = gtext['tokens']
936
+
937
+ unique_hash_id = np.unique(grd_hash, return_index=True)[1]
938
+ selected_mask = np.zeros(len(grd_hash)).astype(np.bool)
939
+ selected_mask[unique_hash_id] = True
940
+
941
+ selected_token_emb = token_emb[selected_mask]
942
+ selected_attn_mask = tokens['attention_mask'][selected_mask]
943
+ query_emb = selected_token_emb[selected_attn_mask.bool()]
944
+
945
+ class_idx = tokens['attention_mask'].sum(dim=-1) - 1
946
+ class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
947
+ class_emb = token_emb[class_idx]
948
+
949
+ target_dict['grounding_masks'] = padded_masks
950
+ target_dict['grounding_query_embs'] = query_emb
951
+ target_dict['grounding_class_embs'] = class_emb
952
+ target_dict['grounding_hash'] = grd_hash
953
+ target_dict['grounding_task'] = grd_task
954
+
955
+ new_targets.append(target_dict)
956
+ return new_targets
957
+
958
+ def prepare_next_spaital_mask(self, outputs, batched_inputs, mode='best'):
959
+ gt_masks = [batched_inputs[i]['spatial_query']['gt_masks'] for i in range(len(batched_inputs))]
960
+ if self.training:
961
+ gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor
962
+ else:
963
+ gt_masks = ImageList.from_tensors(gt_masks, self.size_divisibility).tensor.transpose(0,1)
964
+
965
+ pred_masks = (F.interpolate(outputs['prev_mask'], size=gt_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5)
966
+ prev_masks = torch.stack(outputs['spatial_query_pos_mask']) | torch.stack(outputs['spatial_query_neg_mask'])
967
+
968
+ fn = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks) # fn: False Negative, gt:1, pred:0, prev:0
969
+ fp = (~gt_masks & pred_masks) & (~prev_masks) # fp: False Positive, gt:0, pred:1, prev:0
970
+
971
+ # compute iou between gt and pred
972
+ iou = (gt_masks & pred_masks).sum(list(range(1,len(fn.shape)))) / ((gt_masks | pred_masks).sum(dim=list(range(1,len(fn.shape)))) + 1e-8)
973
+ fn_sum = fn.sum(dim=list(range(1,len(fn.shape))))
974
+ fp_sum = fp.sum(dim=list(range(1,len(fp.shape))))
975
+
976
+ is_postive = fn_sum > fp_sum
977
+ # is_postive = torch.ones(len(fn_sum), device=torch.cuda.current_device()).bool()
978
+ select_mask = torch.stack([fn[i] if is_postive[i] else fp[i] for i in range(len(fn))])
979
+
980
+ # conv implementation
981
+ n,_,h,w = select_mask.shape
982
+ mask_dt = (distance_transform((~F.pad(select_mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(n,-1)
983
+ if mode == 'best':
984
+ max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
985
+ elif mode == 'best_random':
986
+ max_xy_idx = torch.stack([torch.arange(n), torch.cat([(mask_dt[i] > 0).nonzero()[torch.randint(0, len((mask_dt[i] > 0).nonzero()), (1,))][0] for i in range(len(mask_dt))]).cpu()]).tolist()
987
+ next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
988
+ next_mask = next_mask.view(n,-1)
989
+ next_mask[max_xy_idx] = True
990
+ next_mask = next_mask.reshape((n,1,h,w)).float()
991
+ dilation = 3
992
+ next_mask = F.conv2d(next_mask, self.dilation_kernel, padding=dilation//2) > 0
993
+
994
+ # determine whether next mask is zero
995
+ keep = (iou < 0.925)
996
+ next_mask = next_mask & keep.view(-1,1,1,1)
997
+
998
+ pos_mask = []
999
+ neg_mask = []
1000
+ for idx, ip in enumerate(is_postive):
1001
+ if ip:
1002
+ pos_mask += [outputs['spatial_query_pos_mask'][idx] | next_mask[idx]]
1003
+ neg_mask += [outputs['spatial_query_neg_mask'][idx]]
1004
+ else:
1005
+ pos_mask += [outputs['spatial_query_pos_mask'][idx]]
1006
+ neg_mask += [outputs['spatial_query_neg_mask'][idx] | next_mask[idx]]
1007
+
1008
+ if 'false_positive_mask' in outputs:
1009
+ fp = outputs['false_positive_mask'] | fp
1010
+ return {'spatial_query_pos_mask': pos_mask, 'spatial_query_neg_mask': neg_mask, 'false_positive_mask': fp}
1011
+
1012
+ def semantic_inference(self, mask_cls, mask_pred):
1013
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
1014
+ mask_pred = mask_pred.sigmoid()
1015
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
1016
+ return semseg
1017
+
1018
+ def panoptic_inference(self, mask_cls, mask_pred):
1019
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
1020
+ mask_pred = mask_pred.sigmoid()
1021
+
1022
+ keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
1023
+ cur_scores = scores[keep]
1024
+ cur_classes = labels[keep]
1025
+ cur_masks = mask_pred[keep]
1026
+ cur_mask_cls = mask_cls[keep]
1027
+ cur_mask_cls = cur_mask_cls[:, :-1]
1028
+
1029
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
1030
+
1031
+ h, w = cur_masks.shape[-2:]
1032
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
1033
+ segments_info = []
1034
+
1035
+ current_segment_id = 0
1036
+
1037
+ if cur_masks.shape[0] == 0:
1038
+ # We didn't detect any mask :(
1039
+ return panoptic_seg, segments_info
1040
+ else:
1041
+ # take argmax
1042
+ cur_mask_ids = cur_prob_masks.argmax(0)
1043
+ stuff_memory_list = {}
1044
+ for k in range(cur_classes.shape[0]):
1045
+ pred_class = cur_classes[k].item()
1046
+ isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
1047
+ mask_area = (cur_mask_ids == k).sum().item()
1048
+ original_area = (cur_masks[k] >= 0.5).sum().item()
1049
+ mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
1050
+
1051
+ if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
1052
+ if mask_area / original_area < self.overlap_threshold:
1053
+ continue
1054
+
1055
+ # merge stuff regions
1056
+ if not isthing:
1057
+ if int(pred_class) in stuff_memory_list.keys():
1058
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
1059
+ continue
1060
+ else:
1061
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
1062
+
1063
+ current_segment_id += 1
1064
+ panoptic_seg[mask] = current_segment_id
1065
+
1066
+ segments_info.append(
1067
+ {
1068
+ "id": current_segment_id,
1069
+ "isthing": bool(isthing),
1070
+ "category_id": int(pred_class),
1071
+ }
1072
+ )
1073
+
1074
+ return panoptic_seg, segments_info
1075
+
1076
+ def instance_inference(self, mask_cls, mask_pred, box_pred):
1077
+ # mask_pred is already processed to have the same shape as original input
1078
+ image_size = mask_pred.shape[-2:]
1079
+
1080
+ # [Q, K]
1081
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
1082
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
1083
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
1084
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
1085
+
1086
+ labels_per_image = labels[topk_indices]
1087
+ topk_indices = (topk_indices // self.sem_seg_head.num_classes)
1088
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
1089
+ mask_pred = mask_pred[topk_indices]
1090
+ if box_pred is not None:
1091
+ box_pred = box_pred[topk_indices]
1092
+
1093
+ # if this is panoptic segmentation, we only keep the "thing" classes
1094
+ if self.panoptic_on:
1095
+ keep = torch.zeros_like(scores_per_image).bool()
1096
+ for i, lab in enumerate(labels_per_image):
1097
+ keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
1098
+
1099
+ scores_per_image = scores_per_image[keep]
1100
+ labels_per_image = labels_per_image[keep]
1101
+ mask_pred = mask_pred[keep]
1102
+
1103
+ if box_pred is not None:
1104
+ box_pred = box_pred[keep]
1105
+
1106
+ result = Instances(image_size)
1107
+ # mask (before sigmoid)
1108
+ result.pred_masks = (mask_pred > 0).float()
1109
+ # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
1110
+ # Uncomment the following to get boxes from masks (this is slow)
1111
+
1112
+ if box_pred is not None:
1113
+ result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
1114
+ else:
1115
+ result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
1116
+
1117
+ # calculate average mask prob
1118
+ mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
1119
+ result.scores = scores_per_image * mask_scores_per_image
1120
+ result.pred_classes = labels_per_image
1121
+
1122
+ return result
1123
+
1124
+ def prepare_targets4query(self, targets, images, topk=5):
1125
+ h_pad, w_pad = images.tensor.shape[-2:]
1126
+ new_targets = []
1127
+ new_queries = []
1128
+ for targets_per_image in targets:
1129
+ # we randomly sample maximally topk concepts
1130
+ unique_target_classes = [k for k in set(targets_per_image.gt_classes.tolist())]
1131
+ selected_target_classes = random.sample(unique_target_classes, min(topk, len(unique_target_classes)))
1132
+ new_targets_per_image = []
1133
+ new_queries_per_image = []
1134
+ for clss in selected_target_classes:
1135
+ indices = (targets_per_image.gt_classes == clss).nonzero().view(-1)
1136
+ # pad gt
1137
+ gt_masks = targets_per_image.gt_masks[indices]
1138
+ padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
1139
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
1140
+
1141
+ # convert class into concept name and then token seq
1142
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings([COCO_PANOPTIC_CLASSES[clss]], name='grounding')
1143
+ query = getattr(self.sem_seg_head.predictor.lang_encoder, 'grounding_text_embeddings')
1144
+
1145
+ new_targets.append(
1146
+ {
1147
+ "labels": targets_per_image.gt_classes[indices],
1148
+ "masks": padded_masks,
1149
+ }
1150
+ )
1151
+ new_queries_per_image.append(query)
1152
+ new_queries.append(new_queries_per_image)
1153
+
1154
+ return new_targets, new_queries
1155
+
1156
+
1157
+
1158
+ @register_model
1159
+ def get_seem_model(cfg, **kwargs):
1160
+ return GeneralizedSEEM(cfg)
modeling/architectures/seem_model_v1.py ADDED
@@ -0,0 +1,1179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # SEEM -- Segment Everything Everywhere All at Once
3
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
4
+ # Written by Xueyan Zou ([email protected])
5
+ # --------------------------------------------------------
6
+
7
+ import random
8
+ from typing import Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+ from kornia.contrib import distance_transform
15
+
16
+ from detectron2.structures import Boxes, ImageList, Instances, BitMasks
17
+ from detectron2.utils.memory import retry_if_cuda_oom
18
+ from detectron2.data import MetadataCatalog
19
+
20
+ from .build import register_model
21
+
22
+ from ..utils import configurable, get_class_names, get_iou, Spatial_ImageList
23
+ from ..vision.backbone import build_backbone, Backbone
24
+ from ..body import build_xdecoder_head
25
+ from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
26
+ from ..language import build_language_encoder
27
+ from ..language.loss import vl_similarity
28
+ from utilities.prompt_engineering import prompt_engineering
29
+ from utilities.constants import COCO_PANOPTIC_CLASSES, BIOMED_CLASSES
30
+
31
+
32
+ class GeneralizedSEEM(nn.Module):
33
+
34
+ @configurable
35
+ def __init__(
36
+ self,
37
+ *,
38
+ backbone: Backbone,
39
+ sem_seg_head: nn.Module,
40
+ criterion: nn.Module,
41
+ losses: dict,
42
+ num_queries: int,
43
+ object_mask_threshold: float,
44
+ overlap_threshold: float,
45
+ metadata,
46
+ task_switch: dict,
47
+ phrase_prob: float,
48
+ size_divisibility: int,
49
+ sem_seg_postprocess_before_inference: bool,
50
+ pixel_mean: Tuple[float],
51
+ pixel_std: Tuple[float],
52
+ # inference
53
+ semantic_on: bool,
54
+ panoptic_on: bool,
55
+ instance_on: bool,
56
+ test_topk_per_image: int,
57
+ train_dataset_name: str,
58
+ interactive_mode: str,
59
+ interactive_iter: str,
60
+ dilation_kernel: torch.Tensor,
61
+ train_max_iter: int,
62
+ binary_classes: bool,
63
+ standard_text_for_eval: bool,
64
+ ):
65
+ """
66
+ Args:
67
+ backbone: a backbone module, must follow detectron2's backbone interface
68
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
69
+ criterion: a module that defines the loss
70
+ num_queries: int, number of queries
71
+ object_mask_threshold: float, threshold to filter query based on classification score
72
+ for panoptic segmentation inference
73
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
74
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
75
+ segmentation inference
76
+ size_divisibility: Some backbones require the input height and width to be divisible by a
77
+ specific integer. We can use this to override such requirement.
78
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
79
+ to original input size before semantic segmentation inference or after.
80
+ For high-resolution dataset like Mapillary, resizing predictions before
81
+ inference will cause OOM error.
82
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
83
+ the per-channel mean and std to be used to normalize the input image
84
+ semantic_on: bool, whether to output semantic segmentation prediction
85
+ instance_on: bool, whether to output instance segmentation prediction
86
+ panoptic_on: bool, whether to output panoptic segmentation prediction
87
+ test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
88
+ """
89
+ super().__init__()
90
+ self.backbone = backbone
91
+ self.sem_seg_head = sem_seg_head
92
+ self.criterion = criterion
93
+ self.losses = losses
94
+ self.num_queries = num_queries
95
+ self.overlap_threshold = overlap_threshold
96
+ self.object_mask_threshold = object_mask_threshold
97
+ self.metadata = metadata
98
+ if size_divisibility < 0:
99
+ # use backbone size_divisibility if not set
100
+ size_divisibility = self.backbone.size_divisibility
101
+ self.size_divisibility = size_divisibility
102
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
103
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
104
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
105
+
106
+ # additional args
107
+ self.semantic_on = semantic_on
108
+ self.instance_on = instance_on
109
+ self.panoptic_on = panoptic_on
110
+
111
+ # caption argument
112
+ self.task_switch = task_switch
113
+ self.phrase_prob = phrase_prob
114
+ self.train_max_iter = train_max_iter
115
+
116
+ self.test_topk_per_image = test_topk_per_image
117
+ self.train_class_names = get_class_names(train_dataset_name)
118
+ if binary_classes:
119
+ self.train_class_names = ['target', 'background']
120
+ self.interactive_mode = interactive_mode
121
+ self.interactive_iter = interactive_iter
122
+
123
+ if not self.semantic_on:
124
+ assert self.sem_seg_postprocess_before_inference
125
+
126
+ self.register_buffer("dilation_kernel", dilation_kernel)
127
+
128
+ self.standard_text_for_eval = standard_text_for_eval
129
+
130
+ @classmethod
131
+ def from_config(cls, cfg):
132
+ enc_cfg = cfg['MODEL']['ENCODER']
133
+ dec_cfg = cfg['MODEL']['DECODER']
134
+
135
+ # Loss parameters:
136
+ deep_supervision = dec_cfg['DEEP_SUPERVISION']
137
+ no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']
138
+
139
+ # loss weights
140
+ loss_weights = {'mask': {'ce': dec_cfg['CLASS_WEIGHT'], 'dice': dec_cfg['DICE_WEIGHT'], 'bce': dec_cfg['MASK_WEIGHT']},
141
+ 'bbox': {'l1': dec_cfg['BBOX_WEIGHT'], 'giou': dec_cfg['GIOU_WEIGHT']},
142
+ 'spatial': {'ce': dec_cfg['SCLASS_WEIGHT'], 'dice': dec_cfg['SDICE_WEIGHT'], 'bce': dec_cfg['SMASK_WEIGHT']},
143
+ 'grounding': {'ce': dec_cfg['GCLASS_WEIGHT'], 'dice': dec_cfg['GDICE_WEIGHT'], 'bce': dec_cfg['GMASK_WEIGHT']},
144
+ 'openimage': {'ce': dec_cfg['OCLASS_WEIGHT'], 'dice': dec_cfg['ODICE_WEIGHT'], 'bce': dec_cfg['OMASK_WEIGHT']}}
145
+
146
+ openimage_switch = {'grounding': dec_cfg['OPENIMAGE']['GROUNDING'].get('ENABLED', False),
147
+ 'mask': dec_cfg['OPENIMAGE'].get('ENABLED', False)}
148
+
149
+ task_switch = {'bbox': dec_cfg.get('DETECTION', False),
150
+ 'mask': dec_cfg['MASK'].get('ENABLED', True),
151
+ 'spatial': dec_cfg['SPATIAL'].get('ENABLED', False),
152
+ 'grounding': dec_cfg['GROUNDING'].get('ENABLED', False),
153
+ 'openimage': openimage_switch}
154
+
155
+ top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10),
156
+ 'grounding': dec_cfg.get('TOP_GROUNDING_LAYERS', 10),
157
+ 'openimage': dec_cfg.get('TOP_OPENIMAGE_LAYERS', 10),
158
+ 'spatial': dec_cfg.get('TOP_SPATIAL_LAYERS', 10)}
159
+
160
+ spatial_cost = {"class_weight": dec_cfg['COST_SPATIAL']['CLASS_WEIGHT'],
161
+ "mask_weight": dec_cfg['COST_SPATIAL']['MASK_WEIGHT'],
162
+ "dice_weight": dec_cfg['COST_SPATIAL']['DICE_WEIGHT']}
163
+
164
+ extra = {'task_switch': task_switch}
165
+ backbone = build_backbone(cfg)
166
+ lang_encoder = build_language_encoder(cfg)
167
+ sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra=extra)
168
+
169
+ # building criterion
170
+ matcher = HungarianMatcher(
171
+ cost_class=loss_weights['mask']['ce'],
172
+ cost_mask=loss_weights['mask']['bce'],
173
+ cost_dice=loss_weights['mask']['dice'],
174
+ num_points=dec_cfg['TRAIN_NUM_POINTS'],
175
+ spatial_cost=spatial_cost,
176
+ )
177
+
178
+ # init weight dict and criterion loss functions.
179
+ losses = {'seg': [], 'openimage': []}
180
+ if task_switch['mask']:
181
+ losses['seg'] += ["labels", "masks"]
182
+ if task_switch['spatial']:
183
+ losses['seg'] += ["spatials"]
184
+ if task_switch['grounding']:
185
+ losses['seg'] += ["groundings"]
186
+ if task_switch['openimage']:
187
+ losses['openimage'] += ["labels_openimage", "masks"]
188
+ if task_switch['openimage']['grounding']:
189
+ losses['openimage'] += ["groundings"]
190
+
191
+ weight_dict = {}
192
+ for key, turn_on in task_switch.items():
193
+ if turn_on:
194
+ if isinstance(loss_weights[key], dict):
195
+ # HACK it should support bbox in the future
196
+ for key_, weight in loss_weights[key].items():
197
+ weight_dict["loss_{}_{}_0".format(key, key_)] = weight # NOTE: hard code for segmentation that has multiple loss
198
+ else:
199
+ weight_dict["loss_{}_0".format(key)] = loss_weights[key]
200
+
201
+ # generate full weight dict and remove not computed layers.
202
+ if deep_supervision:
203
+ dec_layers = dec_cfg['DEC_LAYERS']
204
+ aux_weight_dict = {}
205
+ for i in range(dec_layers - 1):
206
+ for k, v in weight_dict.items():
207
+ if (i+1) > (top_x_layers[k.split('_')[1]] - 1):
208
+ continue
209
+ aux_weight_dict.update({k.replace('_0', f"_{i+1}"): v})
210
+ weight_dict.update(aux_weight_dict)
211
+
212
+ grd_weight = {'text': dec_cfg['GROUNDING']['TEXT_WEIGHT'], 'class': dec_cfg['GROUNDING']['CLASS_WEIGHT']}
213
+ # generate critenrion for loss function.
214
+ criterion = SetCriterion(
215
+ sem_seg_head.num_classes,
216
+ matcher=matcher,
217
+ weight_dict=weight_dict,
218
+ top_x_layers=top_x_layers,
219
+ eos_coef=no_object_weight,
220
+ losses=[],
221
+ num_points=dec_cfg['TRAIN_NUM_POINTS'],
222
+ oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],
223
+ importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],
224
+ grounding_weight=grd_weight,
225
+ )
226
+
227
+ # extra logistic
228
+ train_dataset_name = cfg['DATASETS']['TRAIN'][0] # HACK for only one training set.
229
+ train_max_iter = dec_cfg['SPATIAL'].get('MAX_ITER', 3)
230
+ phrase_prob = dec_cfg['CAPTION'].get('PHRASE_PROB', 0.5)
231
+ interactive_mode = cfg['STROKE_SAMPLER']['EVAL']['MODE']
232
+ interactive_iter = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
233
+
234
+ dilation = 3
235
+ dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
236
+
237
+ return {
238
+ "backbone": backbone,
239
+ "sem_seg_head": sem_seg_head,
240
+ "criterion": criterion,
241
+ "losses": losses,
242
+ "num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
243
+ "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
244
+ "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
245
+ "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),
246
+ "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
247
+ "sem_seg_postprocess_before_inference": (
248
+ dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
249
+ or dec_cfg['TEST']['PANOPTIC_ON']
250
+ or dec_cfg['TEST']['INSTANCE_ON']
251
+ ),
252
+ "pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
253
+ "pixel_std": cfg['INPUT']['PIXEL_STD'],
254
+ "task_switch": task_switch,
255
+ "phrase_prob": phrase_prob,
256
+ # inference
257
+ "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
258
+ "instance_on": dec_cfg['TEST']['INSTANCE_ON'],
259
+ "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
260
+ "test_topk_per_image": cfg['TEST']['DETECTIONS_PER_IMAGE'],
261
+ "train_dataset_name": train_dataset_name,
262
+ "interactive_mode": interactive_mode,
263
+ "interactive_iter": interactive_iter,
264
+ "dilation_kernel": dilation_kernel,
265
+ "train_max_iter": train_max_iter,
266
+ "binary_classes": enc_cfg['BINARY_CLASSES'],
267
+ "standard_text_for_eval": cfg['STANDARD_TEXT_FOR_EVAL'],
268
+ }
269
+
270
+ @property
271
+ def device(self):
272
+ return self.pixel_mean.device
273
+
274
+ def forward(self, batched_inputs, mode='default'):
275
+ """
276
+ Args:
277
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
278
+ Each item in the list contains the inputs for one image.
279
+ For now, each item in the list is a dict that contains:
280
+ * "image": Tensor, image in (C, H, W) format.
281
+ * "instances": per-region ground truth
282
+ * Other information that's included in the original dicts, such as:
283
+ "height", "width" (int): the output resolution of the model (may be different
284
+ from input resolution), used in inference.
285
+ Returns:
286
+ list[dict]:
287
+ each dict has the results for one image. The dict contains the following keys:
288
+
289
+ * "sem_seg":
290
+ A Tensor that represents the
291
+ per-pixel segmentation prediced by the head.
292
+ The prediction has shape KxHxW that represents the logits of
293
+ each class for each pixel.
294
+ * "panoptic_seg":
295
+ A tuple that represent panoptic output
296
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
297
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
298
+ Each dict contains keys "id", "category_id", "isthing".
299
+ """
300
+ if self.training:
301
+ losses = {}
302
+ if self.task_switch['mask'] or self.task_switch['grounding'] or self.task_switch['spatial']:
303
+ losses_seg = self.forward_seg(batched_inputs)
304
+ losses.update(losses_seg)
305
+ if self.task_switch['openimage'] and self.task_switch['openimage']['mask']:
306
+ losses_openimage = self.forward_openimage(batched_inputs['openimage'])
307
+ losses_openimage = {key.replace('mask', 'openimage'):value for key, value in losses_openimage.items()}
308
+ losses_openimage = {key.replace('grounding', 'grounding_openimage'):value for key, value in losses_openimage.items()}
309
+ losses.update(losses_openimage)
310
+ for k in list(losses.keys()):
311
+ if k in self.criterion.weight_dict:
312
+ losses[k] *= self.criterion.weight_dict[k]
313
+ else: # remove this loss if not specified in `weight_dict`
314
+ losses.pop(k)
315
+ return losses
316
+ else:
317
+ if mode == 'interactive':
318
+ return self.evaluate_interactive(batched_inputs)
319
+ elif mode == 'interactive_grounding':
320
+ return self.evaluate_interactive_grounding(batched_inputs)
321
+ elif mode == 'grounding_spatial':
322
+ return self.evaluate_grounding_sptial(batched_inputs, mode)
323
+ elif mode in ['grounding_phrasecut', 'grounding_refcoco']:
324
+ return self.evaluate_grounding(batched_inputs, mode)
325
+ else:
326
+ return self.evaluate(batched_inputs)
327
+
328
+
329
+ def forward_seg(self, batched_inputs):
330
+ images = [x["image"].to(self.device) for x in batched_inputs]
331
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
332
+ images = ImageList.from_tensors(images, self.size_divisibility)
333
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
334
+
335
+ extra = {}
336
+ # mask classification target
337
+ if "instances" in batched_inputs[0]:
338
+ # input bounding box is checked to be correct.
339
+ targets = self.prepare_targets(batched_inputs, images)
340
+
341
+ if self.task_switch['grounding']:
342
+ grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
343
+ grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens, padding_value=-1)
344
+ non_zero_query_mask = (grounding_tokens.sum(dim=-1) == -grounding_tokens.shape[-1])
345
+ grounding_tokens[non_zero_query_mask] = 0
346
+
347
+ extra['grounding_tokens'] = grounding_tokens
348
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
349
+
350
+ if self.task_switch['spatial']:
351
+ pos_masks = [x['spatial_query']['rand_shape'].to(self.device) for x in batched_inputs]
352
+ neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs]
353
+ fp_masks = nn.utils.rnn.pad_sequence([(x['spatial_query']['rand_shape'].to(self.device) & False) for x in batched_inputs], padding_value=False, batch_first=True)
354
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks, 'false_positive_mask': fp_masks})
355
+
356
+ features = self.backbone(images.tensor)
357
+ mask_features, _, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
358
+
359
+ # forward spatial only without gradient
360
+ if self.task_switch['spatial']:
361
+ with torch.no_grad():
362
+ # generate random integeter between [0,3]
363
+ rand_iter_num = random.randint(0, self.train_max_iter)
364
+ for i in range(rand_iter_num):
365
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='spatial')
366
+ extra.update(outputs)
367
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs))
368
+
369
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, extra=extra, task='seg')
370
+
371
+ extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
372
+ 'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default')),
373
+ 'false_positive_mask': extra['false_positive_mask']}
374
+ # bipartite matching-based loss
375
+ self.criterion.losses = self.losses['seg'] # seg criterion losses
376
+
377
+ if self.task_switch['mask']:
378
+ losses = self.criterion(outputs, targets, extra)
379
+ else:
380
+ losses = self.criterion.forward_vlp(outputs, targets, extra)
381
+
382
+ del outputs
383
+ return losses
384
+
385
+ def evaluate(self, batched_inputs):
386
+ images = [x["image"].to(self.device) for x in batched_inputs]
387
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
388
+
389
+ images = ImageList.from_tensors(images, self.size_divisibility)
390
+ img_bs = images.tensor.shape[0]
391
+
392
+ targets = targets_grounding = queries_grounding = None
393
+ features = self.backbone(images.tensor)
394
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
395
+
396
+ mask_cls_results = outputs["pred_logits"]
397
+ mask_pred_results = outputs["pred_masks"]
398
+ box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
399
+
400
+ # upsample masks
401
+ mask_pred_results = F.interpolate(
402
+ mask_pred_results,
403
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
404
+ mode="bilinear",
405
+ align_corners=False,
406
+ )
407
+
408
+ input_size = mask_pred_results.shape[-2:]
409
+ del outputs
410
+
411
+ processed_results = []
412
+ for mask_cls_result, mask_pred_result, box_pred_result, input_per_image, image_size in zip(
413
+ mask_cls_results, mask_pred_results, box_pred_results, batched_inputs, images.image_sizes
414
+ ):
415
+ height = input_per_image.get("height", image_size[0])
416
+ width = input_per_image.get("width", image_size[1])
417
+ processed_results.append({})
418
+
419
+ if self.sem_seg_postprocess_before_inference:
420
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
421
+ mask_pred_result, image_size, height, width
422
+ )
423
+ mask_cls_result = mask_cls_result.to(mask_pred_result)
424
+
425
+ # semantic segmentation inference
426
+ if self.semantic_on:
427
+ r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
428
+ if not self.sem_seg_postprocess_before_inference:
429
+ r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
430
+ processed_results[-1]["sem_seg"] = r
431
+
432
+ # panoptic segmentation inference
433
+ if self.panoptic_on:
434
+ panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
435
+ processed_results[-1]["panoptic_seg"] = panoptic_r
436
+
437
+ # instance segmentation inference
438
+ if self.instance_on:
439
+ if self.task_switch['bbox']:
440
+ box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
441
+ instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
442
+ processed_results[-1]["instances"] = instance_r
443
+
444
+ return processed_results
445
+
446
+ def evaluate_interactive(self, batched_inputs):
447
+ assert self.task_switch['spatial']
448
+ assert 'spatial_query' in batched_inputs[0]
449
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
450
+
451
+ images = [x["image"].to(self.device) for x in batched_inputs]
452
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
453
+ images = ImageList.from_tensors(images, self.size_divisibility)
454
+ img_bs = images.tensor.shape[0]
455
+
456
+ targets = targets_grounding = queries_grounding = None
457
+ extra = {}
458
+
459
+ features = self.backbone(images.tensor)
460
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
461
+
462
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
463
+
464
+ all_batch_shape_iou = []
465
+ pred_smask_pointer = None
466
+ prev_smask_pointer = None
467
+ pred_smask_all = None
468
+
469
+ # visualization code
470
+ # v_pred_mask = []
471
+ # v_pos_mask = []
472
+ # v_neg_mask = []
473
+ # v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
474
+ query_index = self.sem_seg_head.predictor.query_index
475
+ if self.interactive_mode in ['best', 'best_random']:
476
+ pos_masks = [x['spatial_query']['rand_shape'].to(self.device)[:,0] for x in batched_inputs]
477
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
478
+
479
+ neg_masks = [(x['spatial_query']['rand_shape'].to(self.device) & False)[:,0] for x in batched_inputs]
480
+
481
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
482
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
483
+ elif self.interactive_mode == 'random':
484
+ assert False, "interactive mode not correctly implemented"
485
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
486
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
487
+
488
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
489
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
490
+ extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
491
+ else:
492
+ assert False, "invalid interactive mode"
493
+
494
+ for i in range(self.interactive_iter):
495
+ # v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
496
+ # v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
497
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
498
+ extra.update(outputs)
499
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
500
+ # v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
501
+
502
+ s = image_sizes[0]
503
+ b = batched_inputs[0]
504
+ pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[0].sigmoid() > 0.5
505
+ gt_smask = b['gt_masks_orisize']
506
+ ious = get_iou(gt_smask, pred_smask_all)
507
+ all_batch_shape_iou += [ious]
508
+ if (ious > 0.9).sum() == len(ious):
509
+ all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
510
+ break
511
+ if self.interactive_mode in ['best', 'best_random']:
512
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
513
+ elif self.interactive_mode == 'random':
514
+ extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
515
+ else:
516
+ assert False, "invalid interactive mode"
517
+ all_batch_shape_iou = torch.stack(all_batch_shape_iou)
518
+ processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
519
+
520
+ return processed_results
521
+
522
+ def evaluate_interactive_single(self, batched_inputs, extra={}):
523
+ assert self.task_switch['spatial']
524
+ assert 'spatial_query' in batched_inputs[0]
525
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
526
+
527
+ images = [x["image"].to(self.device) for x in batched_inputs]
528
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
529
+ images = ImageList.from_tensors(images, self.size_divisibility)
530
+ img_bs = images.tensor.shape[0]
531
+
532
+ targets = targets_grounding = queries_grounding = None
533
+
534
+ features = self.backbone(images.tensor)
535
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
536
+
537
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
538
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
539
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
540
+ mask_features = mask_features.repeat(nm,1,1,1)
541
+
542
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
543
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bicubic')
544
+
545
+ s = image_sizes[0]
546
+ b = batched_inputs[0]
547
+ pred_smask_ori = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bicubic')[:,0].sigmoid() > 0.5
548
+ pred_smask_batch = pred_smask[:,:,:s[0],:s[1]].sigmoid() > 0.5
549
+ ious = []
550
+ if 'gt_masks_orisize' in b:
551
+ gt_smask = b['gt_masks_orisize'].to(pred_smask_ori.device)
552
+ ious = get_iou(gt_smask, pred_smask_ori)
553
+ processed_results = [{"mask_iou": ious, 'pred_mask_ori': pred_smask_ori, 'pred_mask_batch': pred_smask_batch}]
554
+ return processed_results
555
+
556
+ def evaluate_interactive_grounding(self, batched_inputs):
557
+ assert self.task_switch['spatial']
558
+ assert 'spatial_query' in batched_inputs[0]
559
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
560
+
561
+ images = [x["image"].to(self.device) for x in batched_inputs]
562
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
563
+ images = ImageList.from_tensors(images, self.size_divisibility)
564
+ img_bs = images.tensor.shape[0]
565
+
566
+ targets = targets_grounding = queries_grounding = None
567
+ extra = {}
568
+
569
+ features = self.backbone(images.tensor)
570
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
571
+
572
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
573
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
574
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
575
+ mask_features = mask_features.repeat(nm,1,1,1)
576
+
577
+ all_batch_shape_iou = []
578
+ pred_smask_pointer = None
579
+ prev_smask_pointer = None
580
+ pred_smask_all = None
581
+
582
+ # visualization code
583
+ # v_pred_mask = []
584
+ # v_pos_mask = []
585
+ # v_neg_mask = []
586
+ # v_gt_mask = batched_inputs[0]['spatial_query']['gt_masks'][0]
587
+ query_index = self.sem_seg_head.predictor.query_index
588
+ if self.interactive_mode in ['best', 'best_random']:
589
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
590
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
591
+
592
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
593
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
594
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
595
+ elif self.interactive_mode == 'random':
596
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==1).unbind(0)
597
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
598
+
599
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)==-1).unbind(0)
600
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor
601
+ extra.update({'spatial_query_pos_mask': pos_masks[:,0:1].unbind(), 'spatial_query_neg_mask': neg_masks[:,0:1].unbind()})
602
+ else:
603
+ assert False, "invalid interactive mode"
604
+
605
+ grd_texts = batched_inputs[0]['classes']
606
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
607
+ token_emb = gtext['token_emb']
608
+ tokens = gtext['tokens']
609
+ query_emb = nn.utils.rnn.pad_sequence([_token_emb[_tokens.bool()] for _token_emb, _tokens in zip(token_emb, tokens['attention_mask'])], padding_value=-1)
610
+ non_zero_query_mask = (query_emb.sum(dim=-1) < 0)
611
+
612
+ extra['grounding_tokens'] = query_emb
613
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
614
+
615
+ for i in range(self.interactive_iter):
616
+ # v_pos_mask += [extra['spatial_query_pos_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
617
+ # v_neg_mask += [extra['spatial_query_neg_mask'][0][0][:image_sizes[0][0],:image_sizes[0][1]].float().cpu().numpy()]
618
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='spatial')
619
+ extra.update(outputs)
620
+ pred_smask = F.interpolate(outputs['prev_mask'], images.tensor.shape[-2:], mode='bilinear')
621
+ # v_pred_mask += [(pred_smask[0,0][:image_sizes[0][0],:image_sizes[0][1]].sigmoid() > 0.5).float().cpu().numpy()]
622
+
623
+ s = image_sizes[0]
624
+ b = batched_inputs[0]
625
+ pred_smask_all = F.interpolate(pred_smask[:,:,:s[0],:s[1]], (b['height'], b['width']), mode='bilinear')[:,0].sigmoid() > 0.5
626
+ gt_smask = b['gt_masks_orisize']
627
+ ious = get_iou(gt_smask, pred_smask_all)
628
+ all_batch_shape_iou += [ious]
629
+ if (ious > 0.9).sum() == len(ious):
630
+ all_batch_shape_iou += [ious for j in range(self.interactive_iter-i-1)]
631
+ break
632
+ if self.interactive_mode in ['best', 'best_random']:
633
+ extra.update(self.prepare_next_spaital_mask(extra, batched_inputs, mode=self.interactive_mode))
634
+ elif self.interactive_mode == 'random':
635
+ extra.update({'spatial_query_pos_mask': pos_masks[:,i+1:i+2].unbind(), 'spatial_query_neg_mask': neg_masks[:,i+1:i+2].unbind()})
636
+ else:
637
+ assert False, "invalid interactive mode"
638
+ all_batch_shape_iou = torch.stack(all_batch_shape_iou)
639
+ processed_results = [{"mask_iou": all_batch_shape_iou[:,i]} for i in range(len(all_batch_shape_iou[0]))]
640
+
641
+ # visualization
642
+ # VL.step()
643
+ # import cv2
644
+ # v_masks = []
645
+ # v_pos_masks = []
646
+ # v_neg_masks = []
647
+ # txt = []
648
+
649
+ # img = batched_inputs[0]['image'].permute(1,2,0).cpu().numpy()
650
+ # mask_img = VL.overlay_single_mask_to_image(img[:,:,::-1], v_gt_mask.cpu().float().numpy())
651
+ # acc_pos_mask = np.zeros(v_pos_mask[0].shape)
652
+ # acc_neg_mask = np.zeros(v_neg_mask[0].shape)
653
+ # for x,y,z,iou in zip(v_pos_mask, v_neg_mask, v_pred_mask, all_batch_shape_iou):
654
+ # # dilate x,y
655
+ # x = cv2.dilate(x, np.ones((5,5), np.uint8), iterations=3)
656
+ # y = cv2.dilate(y, np.ones((5,5), np.uint8), iterations=3)
657
+ # acc_pos_mask += x
658
+ # acc_neg_mask += y
659
+
660
+ # v_masks += [z]
661
+ # v_pos_masks += [acc_pos_mask.clip(0,1)]
662
+ # v_neg_masks += [acc_neg_mask.clip(0,1)]
663
+ # txt += ["pred_{}".format(str(iou[0].item())[0:5])]
664
+
665
+ # VL.add_image(img[:,:,::-1])
666
+ # VL.insert(mask_img, "gt_mask")
667
+ # VL.overlay_obj_mask_to_image_withposneg(img[:,:,::-1], v_masks, v_pos_masks, v_neg_masks, txt, max_len=20)
668
+ return processed_results
669
+
670
+ def evaluate_referring_image(self, batched_inputs, extra={}):
671
+ assert self.task_switch['spatial']
672
+ assert len(batched_inputs) == 1, "only support batch size equal to 1"
673
+ assert self.interactive_mode == 'best'
674
+
675
+ images = [x["image"].to(self.device) for x in batched_inputs]
676
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
677
+ images = ImageList.from_tensors(images, self.size_divisibility)
678
+ img_bs = images.tensor.shape[0]
679
+
680
+ targets = targets_grounding = queries_grounding = None
681
+ features = self.backbone(images.tensor)
682
+ mask_features, transformer_encoder_features, multi_scale_features = self.sem_seg_head.pixel_decoder.forward_features(features)
683
+
684
+ if 'spatial_query' in batched_inputs[0]:
685
+ image_sizes = [x["image"].shape[-2:] for x in batched_inputs]
686
+ nm = len(batched_inputs[0]['spatial_query']['rand_shape'])
687
+ multi_scale_features = [m.repeat(nm,1,1,1) for m in multi_scale_features]
688
+ mask_features = mask_features.repeat(nm,1,1,1)
689
+
690
+ query_index = self.sem_seg_head.predictor.query_index
691
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
692
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor.unbind(0)
693
+
694
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
695
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
696
+ extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
697
+
698
+ outputs = self.sem_seg_head.predictor(multi_scale_features, mask_features, target_queries=queries_grounding, extra=extra, task='refimg')
699
+ return outputs, images.tensor.shape
700
+
701
+ def evaluate_grounding(self, batched_inputs, mode):
702
+ images = [x["image"].to(self.device) for x in batched_inputs]
703
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
704
+ images = ImageList.from_tensors(images, self.size_divisibility)
705
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
706
+
707
+ extra = {}
708
+ # mask_pred_results = []
709
+ # for idx, batch_per_image in enumerate(batched_inputs):
710
+ # grd_texts = batch_per_image['groundings']['texts']
711
+ # grd_masks = []
712
+ # for anno_text in grd_texts:
713
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
714
+ # token_emb = gtext['token_emb']
715
+ # tokens = gtext['tokens']
716
+
717
+ # grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
718
+ # extra['grounding_tokens'] = grd_emb[:,None]
719
+
720
+ # assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
721
+ # features = self.backbone(images.tensor)
722
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
723
+
724
+ # pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
725
+ # v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
726
+ # t_emb = grd_emb[-1:]
727
+
728
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
729
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
730
+
731
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
732
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
733
+
734
+ # matched_id = out_prob.max(0)[1]
735
+ # grd_masks += [pred_gmasks[matched_id,:,:]]
736
+ # mask_pred_results += [torch.cat(grd_masks)]
737
+
738
+ # comment for multi object inference.
739
+ mask_pred_results = []
740
+ for idx, batch_per_image in enumerate(batched_inputs):
741
+ grd_texts = batch_per_image['groundings']['texts']
742
+ if self.standard_text_for_eval:
743
+ standard_texts = []
744
+ for grd in batch_per_image['grounding_info']:
745
+ mask_file = grd['mask_file'].split('.')[0].split('/')[-1]
746
+ target = mask_file.split('_')[-1].replace('+', ' ')
747
+ site = mask_file.split('_')[-2].replace('+', ' ')
748
+ modality = mask_file.split('_')[-3].replace('+', ' ')
749
+ standard_texts.append(f'{target} in {site} {modality}')
750
+ grd_texts = standard_texts
751
+ batch_per_image['groundings']['texts'] = standard_texts
752
+
753
+
754
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
755
+ token_emb = gtext['token_emb']
756
+ tokens = gtext['tokens']
757
+ query_emb = token_emb[tokens['attention_mask'].bool()]
758
+ non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
759
+
760
+ extra['grounding_tokens'] = query_emb[:,None]
761
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
762
+
763
+ features = self.backbone(images.tensor)
764
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
765
+
766
+ pred_gmasks = outputs['pred_gmasks'][idx]
767
+ v_emb = outputs['pred_gtexts'][idx]
768
+ t_emb = gtext['class_emb']
769
+
770
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
771
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
772
+
773
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
774
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
775
+
776
+ matched_id = out_prob.max(0)[1]
777
+ mask_pred_results += [pred_gmasks[matched_id,:,:]]
778
+
779
+ for i in range(len(mask_pred_results)):
780
+ # upsample masks
781
+ mask_pred_results[i] = F.interpolate(
782
+ mask_pred_results[i][None,],
783
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
784
+ mode="bilinear",
785
+ align_corners=False,
786
+ )[0]
787
+
788
+ processed_results = []
789
+ for mask_pred_result, input_per_image, image_size in zip(
790
+ mask_pred_results, batched_inputs, images.image_sizes
791
+ ):
792
+ height = input_per_image.get("height", image_size[0])
793
+ width = input_per_image.get("width", image_size[1])
794
+ processed_results.append({})
795
+
796
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
797
+ mask_pred_result, image_size, height, width
798
+ )
799
+ processed_results[-1]['grounding_mask'] = mask_pred_result
800
+
801
+ # compute bbox
802
+ # bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
803
+ # bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
804
+ # processed_results[-1]['grounding_box'] = bbox
805
+
806
+ return processed_results
807
+
808
+ def evaluate_grounding_sptial(self, batched_inputs, mode):
809
+ images = [x["image"].to(self.device) for x in batched_inputs]
810
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
811
+ images = ImageList.from_tensors(images, self.size_divisibility)
812
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
813
+
814
+ extra = {}
815
+ dilation = 3
816
+ pos_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device)).unbind(0)
817
+ pos_masks = ImageList.from_tensors(pos_masks, self.size_divisibility).tensor
818
+ pos_masks = (F.conv2d(pos_masks.float(), self.dilation_kernel, padding=dilation//2) > 0).unbind(0)
819
+
820
+ neg_masks = (batched_inputs[0]['spatial_query']['rand_shape'].to(self.device) & False).unbind(0)
821
+ neg_masks = ImageList.from_tensors(neg_masks, self.size_divisibility).tensor.unbind(0)
822
+
823
+ mask_pred_results = []
824
+ for idx, batch_per_image in enumerate(batched_inputs):
825
+ grd_texts = batch_per_image['groundings']['texts']
826
+ grd_masks = []
827
+ for idx2, anno_text in enumerate(grd_texts):
828
+ extra.update({'spatial_query_pos_mask': [pos_masks[idx2]], 'spatial_query_neg_mask': [neg_masks[idx2]]})
829
+
830
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
831
+ token_emb = gtext['token_emb']
832
+ tokens = gtext['tokens']
833
+
834
+ grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
835
+ non_zero_query_mask = torch.zeros(grd_emb[:,None].shape[:-1], dtype=torch.bool, device=grd_emb.device)
836
+ extra['grounding_tokens'] = grd_emb[:,None]
837
+ extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
838
+
839
+ assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
840
+ features = self.backbone(images.tensor)
841
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
842
+
843
+ pred_gmasks = outputs['pred_gmasks'][idx]
844
+ v_emb = outputs['pred_gtexts'][idx]
845
+ t_emb = gtext['class_emb']
846
+
847
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
848
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
849
+
850
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
851
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
852
+
853
+ matched_id = out_prob.max(0)[1]
854
+ grd_masks += [pred_gmasks[matched_id,:,:]]
855
+ # grd_masks += [outputs['prev_mask'][0]]
856
+
857
+ mask_pred_results += [torch.cat(grd_masks)]
858
+
859
+ # comment for multi object inference.
860
+ # mask_pred_results = []
861
+ # for idx, batch_per_image in enumerate(batched_inputs):
862
+ # grd_texts = batch_per_image['groundings']['texts']
863
+ # grd_texts = [x[0] for x in grd_texts]
864
+
865
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
866
+ # token_emb = gtext['token_emb']
867
+ # tokens = gtext['tokens']
868
+ # query_emb = token_emb[tokens['attention_mask'].bool()]
869
+ # non_zero_query_mask = torch.zeros(query_emb[:,None].shape[:-1], dtype=torch.bool, device=query_emb.device)
870
+
871
+ # extra['grounding_tokens'] = query_emb[:,None]
872
+ # extra['grounding_nonzero_mask'] = non_zero_query_mask.t()
873
+
874
+ # features = self.backbone(images.tensor)
875
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
876
+
877
+ # pred_gmasks = outputs['pred_gmasks'][idx]
878
+ # v_emb = outputs['pred_gtexts'][idx]
879
+ # t_emb = gtext['class_emb']
880
+
881
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
882
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
883
+
884
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
885
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
886
+
887
+ # matched_id = out_prob.max(0)[1]
888
+ # mask_pred_results += [pred_gmasks[matched_id,:,:]]
889
+
890
+ for i in range(len(mask_pred_results)):
891
+ # upsample masks
892
+ mask_pred_results[i] = F.interpolate(
893
+ mask_pred_results[i][None,],
894
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
895
+ mode="bilinear",
896
+ align_corners=False,
897
+ )[0]
898
+
899
+ processed_results = []
900
+ for mask_pred_result, input_per_image, image_size in zip(
901
+ mask_pred_results, batched_inputs, images.image_sizes
902
+ ):
903
+ height = input_per_image.get("height", image_size[0])
904
+ width = input_per_image.get("width", image_size[1])
905
+ processed_results.append({})
906
+
907
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
908
+ mask_pred_result, image_size, height, width
909
+ )
910
+ processed_results[-1]['grounding_mask'] = mask_pred_result
911
+
912
+ return processed_results
913
+
914
+ def prepare_targets(self, batched_inputs, images):
915
+ h_pad, w_pad = images.tensor.shape[-2:]
916
+ new_targets = []
917
+ for idx, batch_per_image in enumerate(batched_inputs):
918
+ target_dict = {}
919
+ if self.task_switch['mask']:
920
+ targets_per_image = batch_per_image['instances'].to(self.device)
921
+ # pad gt
922
+ gt_masks = targets_per_image.gt_masks.tensor
923
+ padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
924
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
925
+
926
+ gt_boxes = targets_per_image.gt_boxes.tensor
927
+ ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
928
+ gt_boxes = gt_boxes / ratio
929
+ xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
930
+ gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
931
+
932
+ target_dict.update({
933
+ "labels": targets_per_image.gt_classes,
934
+ "is_things": targets_per_image.is_things,
935
+ "masks": padded_masks,
936
+ "boxes": gt_boxes,
937
+ })
938
+
939
+ if self.task_switch['spatial']:
940
+ # prepare targets for spatial query
941
+ target_dict['gt_spatial_masks'] = batch_per_image['spatial_query']['gt_masks']
942
+
943
+ if self.task_switch['grounding']:
944
+ grd_masks = batch_per_image['groundings']['masks']
945
+ grd_texts = batch_per_image['groundings']['texts']
946
+ grd_hash = batch_per_image['groundings']['hash']
947
+ grd_task = batch_per_image['groundings']['mode']
948
+
949
+ if len(grd_masks) == 0:
950
+ padded_masks = None
951
+ else:
952
+ padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
953
+ padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
954
+
955
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
956
+ token_emb = gtext['token_emb']
957
+ tokens = gtext['tokens']
958
+
959
+ unique_hash_id = np.unique(grd_hash, return_index=True)[1]
960
+ selected_mask = np.zeros(len(grd_hash)).astype(bool)
961
+ selected_mask[unique_hash_id] = True
962
+
963
+ selected_token_emb = token_emb[selected_mask]
964
+ selected_attn_mask = tokens['attention_mask'][selected_mask]
965
+ query_emb = selected_token_emb[selected_attn_mask.bool()]
966
+
967
+ class_idx = tokens['attention_mask'].sum(dim=-1) - 1
968
+ class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
969
+ class_emb = token_emb[class_idx]
970
+
971
+ target_dict['grounding_masks'] = padded_masks
972
+ target_dict['grounding_query_embs'] = query_emb
973
+ target_dict['grounding_class_embs'] = class_emb
974
+ target_dict['grounding_hash'] = grd_hash
975
+ target_dict['grounding_task'] = grd_task
976
+
977
+ new_targets.append(target_dict)
978
+ return new_targets
979
+
980
+ def prepare_next_spaital_mask(self, outputs, batched_inputs, mode='best'):
981
+ gt_masks = [batched_inputs[i]['spatial_query']['gt_masks'] for i in range(len(batched_inputs))]
982
+ gt_masks = Spatial_ImageList.from_tensors(gt_masks, self.size_divisibility).tensor
983
+
984
+ pred_masks = (F.interpolate(outputs['prev_mask'], size=gt_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5)
985
+ prev_masks = nn.utils.rnn.pad_sequence(outputs['spatial_query_pos_mask'], padding_value=False, batch_first=True) | \
986
+ nn.utils.rnn.pad_sequence(outputs['spatial_query_neg_mask'], padding_value=False, batch_first=True)
987
+
988
+ fn = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks) # fn: False Negative, gt:1, pred:0, prev:0
989
+ fp = (~gt_masks & pred_masks) & (~prev_masks) # fp: False Positive, gt:0, pred:1, prev:0
990
+
991
+ # compute iou between gt and pred
992
+ iou = (gt_masks & pred_masks).sum(list(range(2,len(fn.shape)))) / ((gt_masks | pred_masks).sum(dim=list(range(2,len(fn.shape)))) + 1e-8)
993
+ fn_sum = fn.sum(dim=list(range(2,len(fn.shape))))
994
+ fp_sum = fp.sum(dim=list(range(2,len(fp.shape))))
995
+
996
+ is_postive = fn_sum > fp_sum
997
+ select_mask = torch.zeros_like(fn)
998
+ select_mask[is_postive] = fn[is_postive]
999
+ select_mask[~is_postive] = fp[~is_postive]
1000
+ # is_postive = torch.ones(len(fn_sum), device=torch.cuda.current_device()).bool()
1001
+
1002
+ # conv implementation
1003
+ bs,ns,h,w = select_mask.shape
1004
+ mask_dt = (distance_transform((~F.pad(select_mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(bs*ns,-1)
1005
+ if mode == 'best':
1006
+ max_xy_idx = torch.stack([torch.arange(bs*ns), mask_dt.max(dim=-1)[1].cpu()]).tolist()
1007
+ elif mode == 'best_random':
1008
+ max_xy_idx = torch.stack([torch.arange(bs*ns), torch.cat([(mask_dt[i] > 0).nonzero()[torch.randint(0, len((mask_dt[i] > 0).nonzero()), (1,))][0] for i in range(len(mask_dt))]).cpu()]).tolist()
1009
+ next_mask = torch.zeros(gt_masks.shape, device=torch.cuda.current_device()).bool()
1010
+ next_mask = next_mask.view(bs*ns,-1)
1011
+ next_mask[max_xy_idx] = True
1012
+ next_mask = next_mask.reshape((bs*ns,1,h,w)).float()
1013
+ dilation = 3
1014
+ next_mask = F.conv2d(next_mask, self.dilation_kernel, padding=dilation//2).reshape(bs,ns,h,w) > 0
1015
+
1016
+ # determine whether next mask is zero
1017
+ keep = (iou < 0.925)
1018
+ next_mask = next_mask & keep.view(bs,ns,1,1)
1019
+
1020
+ pos_mask = []
1021
+ neg_mask = []
1022
+ for idx, ip in enumerate(is_postive):
1023
+ mask_len = len(outputs['spatial_query_pos_mask'][idx])
1024
+ pos_mask += [outputs['spatial_query_pos_mask'][idx] | (next_mask[idx][:mask_len] & ip[:mask_len,None,None])]
1025
+ neg_mask += [outputs['spatial_query_neg_mask'][idx] | (next_mask[idx][:mask_len] & (~ip[:mask_len,None,None]))]
1026
+
1027
+ if 'false_positive_mask' in outputs:
1028
+ fp = outputs['false_positive_mask'] | fp
1029
+ return {'spatial_query_pos_mask': pos_mask, 'spatial_query_neg_mask': neg_mask, 'false_positive_mask': fp}
1030
+
1031
+ def semantic_inference(self, mask_cls, mask_pred):
1032
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
1033
+ mask_pred = mask_pred.sigmoid()
1034
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
1035
+ return semseg
1036
+
1037
+ def panoptic_inference(self, mask_cls, mask_pred):
1038
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
1039
+ mask_pred = mask_pred.sigmoid()
1040
+
1041
+ keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
1042
+ cur_scores = scores[keep]
1043
+ cur_classes = labels[keep]
1044
+ cur_masks = mask_pred[keep]
1045
+ cur_mask_cls = mask_cls[keep]
1046
+ cur_mask_cls = cur_mask_cls[:, :-1]
1047
+
1048
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
1049
+
1050
+ h, w = cur_masks.shape[-2:]
1051
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
1052
+ segments_info = []
1053
+
1054
+ current_segment_id = 0
1055
+
1056
+ if cur_masks.shape[0] == 0:
1057
+ # We didn't detect any mask :(
1058
+ return panoptic_seg, segments_info
1059
+ else:
1060
+ # take argmax
1061
+ cur_mask_ids = cur_prob_masks.argmax(0)
1062
+ stuff_memory_list = {}
1063
+ for k in range(cur_classes.shape[0]):
1064
+ pred_class = cur_classes[k].item()
1065
+ isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
1066
+ mask_area = (cur_mask_ids == k).sum().item()
1067
+ original_area = (cur_masks[k] >= 0.5).sum().item()
1068
+ mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
1069
+
1070
+ if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
1071
+ if mask_area / original_area < self.overlap_threshold:
1072
+ continue
1073
+
1074
+ # merge stuff regions
1075
+ if not isthing:
1076
+ if int(pred_class) in stuff_memory_list.keys():
1077
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
1078
+ continue
1079
+ else:
1080
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
1081
+
1082
+ current_segment_id += 1
1083
+ panoptic_seg[mask] = current_segment_id
1084
+
1085
+ segments_info.append(
1086
+ {
1087
+ "id": current_segment_id,
1088
+ "isthing": bool(isthing),
1089
+ "category_id": int(pred_class),
1090
+ }
1091
+ )
1092
+
1093
+ return panoptic_seg, segments_info
1094
+
1095
+ def instance_inference(self, mask_cls, mask_pred, box_pred):
1096
+ # mask_pred is already processed to have the same shape as original input
1097
+ image_size = mask_pred.shape[-2:]
1098
+
1099
+ # [Q, K]
1100
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
1101
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
1102
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
1103
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
1104
+
1105
+ labels_per_image = labels[topk_indices]
1106
+ topk_indices = (topk_indices // self.sem_seg_head.num_classes)
1107
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
1108
+ mask_pred = mask_pred[topk_indices]
1109
+ if box_pred is not None:
1110
+ box_pred = box_pred[topk_indices]
1111
+
1112
+ # if this is panoptic segmentation, we only keep the "thing" classes
1113
+ if self.panoptic_on:
1114
+ keep = torch.zeros_like(scores_per_image).bool()
1115
+ for i, lab in enumerate(labels_per_image):
1116
+ keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
1117
+
1118
+ scores_per_image = scores_per_image[keep]
1119
+ labels_per_image = labels_per_image[keep]
1120
+ mask_pred = mask_pred[keep]
1121
+
1122
+ if box_pred is not None:
1123
+ box_pred = box_pred[keep]
1124
+
1125
+ result = Instances(image_size)
1126
+ # mask (before sigmoid)
1127
+ result.pred_masks = (mask_pred > 0).float()
1128
+ # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
1129
+ # Uncomment the following to get boxes from masks (this is slow)
1130
+
1131
+ if box_pred is not None:
1132
+ result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
1133
+ else:
1134
+ result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
1135
+
1136
+ # calculate average mask prob
1137
+ mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
1138
+ result.scores = scores_per_image * mask_scores_per_image
1139
+ result.pred_classes = labels_per_image
1140
+
1141
+ return result
1142
+
1143
+ def prepare_targets4query(self, targets, images, topk=5):
1144
+ h_pad, w_pad = images.tensor.shape[-2:]
1145
+ new_targets = []
1146
+ new_queries = []
1147
+ for targets_per_image in targets:
1148
+ # we randomly sample maximally topk concepts
1149
+ unique_target_classes = [k for k in set(targets_per_image.gt_classes.tolist())]
1150
+ selected_target_classes = random.sample(unique_target_classes, min(topk, len(unique_target_classes)))
1151
+ new_targets_per_image = []
1152
+ new_queries_per_image = []
1153
+ for clss in selected_target_classes:
1154
+ indices = (targets_per_image.gt_classes == clss).nonzero().view(-1)
1155
+ # pad gt
1156
+ gt_masks = targets_per_image.gt_masks[indices]
1157
+ padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
1158
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
1159
+
1160
+ # convert class into concept name and then token seq
1161
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings([BIOMED_CLASSES[clss]], name='grounding')
1162
+ query = getattr(self.sem_seg_head.predictor.lang_encoder, 'grounding_text_embeddings')
1163
+
1164
+ new_targets.append(
1165
+ {
1166
+ "labels": targets_per_image.gt_classes[indices],
1167
+ "masks": padded_masks,
1168
+ }
1169
+ )
1170
+ new_queries_per_image.append(query)
1171
+ new_queries.append(new_queries_per_image)
1172
+
1173
+ return new_targets, new_queries
1174
+
1175
+
1176
+
1177
+ @register_model
1178
+ def get_seem_model(cfg, **kwargs):
1179
+ return GeneralizedSEEM(cfg)
modeling/architectures/xdecoder_model.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou ([email protected]), Ziyi Dou, Jianwei Yang
6
+ # --------------------------------------------------------
7
+
8
+ from typing import Tuple
9
+ import random
10
+
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+ import numpy as np
15
+
16
+ from timm.models.layers import trunc_normal_
17
+ from nltk.stem.lancaster import LancasterStemmer
18
+ from detectron2.structures import Boxes, ImageList, Instances, BitMasks, BoxMode
19
+ from detectron2.utils.memory import retry_if_cuda_oom
20
+ from detectron2.data import MetadataCatalog
21
+
22
+ from .build import register_model
23
+ from ..utils import configurable, get_class_names
24
+ from ..vision.backbone import build_backbone, Backbone
25
+ from ..body import build_xdecoder_head
26
+ from ..modules import sem_seg_postprocess, SetCriterion, HungarianMatcher, bbox_postprocess
27
+ from ..language import build_language_encoder
28
+ from ..language.loss import vl_similarity, image_text_contrastive_loss_queue
29
+ from utilities.prompt_engineering import prompt_engineering
30
+ from utilities.constants import COCO_PANOPTIC_CLASSES
31
+
32
+ st = LancasterStemmer()
33
+
34
+
35
+ class GeneralizedXdecoder(nn.Module):
36
+
37
+ @configurable
38
+ def __init__(
39
+ self,
40
+ *,
41
+ backbone: Backbone,
42
+ sem_seg_head: nn.Module,
43
+ criterion: nn.Module,
44
+ losses: dict,
45
+ num_queries: int,
46
+ object_mask_threshold: float,
47
+ overlap_threshold: float,
48
+ metadata,
49
+ task_switch: dict,
50
+ phrase_prob: float,
51
+ size_divisibility: int,
52
+ sem_seg_postprocess_before_inference: bool,
53
+ pixel_mean: Tuple[float],
54
+ pixel_std: Tuple[float],
55
+ # inference
56
+ semantic_on: bool,
57
+ panoptic_on: bool,
58
+ instance_on: bool,
59
+ test_topk_per_image: int,
60
+ train_dataset_name: str,
61
+ retrieval_emsemble: bool,
62
+ backbone_dim: int,
63
+ dim_proj: int,
64
+ ):
65
+ """
66
+ Args:
67
+ backbone: a backbone module, must follow detectron2's backbone interface
68
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
69
+ criterion: a module that defines the loss
70
+ num_queries: int, number of queries
71
+ object_mask_threshold: float, threshold to filter query based on classification score
72
+ for panoptic segmentation inference
73
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
74
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
75
+ segmentation inference
76
+ size_divisibility: Some backbones require the input height and width to be divisible by a
77
+ specific integer. We can use this to override such requirement.
78
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
79
+ to original input size before semantic segmentation inference or after.
80
+ For high-resolution dataset like Mapillary, resizing predictions before
81
+ inference will cause OOM error.
82
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
83
+ the per-channel mean and std to be used to normalize the input image
84
+ semantic_on: bool, whether to output semantic segmentation prediction
85
+ instance_on: bool, whether to output instance segmentation prediction
86
+ panoptic_on: bool, whether to output panoptic segmentation prediction
87
+ test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
88
+ """
89
+ super().__init__()
90
+ self.backbone = backbone
91
+ self.sem_seg_head = sem_seg_head
92
+ self.criterion = criterion
93
+ self.losses = losses
94
+ self.num_queries = num_queries
95
+ self.overlap_threshold = overlap_threshold
96
+ self.object_mask_threshold = object_mask_threshold
97
+ self.metadata = metadata
98
+ if size_divisibility < 0:
99
+ # use backbone size_divisibility if not set
100
+ size_divisibility = self.backbone.size_divisibility
101
+ self.size_divisibility = size_divisibility
102
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
103
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
104
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
105
+
106
+ # additional args
107
+ self.semantic_on = semantic_on
108
+ self.instance_on = instance_on
109
+ self.panoptic_on = panoptic_on
110
+
111
+ # caption argument
112
+ self.task_switch = task_switch
113
+ self.phrase_prob = phrase_prob
114
+
115
+ self.test_topk_per_image = test_topk_per_image
116
+ self.train_class_names = get_class_names(train_dataset_name)
117
+
118
+ self.retrieval_emsemble = retrieval_emsemble
119
+ # backbone itc loss
120
+ if task_switch['retrieval'] and retrieval_emsemble:
121
+ self.backbone_proj = nn.Parameter(torch.empty(backbone_dim, dim_proj))
122
+ trunc_normal_(self.backbone_proj, std=.02)
123
+
124
+ if not self.semantic_on:
125
+ assert self.sem_seg_postprocess_before_inference
126
+
127
+ @classmethod
128
+ def from_config(cls, cfg):
129
+ enc_cfg = cfg['MODEL']['ENCODER']
130
+ dec_cfg = cfg['MODEL']['DECODER']
131
+
132
+ # Loss parameters:
133
+ deep_supervision = dec_cfg['DEEP_SUPERVISION']
134
+ no_object_weight = dec_cfg['NO_OBJECT_WEIGHT']
135
+
136
+ # loss weights, switcher for task, and top layers to compute loss
137
+ loss_weights = {'mask': {'ce': dec_cfg['CLASS_WEIGHT'], 'dice': dec_cfg['DICE_WEIGHT'], 'bce': dec_cfg['MASK_WEIGHT']},
138
+ 'bbox': {'l1': dec_cfg['BBOX_WEIGHT'], 'giou': dec_cfg['GIOU_WEIGHT']},
139
+ 'caption': dec_cfg['CAPTION_WEIGHT'],
140
+ 'captioning': dec_cfg['CAPTIONING_WEIGHT'],
141
+ 'retrieval': {'decoder': dec_cfg['RETRIEVAL_WEIGHT'], 'backbone': dec_cfg['BACKBONER_WEIGHT']},
142
+ 'grounding': {'ce': dec_cfg['GCLASS_WEIGHT'], 'dice': dec_cfg['GDICE_WEIGHT'], 'bce': dec_cfg['GMASK_WEIGHT']}}
143
+
144
+ task_switch = {'bbox': dec_cfg.get('DETECTION', False),
145
+ 'mask': dec_cfg.get('MASK', True),
146
+ 'caption': dec_cfg['CAPTION'].get('ENABLED', False),
147
+ 'captioning': dec_cfg['CAPTIONING'].get('ENABLED', False),
148
+ 'retrieval': dec_cfg['RETRIEVAL'].get('ENABLED', False),
149
+ 'grounding': dec_cfg['GROUNDING'].get('ENABLED', False)}
150
+
151
+ top_x_layers = {'mask': dec_cfg.get('TOP_MASK_LAYERS', 10),
152
+ 'caption': dec_cfg.get('TOP_CAPTION_LAYERS', 10),
153
+ 'captioning': dec_cfg.get('TOP_CAPTIONING_LAYERS', 10),
154
+ 'retrieval': dec_cfg.get('TOP_RETRIEVAL_LAYERS', 10),
155
+ 'grounding': dec_cfg.get('TOP_GROUNDING_LAYERS', 10),}
156
+
157
+ # build model
158
+ extra = {'task_switch': task_switch}
159
+ backbone = build_backbone(cfg)
160
+ lang_encoder = build_language_encoder(cfg)
161
+ sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra)
162
+
163
+ # building criterion
164
+ matcher = HungarianMatcher(
165
+ cost_class=loss_weights['mask']['ce'],
166
+ cost_mask=loss_weights['mask']['bce'],
167
+ cost_dice=loss_weights['mask']['dice'],
168
+ num_points=dec_cfg['TRAIN_NUM_POINTS'],
169
+ )
170
+
171
+ # init weight dict and criterion loss functions.
172
+ losses = {'seg': [], 'vlp': []}
173
+ if task_switch['mask']:
174
+ losses['seg'] += ["labels", "masks"]
175
+ if task_switch['caption']:
176
+ losses['seg'] += ["captions"]
177
+ if task_switch['grounding']:
178
+ losses['seg'] += ["groundings"]
179
+ if task_switch['captioning']:
180
+ losses['vlp'] += ["captionings"]
181
+ if task_switch['retrieval']:
182
+ losses['vlp'] += ["retrievals"]
183
+
184
+ weight_dict = {}
185
+ for key, turn_on in task_switch.items():
186
+ if turn_on:
187
+ if isinstance(loss_weights[key], dict):
188
+ # HACK it should support bbox in the future
189
+ for key_, weight in loss_weights[key].items():
190
+ weight_dict["loss_{}_{}_0".format(key, key_)] = weight # NOTE: hard code for segmentation that has multiple loss
191
+ else:
192
+ weight_dict["loss_{}_0".format(key)] = loss_weights[key]
193
+
194
+ # generate full weight dict and remove not computed layers.
195
+ if deep_supervision:
196
+ dec_layers = dec_cfg['DEC_LAYERS']
197
+ aux_weight_dict = {}
198
+ for i in range(dec_layers - 1):
199
+ for k, v in weight_dict.items():
200
+ if (i+1) > (top_x_layers[k.split('_')[1]] - 1):
201
+ continue
202
+ aux_weight_dict.update({k.replace('_0', f"_{i+1}"): v})
203
+ weight_dict.update(aux_weight_dict)
204
+
205
+ grd_weight = {'text': dec_cfg['GROUNDING']['TEXT_WEIGHT'], 'class': dec_cfg['GROUNDING']['CLASS_WEIGHT']}
206
+ # generate critenrion for loss function.
207
+ criterion = SetCriterion(
208
+ sem_seg_head.num_classes,
209
+ matcher=matcher,
210
+ weight_dict=weight_dict,
211
+ top_x_layers=top_x_layers,
212
+ eos_coef=no_object_weight,
213
+ losses=[],
214
+ num_points=dec_cfg['TRAIN_NUM_POINTS'],
215
+ oversample_ratio=dec_cfg['OVERSAMPLE_RATIO'],
216
+ importance_sample_ratio=dec_cfg['IMPORTANCE_SAMPLE_RATIO'],
217
+ grounding_weight=grd_weight,
218
+ )
219
+
220
+ # extra logistic
221
+ train_dataset_name = cfg['DATASETS']['TRAIN'][0] # HACK for only one training set.
222
+ phrase_prob = dec_cfg['CAPTION'].get('PHRASE_PROB', 0.5)
223
+
224
+ return {
225
+ "backbone": backbone,
226
+ "sem_seg_head": sem_seg_head,
227
+ "criterion": criterion,
228
+ "losses": losses,
229
+ "num_queries": dec_cfg['NUM_OBJECT_QUERIES'],
230
+ "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'],
231
+ "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'],
232
+ "metadata": MetadataCatalog.get(cfg['DATASETS']['TRAIN'][0]),
233
+ "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'],
234
+ "sem_seg_postprocess_before_inference": (
235
+ dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE']
236
+ or dec_cfg['TEST']['PANOPTIC_ON']
237
+ or dec_cfg['TEST']['INSTANCE_ON']
238
+ ),
239
+ "pixel_mean": cfg['INPUT']['PIXEL_MEAN'],
240
+ "pixel_std": cfg['INPUT']['PIXEL_STD'],
241
+ "task_switch": task_switch,
242
+ "phrase_prob": phrase_prob,
243
+ # inference
244
+ "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'],
245
+ "instance_on": dec_cfg['TEST']['INSTANCE_ON'],
246
+ "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'],
247
+ "test_topk_per_image": cfg['COCO']['TEST']['DETECTIONS_PER_IMAGE'],
248
+ "train_dataset_name": train_dataset_name,
249
+ "retrieval_emsemble": dec_cfg['RETRIEVAL']['ENSEMBLE'],
250
+ "backbone_dim": cfg['MODEL']['BACKBONE_DIM'],
251
+ "dim_proj": cfg['MODEL']['DIM_PROJ'],
252
+ }
253
+
254
+ @property
255
+ def device(self):
256
+ return self.pixel_mean.device
257
+
258
+ def forward(self, batched_inputs, mode=None):
259
+ """
260
+ Args:
261
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
262
+ Each item in the list contains the inputs for one image.
263
+ For now, each item in the list is a dict that contains:
264
+ * "image": Tensor, image in (C, H, W) format.
265
+ * "instances": per-region ground truth
266
+ * Other information that's included in the original dicts, such as:
267
+ "height", "width" (int): the output resolution of the model (may be different
268
+ from input resolution), used in inference.
269
+ Returns:
270
+ list[dict]:
271
+ each dict has the results for one image. The dict contains the following keys:
272
+
273
+ * "sem_seg":
274
+ A Tensor that represents the
275
+ per-pixel segmentation prediced by the head.
276
+ The prediction has shape KxHxW that represents the logits of
277
+ each class for each pixel.
278
+ * "panoptic_seg":
279
+ A tuple that represent panoptic output
280
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
281
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
282
+ Each dict contains keys "id", "category_id", "isthing".
283
+ """
284
+ if self.training:
285
+ losses = {}
286
+ if self.task_switch['mask']:
287
+ losses_seg = self.forward_seg(batched_inputs['coco'])
288
+ losses.update(losses_seg)
289
+ if self.task_switch['retrieval'] or self.task_switch['captioning']:
290
+ losses_vlp = self.forward_vlp(batched_inputs['vlp'])
291
+ losses.update(losses_vlp)
292
+ for k in list(losses.keys()):
293
+ if k in self.criterion.weight_dict:
294
+ losses[k] *= self.criterion.weight_dict[k]
295
+ else: # remove this loss if not specified in `weight_dict`
296
+ losses.pop(k)
297
+ return losses
298
+ else:
299
+ if mode == 'retrieval':
300
+ return self.evaluate_retrieval(batched_inputs)
301
+ elif mode == 'captioning':
302
+ return self.evaluate_captioning(batched_inputs)
303
+ elif mode == 'classification':
304
+ return self.evaluate_classification(batched_inputs)
305
+ elif mode == 'grounding_refcoco':
306
+ return self.evaluate_grounding(batched_inputs, mode)
307
+ else:
308
+ return self.evaluate(batched_inputs)
309
+
310
+
311
+ def forward_seg(self, batched_inputs):
312
+ images = [x["image"].to(self.device) for x in batched_inputs]
313
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
314
+ images = ImageList.from_tensors(images, self.size_divisibility)
315
+
316
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(self.train_class_names, is_eval=False)
317
+
318
+ extra = {}
319
+ # mask classification target
320
+ if "instances" in batched_inputs[0]:
321
+ # input bounding box is checked to be correct.
322
+ targets = self.prepare_targets(batched_inputs, images)
323
+
324
+ if self.task_switch['grounding']:
325
+ grounding_tokens = [x['grounding_query_embs'] for x in targets] # need to pad for more than one grounding token
326
+ grounding_tokens = nn.utils.rnn.pad_sequence(grounding_tokens)
327
+ extra['grounding_tokens'] = grounding_tokens
328
+
329
+ features = self.backbone(images.tensor)
330
+ outputs = self.sem_seg_head(features, extra=extra)
331
+
332
+ _outputs = {}
333
+ for key, value in outputs.items():
334
+ if key == 'pred_logits':
335
+ _outputs[key] = value[:,:self.num_queries-1]
336
+ elif key == 'pred_masks':
337
+ _outputs[key] = value[:,:self.num_queries-1]
338
+ if self.task_switch['grounding']:
339
+ _outputs['pred_gmasks'] = value[:,self.num_queries:2*self.num_queries-1]
340
+ elif key == 'pred_captions':
341
+ _outputs[key] = value[:,:self.num_queries-1]
342
+ if self.task_switch['grounding']:
343
+ _outputs['pred_gtexts'] = value[:,self.num_queries:2*self.num_queries-1]
344
+ elif key == 'aux_outputs':
345
+ _outputs[key] = []
346
+ for i in range(len(value)):
347
+ _outputs[key] += [{}]
348
+ for _key, _value in value[i].items():
349
+ if _key == 'pred_logits':
350
+ _outputs[key][i][_key] = _value[:,:self.num_queries-1]
351
+ elif _key == 'pred_masks':
352
+ _outputs[key][i][_key] = _value[:,:self.num_queries-1]
353
+ if self.task_switch['grounding']:
354
+ _outputs[key][i]['pred_gmasks'] = _value[:,self.num_queries:2*self.num_queries-1]
355
+ elif _key == 'pred_captions':
356
+ _outputs[key][i][_key] = _value[:,:self.num_queries-1]
357
+ if self.task_switch['grounding']:
358
+ _outputs[key][i]['pred_gtexts'] = _value[:,self.num_queries:2*self.num_queries-1]
359
+ outputs = _outputs
360
+
361
+ extra = {'lang_logit': self.sem_seg_head.predictor.lang_encoder.logit_scale,
362
+ 'class_embeddings': getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('default'))}
363
+
364
+ # bipartite matching-based loss
365
+ self.criterion.losses = self.losses['seg'] # seg criterion losses
366
+ losses = self.criterion(outputs, targets, extra)
367
+
368
+ del outputs
369
+ del _outputs
370
+ return losses
371
+
372
+ def forward_vlp(self, batched_inputs):
373
+ images = [x["image"].to(self.device) for x in batched_inputs]
374
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
375
+ images = ImageList.from_tensors(images, self.size_divisibility)
376
+ targets_vlp = self.prepare_vlp_targets(batched_inputs, images.tensor.device)
377
+
378
+ extra = {"token_embedding": self.sem_seg_head.predictor.lang_encoder.lang_encoder.token_embedding,
379
+ "lang_encoder": self.sem_seg_head.predictor.lang_encoder,
380
+ "training": self.training}
381
+
382
+ features = self.backbone(images.tensor)
383
+ outputs = self.sem_seg_head(features, target_queries=None, target_vlp=targets_vlp, task='vlp', extra=extra)
384
+
385
+ for key, value in outputs.items():
386
+ if key == 'pred_captionings':
387
+ outputs[key] = value
388
+ elif key == 'pred_captions':
389
+ # outputs[key] = value[:,-1:]
390
+ outputs[key] = value
391
+ elif key == 'aux_outputs':
392
+ outputs[key] = []
393
+ for i in range(len(value)):
394
+ outputs[key] += [{}]
395
+ for _key, _value in value[i].items():
396
+ if _key == 'pred_captions':
397
+ # outputs[key][i][_key] = _value[:,-1:]
398
+ outputs[key][i][_key] = _value
399
+ elif _key == 'pred_captionings':
400
+ outputs[key][i][_key] = _value
401
+
402
+ self.criterion.losses = self.losses['vlp'] # seg criterion losses
403
+ losses = self.criterion.forward_vlp(outputs, targets_vlp, extra)
404
+ del outputs
405
+
406
+ if self.task_switch['retrieval'] and self.retrieval_emsemble:
407
+ # compute backbone vlp.
408
+ v_emb = features['res5']
409
+ bs,nc,_,_ = v_emb.shape
410
+ v_emb = v_emb.reshape(bs,nc,-1)
411
+ v_emb = F.adaptive_avg_pool1d(v_emb, 1).reshape(bs,nc) @ self.backbone_proj
412
+ t_emb = torch.cat([x['caption_proj'] for x in targets_vlp], dim=0)
413
+ loss_contrast = image_text_contrastive_loss_queue(v_emb, t_emb, self.sem_seg_head.predictor.lang_encoder, None)
414
+ losses['loss_retrieval_backbone_0'] = loss_contrast
415
+ return losses
416
+
417
+ def evaluate(self, batched_inputs):
418
+ images = [x["image"].to(self.device) for x in batched_inputs]
419
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
420
+
421
+ images = ImageList.from_tensors(images, self.size_divisibility)
422
+ img_bs = images.tensor.shape[0]
423
+
424
+ targets = targets_grounding = queries_grounding = None
425
+ features = self.backbone(images.tensor)
426
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
427
+
428
+ mask_cls_results = outputs["pred_logits"]
429
+ mask_pred_results = outputs["pred_masks"]
430
+ box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))]
431
+ caption_pred_results = outputs["pred_captions"] if self.task_switch['caption'] else [None for i in range(len(mask_pred_results))]
432
+
433
+ # upsample masks
434
+ mask_pred_results = F.interpolate(
435
+ mask_pred_results,
436
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
437
+ mode="bicubic",
438
+ align_corners=False,
439
+ antialias=True
440
+ )
441
+
442
+ input_size = mask_pred_results.shape[-2:]
443
+ keep_sem_bgd = self.metadata.keep_sem_bgd if hasattr(self.metadata, 'keep_sem_bgd') else False
444
+ del outputs
445
+
446
+ processed_results = []
447
+ for mask_cls_result, mask_pred_result, box_pred_result, caption_pred_result, input_per_image, image_size in zip(
448
+ mask_cls_results, mask_pred_results, box_pred_results, caption_pred_results, batched_inputs, images.image_sizes
449
+ ):
450
+ height = input_per_image.get("height", image_size[0])
451
+ width = input_per_image.get("width", image_size[1])
452
+ processed_results.append({})
453
+
454
+ if self.sem_seg_postprocess_before_inference:
455
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
456
+ mask_pred_result, image_size, height, width
457
+ )
458
+ mask_cls_result = mask_cls_result.to(mask_pred_result)
459
+
460
+ # semantic segmentation inference
461
+ if self.semantic_on:
462
+ r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result, keep_sem_bgd)
463
+ if not self.sem_seg_postprocess_before_inference:
464
+ r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
465
+ processed_results[-1]["sem_seg"] = r
466
+
467
+ # panoptic segmentation inference
468
+ if self.panoptic_on:
469
+ panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
470
+ processed_results[-1]["panoptic_seg"] = panoptic_r
471
+
472
+ # instance segmentation inference
473
+ if self.instance_on:
474
+ if self.task_switch['bbox']:
475
+ box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width)
476
+ instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result)
477
+ processed_results[-1]["instances"] = instance_r
478
+ if self.task_switch['caption']:
479
+ processed_results[-1]["captions"] = caption_pred_result
480
+ processed_results[-1]["masks"] = mask_pred_result
481
+
482
+ return processed_results
483
+
484
+ def evaluate_retrieval(self, batched_inputs):
485
+ images = [x["image"].to(self.device) for x in batched_inputs]
486
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
487
+ images = ImageList.from_tensors(images, self.size_divisibility)
488
+ img_bs = images.tensor.shape[0]
489
+
490
+ targets = targets_grounding = queries_grounding = None
491
+ features = self.backbone(images.tensor)
492
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
493
+ v_emb_it = outputs['pred_captions'][:,-1]
494
+
495
+ # compute backbone score
496
+ if self.task_switch['retrieval'] and self.retrieval_emsemble:
497
+ _v_emb_it = features['res5']
498
+ bs,nc,_,_ = _v_emb_it.shape
499
+ _v_emb_it = _v_emb_it.reshape(bs,nc,-1)
500
+ _v_emb_it = F.adaptive_avg_pool1d(_v_emb_it, 1).reshape(bs,nc) @ self.backbone_proj
501
+
502
+ processed_results = []
503
+ for idx, batch_data in enumerate(batched_inputs):
504
+ caption_ids = []
505
+ t_emb_its = []
506
+ processed_results.append({})
507
+ for caption in batch_data['captions']:
508
+ lang_results = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(caption)
509
+ t_emb_it = lang_results['class_emb']
510
+ caption_ids.append(batch_data['image_id'])
511
+ t_emb_its.append(t_emb_it)
512
+
513
+ t_emb_it = torch.cat(t_emb_its, dim=0)
514
+
515
+ image_embeds = [v_emb_it[idx].unsqueeze(0)]
516
+ if self.task_switch['retrieval'] and self.retrieval_emsemble:
517
+ image_embeds += [_v_emb_it[idx].unsqueeze(0)]
518
+ caption_results = {
519
+ 'image_embeds': image_embeds,
520
+ 'text_embeds': t_emb_it,
521
+ 'caption_ids': caption_ids,
522
+ 'image_ids': batch_data['image_id'],
523
+ }
524
+ processed_results[-1]["caption"] = caption_results
525
+
526
+ del features
527
+ return processed_results
528
+
529
+ def evaluate_captioning(self, batched_inputs):
530
+ images = [x["image"].to(self.device) for x in batched_inputs]
531
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
532
+ images = ImageList.from_tensors(images, self.size_divisibility)
533
+ img_bs = images.tensor.shape[0]
534
+
535
+ if not hasattr(self, 'start_token'):
536
+ self.start_token = torch.tensor([[49406]*77], device=self.device)
537
+
538
+ targets = targets_grounding = queries_grounding = None
539
+ features = self.backbone(images.tensor)
540
+
541
+ captioning_mask = None
542
+ if 'captioning_mask' in batched_inputs[-1]:
543
+ captioning_mask = torch.cat([x['captioning_mask'] for x in batched_inputs])
544
+
545
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding, task='captioning_infer', extra={'start_token': self.start_token, 'captioning_mask': captioning_mask})
546
+
547
+ processed_results = []
548
+ for idx, batch_data in enumerate(batched_inputs):
549
+ processed_results.append({})
550
+ processed_results[-1]["captioning_token"] = outputs['pred_captionings'][idx]
551
+ processed_results[-1]["captioning_text"] = outputs['pred_texts'][idx].split('.')[0]
552
+ processed_results[-1]["image_id"] = batched_inputs[idx]['image_id']
553
+
554
+ return processed_results
555
+
556
+ def evaluate_classification(self, batched_inputs):
557
+ images = [x["image"].to(self.device) for x in batched_inputs]
558
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
559
+ images = ImageList.from_tensors(images, self.size_divisibility)
560
+ img_bs = images.tensor.shape[0]
561
+
562
+ targets = targets_grounding = queries_grounding = None
563
+ features = self.backbone(images.tensor)
564
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
565
+
566
+ processed_results = []
567
+ for idx, batch_data in enumerate(batched_inputs):
568
+ processed_results.append({})
569
+ processed_results[-1]["pred_class"] = outputs['pred_logits'][idx,-1]
570
+ return processed_results
571
+
572
+ def evaluate_grounding_baseline(self, batched_inputs, mode):
573
+ images = [x["image"].to(self.device) for x in batched_inputs]
574
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
575
+ images = ImageList.from_tensors(images, self.size_divisibility)
576
+ img_bs = images.tensor.shape[0]
577
+
578
+ targets = targets_grounding = queries_grounding = None
579
+ features = self.backbone(images.tensor)
580
+ outputs = self.sem_seg_head(features, target_queries=queries_grounding)
581
+
582
+ mask_pred_results = outputs["pred_masks"]
583
+ caption_pred_results = outputs["pred_captions"] if self.task_switch['caption'] else [None for i in range(len(mask_pred_results))]
584
+
585
+ # upsample masks
586
+ mask_pred_results = F.interpolate(
587
+ mask_pred_results,
588
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
589
+ mode="bicubic",
590
+ align_corners=False,
591
+ antialias=True
592
+ )
593
+
594
+ processed_results = []
595
+ for mask_pred_result, caption_pred_result, input_per_image, image_size in zip(
596
+ mask_pred_results, caption_pred_results, batched_inputs, images.image_sizes
597
+ ):
598
+ height = input_per_image.get("height", image_size[0])
599
+ width = input_per_image.get("width", image_size[1])
600
+ processed_results.append({})
601
+
602
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
603
+ mask_pred_result, image_size, height, width
604
+ )[:-1]
605
+
606
+ texts_all = input_per_image['groundings']['texts']
607
+ grd_masks = []
608
+ for texts in texts_all:
609
+ if mode == 'grounding_refcoco':
610
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, name='grounding', prompt=False, is_eval=True)
611
+ elif mode == 'grounding_phrasecut':
612
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, name='grounding', prompt=True, is_eval=False)
613
+ t_emb = getattr(self.sem_seg_head.predictor.lang_encoder, "{}_text_embeddings".format('grounding')).t()
614
+ v_emb = caption_pred_result[:-1]
615
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
616
+ vt_sim = v_emb @ t_emb
617
+ max_id = vt_sim.max(0)[1][0]
618
+ grd_masks += [mask_pred_result[max_id]]
619
+ processed_results[-1]['grounding_mask'] = torch.stack(grd_masks)
620
+
621
+ return processed_results
622
+
623
+ def evaluate_grounding(self, batched_inputs, mode):
624
+ images = [x["image"].to(self.device) for x in batched_inputs]
625
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
626
+ images = ImageList.from_tensors(images, self.size_divisibility)
627
+
628
+ extra = {}
629
+ # mask_pred_results = []
630
+ # for idx, batch_per_image in enumerate(batched_inputs):
631
+ # grd_texts = batch_per_image['groundings']['texts']
632
+ # grd_masks = []
633
+ # for anno_text in grd_texts:
634
+ # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False)
635
+ # token_emb = gtext['token_emb']
636
+ # tokens = gtext['tokens']
637
+
638
+ # grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]]
639
+ # extra['grounding_tokens'] = grd_emb[:,None]
640
+
641
+ # assert len(images.tensor) == 1, "grounding evaluation only support single batch size now"
642
+ # features = self.backbone(images.tensor)
643
+ # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
644
+
645
+ # pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
646
+ # v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
647
+ # t_emb = grd_emb[-1:]
648
+
649
+ # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
650
+ # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
651
+
652
+ # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
653
+ # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
654
+
655
+ # matched_id = out_prob.max(0)[1]
656
+ # grd_masks += [pred_gmasks[matched_id,:,:]]
657
+ # mask_pred_results += [torch.cat(grd_masks)]
658
+
659
+ # comment for multi object inference.
660
+ mask_pred_results = []
661
+ for idx, batch_per_image in enumerate(batched_inputs):
662
+ grd_texts = batch_per_image['groundings']['texts']
663
+ grd_texts = [x[0] for x in grd_texts]
664
+
665
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
666
+ token_emb = gtext['token_emb']
667
+ tokens = gtext['tokens']
668
+ query_emb = token_emb[tokens['attention_mask'].bool()]
669
+ extra['grounding_tokens'] = query_emb[:,None]
670
+
671
+ features = self.backbone(images.tensor)
672
+ outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval')
673
+
674
+ pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1]
675
+ v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1]
676
+ t_emb = gtext['class_emb']
677
+
678
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
679
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
680
+
681
+ temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale
682
+ out_prob = vl_similarity(v_emb, t_emb, temperature=temperature)
683
+
684
+ matched_id = out_prob.max(0)[1]
685
+ mask_pred_results += [pred_gmasks[matched_id,:,:]]
686
+
687
+ for i in range(len(mask_pred_results)):
688
+ # upsample masks
689
+ mask_pred_results[i] = F.interpolate(
690
+ mask_pred_results[i][None,],
691
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
692
+ mode="bicubic",
693
+ align_corners=False,
694
+ antialias=True
695
+ )[0]
696
+
697
+ processed_results = []
698
+ for mask_pred_result, input_per_image, image_size in zip(
699
+ mask_pred_results, batched_inputs, images.image_sizes
700
+ ):
701
+ height = input_per_image.get("height", image_size[0])
702
+ width = input_per_image.get("width", image_size[1])
703
+ processed_results.append({})
704
+
705
+ mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
706
+ mask_pred_result, image_size, height, width
707
+ )
708
+ processed_results[-1]['grounding_mask'] = mask_pred_result
709
+
710
+ # compute bbox
711
+ # bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes()
712
+ # bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
713
+ # processed_results[-1]['grounding_box'] = bbox
714
+
715
+ return processed_results
716
+
717
+ def prepare_vlp_targets(self, batched_inputs, device):
718
+ input_ids = []
719
+ attention_mask = []
720
+ for cnt, x in enumerate(batched_inputs):
721
+ captions = x['captions']
722
+ randid = random.randint(0, len(captions)-1)
723
+ input_ids += x['tokens']['input_ids'][randid:randid+1]
724
+ attention_mask += x['tokens']['attention_mask'][randid:randid+1]
725
+
726
+ input_ids = torch.stack(input_ids)
727
+ attention_mask = torch.stack(attention_mask)
728
+ tokens = {"input_ids": input_ids, "attention_mask": attention_mask}
729
+ lang_results = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(tokens, token=True)
730
+
731
+ target_vlp = []
732
+ for cnt, x in enumerate(batched_inputs):
733
+ target_dict = {}
734
+ target_dict["caption_tokens"] = lang_results['token_emb'][cnt:cnt+1]
735
+ target_dict["caption_proj"] = lang_results['class_emb'][cnt:cnt+1]
736
+ target_dict["caption_tokenids"] = lang_results['tokens']['input_ids'][cnt:cnt+1]
737
+ target_dict["caption_mask"] = lang_results['tokens']['attention_mask'][cnt:cnt+1]
738
+ target_vlp.append(target_dict)
739
+ return target_vlp
740
+
741
+ def prepare_targets(self, batched_inputs, images):
742
+ h_pad, w_pad = images.tensor.shape[-2:]
743
+ new_targets = []
744
+ for idx, batch_per_image in enumerate(batched_inputs):
745
+ targets_per_image = batch_per_image["instances"].to(self.device)
746
+
747
+ # pad gt
748
+ gt_masks = targets_per_image.gt_masks
749
+ padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
750
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
751
+
752
+ gt_boxes = targets_per_image.gt_boxes.tensor
753
+ ratio = torch.tensor([w_pad,h_pad,w_pad,h_pad]).to(gt_boxes.device)[None,:]
754
+ gt_boxes = gt_boxes / ratio
755
+ xc,yc,w,h = (gt_boxes[:,0] + gt_boxes[:,2])/2, (gt_boxes[:,1] + gt_boxes[:,3])/2, gt_boxes[:,2] - gt_boxes[:,0], gt_boxes[:,3] - gt_boxes[:,1]
756
+ gt_boxes = torch.stack([xc,yc,w,h]).permute(1,0)
757
+
758
+ target_dict = {
759
+ "labels": targets_per_image.gt_classes,
760
+ "is_things": targets_per_image.is_things,
761
+ "masks": padded_masks,
762
+ "boxes": gt_boxes
763
+ }
764
+
765
+ if self.task_switch['caption']:
766
+ caption = batch_per_image["captions"]
767
+ caption_noun = batch_per_image["captions_noun"]
768
+ rand_index = random.randint(0, len(caption)-1)
769
+
770
+ text = caption[rand_index]
771
+ nouns = caption_noun[rand_index]
772
+ noun_captions = [prompt_engineering(noun, topk=10000, suffix='.') for noun in nouns] + [text]
773
+
774
+ self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(noun_captions, is_eval=False, name='caption_noun', prompt=False)
775
+ ctext = getattr(self.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('caption_noun'))
776
+ target_dict["captions"] = ctext
777
+
778
+ target_dict["captions_hash"] = [(hash(st.stem(txt)) % 10**16) for txt in (nouns + [text])]
779
+ target_dict["labels_hash"] = [(hash(st.stem(COCO_PANOPTIC_CLASSES[label_id].replace('-other','').replace('-merged','').replace('-stuff',''))) % 10**16) for label_id in target_dict['labels']]
780
+
781
+ if self.task_switch['grounding']:
782
+ grd_masks = batch_per_image['groundings']['masks']
783
+ grd_texts = batch_per_image['groundings']['texts']
784
+ grd_hash = batch_per_image['groundings']['hash']
785
+ grd_task = batch_per_image['groundings']['mode']
786
+
787
+ if len(grd_masks) == 0:
788
+ padded_masks = None
789
+ else:
790
+ padded_masks = torch.zeros((grd_masks.shape[0], h_pad, w_pad), dtype=grd_masks.dtype, device=grd_masks.device)
791
+ padded_masks[:, : grd_masks.shape[1], : grd_masks.shape[2]] = grd_masks
792
+
793
+ gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False)
794
+ token_emb = gtext['token_emb']
795
+ tokens = gtext['tokens']
796
+
797
+ unique_hash_id = np.unique(grd_hash, return_index=True)[1]
798
+ selected_mask = np.zeros(len(grd_hash)).astype(np.bool)
799
+ selected_mask[unique_hash_id] = True
800
+
801
+ selected_token_emb = token_emb[selected_mask]
802
+ selected_attn_mask = tokens['attention_mask'][selected_mask]
803
+ query_emb = selected_token_emb[selected_attn_mask.bool()]
804
+
805
+ class_idx = tokens['attention_mask'].sum(dim=-1) - 1
806
+ class_idx = torch.stack((torch.arange(len(class_idx), device=class_idx.device), class_idx)).tolist()
807
+ class_emb = token_emb[class_idx]
808
+
809
+ target_dict['grounding_masks'] = padded_masks
810
+ target_dict['grounding_query_embs'] = query_emb
811
+ target_dict['grounding_class_embs'] = class_emb
812
+ target_dict['grounding_hash'] = grd_hash
813
+ target_dict['grounding_task'] = grd_task
814
+
815
+ new_targets.append(target_dict)
816
+ return new_targets
817
+
818
+ def semantic_inference(self, mask_cls, mask_pred, keep_sem_bgd=False):
819
+ if keep_sem_bgd:
820
+ mask_cls = F.softmax(mask_cls, dim=-1)
821
+ else:
822
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
823
+ mask_pred = mask_pred.sigmoid()
824
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
825
+ return semseg
826
+
827
+ def panoptic_inference(self, mask_cls, mask_pred):
828
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
829
+ mask_pred = mask_pred.sigmoid()
830
+
831
+ keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
832
+ cur_scores = scores[keep]
833
+ cur_classes = labels[keep]
834
+ cur_masks = mask_pred[keep]
835
+ cur_mask_cls = mask_cls[keep]
836
+ cur_mask_cls = cur_mask_cls[:, :-1]
837
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
838
+
839
+ h, w = cur_masks.shape[-2:]
840
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
841
+ segments_info = []
842
+
843
+ current_segment_id = 0
844
+
845
+ if cur_masks.shape[0] == 0:
846
+ # We didn't detect any mask :(
847
+ return panoptic_seg, segments_info
848
+ else:
849
+ # take argmax
850
+ cur_mask_ids = cur_prob_masks.argmax(0)
851
+ stuff_memory_list = {}
852
+ thing_dataset_id_to_contiguous_id = self.metadata.thing_dataset_id_to_contiguous_id if hasattr(self.metadata, 'thing_dataset_id_to_contiguous_id') else {}
853
+ for k in range(cur_classes.shape[0]):
854
+ pred_class = cur_classes[k].item()
855
+ isthing = pred_class in thing_dataset_id_to_contiguous_id.values()
856
+ mask_area = (cur_mask_ids == k).sum().item()
857
+ original_area = (cur_masks[k] >= 0.5).sum().item()
858
+ mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
859
+
860
+ if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
861
+ if mask_area / original_area < self.overlap_threshold:
862
+ continue
863
+
864
+ # merge stuff regions
865
+ if not isthing:
866
+ if int(pred_class) in stuff_memory_list.keys():
867
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
868
+ continue
869
+ else:
870
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
871
+
872
+ current_segment_id += 1
873
+ panoptic_seg[mask] = current_segment_id
874
+
875
+ segments_info.append(
876
+ {
877
+ "id": current_segment_id,
878
+ "isthing": bool(isthing),
879
+ "category_id": int(pred_class),
880
+ }
881
+ )
882
+ return panoptic_seg, segments_info
883
+
884
+ def instance_inference(self, mask_cls, mask_pred, box_pred):
885
+ # mask_pred is already processed to have the same shape as original input
886
+ image_size = mask_pred.shape[-2:]
887
+
888
+ # [Q, K]
889
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
890
+ labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
891
+ # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
892
+ scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
893
+
894
+ labels_per_image = labels[topk_indices]
895
+ topk_indices = (topk_indices // self.sem_seg_head.num_classes)
896
+ # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
897
+ mask_pred = mask_pred[topk_indices]
898
+ if box_pred is not None:
899
+ box_pred = box_pred[topk_indices]
900
+
901
+ # if this is panoptic segmentation, we only keep the "thing" classes
902
+ if self.panoptic_on:
903
+ thing_dataset_id_to_contiguous_id = self.metadata.thing_dataset_id_to_contiguous_id if hasattr(self.metadata, 'thing_dataset_id_to_contiguous_id') else {}
904
+ keep = torch.zeros_like(scores_per_image).bool()
905
+ for i, lab in enumerate(labels_per_image):
906
+ keep[i] = lab in thing_dataset_id_to_contiguous_id.values()
907
+
908
+ scores_per_image = scores_per_image[keep]
909
+ labels_per_image = labels_per_image[keep]
910
+ mask_pred = mask_pred[keep]
911
+
912
+ if box_pred is not None:
913
+ box_pred = box_pred[keep]
914
+
915
+ result = Instances(image_size)
916
+ # mask (before sigmoid)
917
+ result.pred_masks = (mask_pred > 0).float()
918
+ # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
919
+ # Uncomment the following to get boxes from masks (this is slow)
920
+
921
+ if box_pred is not None:
922
+ result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
923
+ else:
924
+ result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
925
+
926
+ # calculate average mask prob
927
+ mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
928
+ result.scores = scores_per_image * mask_scores_per_image
929
+ result.pred_classes = labels_per_image
930
+
931
+ return result
932
+
933
+
934
+
935
+ @register_model
936
+ def get_xdecoder_model(cfg, **kwargs):
937
+ return GeneralizedXdecoder(cfg)
modeling/body/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .xdecoder_head import *
2
+ from .build import *
3
+
4
+ def build_xdecoder_head(config, *args, **kwargs):
5
+ model_name = config['MODEL']['HEAD']
6
+ if not is_model(model_name):
7
+ raise ValueError(f'Unkown model: {model_name}')
8
+
9
+ body = model_entrypoints(model_name)(config, *args, **kwargs)
10
+ return body
modeling/body/build.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _model_entrypoints = {}
2
+
3
+ def register_body(fn):
4
+ module_name_split = fn.__module__.split('.')
5
+ model_name = module_name_split[-1]
6
+ _model_entrypoints[model_name] = fn
7
+ return fn
8
+
9
+ def model_entrypoints(model_name):
10
+ return _model_entrypoints[model_name]
11
+
12
+ def is_model(model_name):
13
+ return model_name in _model_entrypoints
modeling/body/xdecoder_head.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+ # Copyright (c) Facebook, Inc. and its affiliates.
8
+ from typing import Dict
9
+
10
+ from torch import nn
11
+
12
+ from detectron2.layers import ShapeSpec
13
+
14
+ from .build import register_body
15
+ from ..vision.encoder import build_encoder
16
+ from ..interface import build_decoder
17
+ from ..utils import configurable
18
+
19
+
20
+ class XdecoderHead(nn.Module):
21
+
22
+ @configurable
23
+ def __init__(
24
+ self,
25
+ input_shape: Dict[str, ShapeSpec],
26
+ *,
27
+ num_classes: int,
28
+ pixel_decoder: nn.Module,
29
+ loss_weight: float = 1.0,
30
+ ignore_value: int = -1,
31
+ # extra parameters
32
+ transformer_predictor: nn.Module,
33
+ transformer_in_feature: str,
34
+ binary_classes: bool,
35
+ ):
36
+ """
37
+ NOTE: this interface is experimental.
38
+ Args:
39
+ input_shape: shapes (channels and stride) of the input features
40
+ num_classes: number of classes to predict
41
+ pixel_decoder: the pixel decoder module
42
+ loss_weight: loss weight
43
+ ignore_value: category id to be ignored during training.
44
+ transformer_predictor: the transformer decoder that makes prediction
45
+ transformer_in_feature: input feature name to the transformer_predictor
46
+ """
47
+ super().__init__()
48
+
49
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
50
+ self.in_features = [k for k, v in input_shape]
51
+ feature_strides = [v.stride for k, v in input_shape]
52
+ feature_channels = [v.channels for k, v in input_shape]
53
+
54
+ self.ignore_value = ignore_value
55
+ self.common_stride = 4
56
+ self.loss_weight = loss_weight
57
+
58
+ self.pixel_decoder = pixel_decoder
59
+ self.predictor = transformer_predictor
60
+ self.transformer_in_feature = transformer_in_feature
61
+
62
+ self.num_classes = num_classes
63
+
64
+ if binary_classes:
65
+ self.num_classes = 1
66
+
67
+ @classmethod
68
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict):
69
+
70
+ in_features_type = cfg['MODEL']['DECODER']['TRANSFORMER_IN_FEATURE']
71
+ enc_cfg = cfg['MODEL']['ENCODER']
72
+ dec_cfg = cfg['MODEL']['DECODER']
73
+
74
+ # figure out in_channels to transformer predictor
75
+ if in_features_type == "transformer_encoder":
76
+ transformer_predictor_in_channels = enc_cfg['CONVS_DIM']
77
+ elif in_features_type == "pixel_embedding":
78
+ transformer_predictor_in_channels = enc_cfg['MASK_DIM']
79
+ elif in_features_type == "multi_scale_pixel_decoder":
80
+ transformer_predictor_in_channels = enc_cfg['CONVS_DIM']
81
+ else:
82
+ transformer_predictor_in_channels = input_shape[dec_cfg['TRANSFORMER_IN_FEATURE']].channels
83
+
84
+ return {
85
+ "input_shape": {
86
+ k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES']
87
+ },
88
+ "ignore_value": enc_cfg['IGNORE_VALUE'],
89
+ "num_classes": enc_cfg.get('NUM_CLASSES', None),
90
+ "pixel_decoder": build_encoder(cfg, input_shape),
91
+ "loss_weight": enc_cfg['LOSS_WEIGHT'],
92
+ "transformer_in_feature": dec_cfg['TRANSFORMER_IN_FEATURE'],
93
+ "transformer_predictor": build_decoder(
94
+ cfg,
95
+ transformer_predictor_in_channels,
96
+ lang_encoder,
97
+ mask_classification=True,
98
+ extra=extra,
99
+ ),
100
+ "binary_classes": enc_cfg['BINARY_CLASSES']
101
+ }
102
+
103
+ def forward(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
104
+ return self.layers(features, mask, target_queries, target_vlp, task, extra)
105
+
106
+ def layers(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
107
+ mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features)
108
+
109
+ if self.transformer_in_feature == "multi_scale_pixel_decoder":
110
+ predictions = self.predictor(multi_scale_features, mask_features, mask, target_queries, target_vlp, task, extra)
111
+ else:
112
+ if self.transformer_in_feature == "transformer_encoder":
113
+ assert (
114
+ transformer_encoder_features is not None
115
+ ), "Please use the TransformerEncoderPixelDecoder."
116
+ predictions = self.predictor(transformer_encoder_features, mask_features, mask)
117
+ elif self.transformer_in_feature == "pixel_embedding":
118
+ predictions = self.predictor(mask_features, mask_features, mask)
119
+ else:
120
+ predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask)
121
+ return predictions
122
+
123
+
124
+ @register_body
125
+ def get_xdecoder_head(cfg, input_shape, lang_encoder, extra):
126
+ return XdecoderHead(cfg, input_shape, lang_encoder, extra)
modeling/interface/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .xdecoder import *
2
+ from .seem_v0 import *
3
+ from .seem_v1 import *
4
+ from .seem_demo import *
5
+ from .build import *
6
+
7
+ def build_decoder(config, *args, **kwargs):
8
+ model_name = config['MODEL']['DECODER']['NAME']
9
+
10
+ if not is_model(model_name):
11
+ raise ValueError(f'Unkown model: {model_name}')
12
+
13
+ return model_entrypoints(model_name)(config, *args, **kwargs)
modeling/interface/build.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _model_entrypoints = {}
2
+
3
+
4
+ def register_decoder(fn):
5
+ module_name_split = fn.__module__.split('.')
6
+ model_name = module_name_split[-1]
7
+ _model_entrypoints[model_name] = fn
8
+ return fn
9
+
10
+ def model_entrypoints(model_name):
11
+ return _model_entrypoints[model_name]
12
+
13
+ def is_model(model_name):
14
+ return model_name in _model_entrypoints
modeling/interface/modules.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn, Tensor
5
+ from torch.nn import functional as F
6
+
7
+ from timm.models.layers import trunc_normal_
8
+ from detectron2.layers import Conv2d
9
+ import fvcore.nn.weight_init as weight_init
10
+
11
+ from ..utils import MultiheadAttention
12
+
13
+
14
+ class SelfAttentionLayer(nn.Module):
15
+
16
+ def __init__(self, d_model, nhead, dropout=0.0,
17
+ activation="relu", normalize_before=False):
18
+ super().__init__()
19
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
20
+
21
+ self.norm = nn.LayerNorm(d_model)
22
+ self.dropout = nn.Dropout(dropout)
23
+
24
+ self.activation = _get_activation_fn(activation)
25
+ self.normalize_before = normalize_before
26
+
27
+ self._reset_parameters()
28
+
29
+ def _reset_parameters(self):
30
+ for p in self.parameters():
31
+ if p.dim() > 1:
32
+ nn.init.xavier_uniform_(p)
33
+
34
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
35
+ return tensor if pos is None else tensor + pos
36
+
37
+ def forward_post(self, tgt,
38
+ tgt_mask: Optional[Tensor] = None,
39
+ tgt_key_padding_mask: Optional[Tensor] = None,
40
+ query_pos: Optional[Tensor] = None):
41
+
42
+ q = k = self.with_pos_embed(tgt, query_pos)
43
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
44
+ key_padding_mask=tgt_key_padding_mask)[0]
45
+ tgt = tgt + self.dropout(tgt2)
46
+ tgt = self.norm(tgt)
47
+ return tgt
48
+
49
+ def forward_pre(self, tgt,
50
+ tgt_mask: Optional[Tensor] = None,
51
+ tgt_key_padding_mask: Optional[Tensor] = None,
52
+ query_pos: Optional[Tensor] = None):
53
+ tgt2 = self.norm(tgt)
54
+ q = k = self.with_pos_embed(tgt2, query_pos)
55
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
56
+ key_padding_mask=tgt_key_padding_mask)[0]
57
+ tgt = tgt + self.dropout(tgt2)
58
+
59
+ return tgt
60
+
61
+ def forward(self, tgt,
62
+ tgt_mask: Optional[Tensor] = None,
63
+ tgt_key_padding_mask: Optional[Tensor] = None,
64
+ query_pos: Optional[Tensor] = None):
65
+ if self.normalize_before:
66
+ return self.forward_pre(tgt, tgt_mask,
67
+ tgt_key_padding_mask, query_pos)
68
+ return self.forward_post(tgt, tgt_mask,
69
+ tgt_key_padding_mask, query_pos)
70
+
71
+
72
+ class CrossAttentionLayer(nn.Module):
73
+
74
+ def __init__(self, d_model, nhead, dropout=0.0,
75
+ activation="relu", normalize_before=False):
76
+ super().__init__()
77
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
78
+
79
+ self.norm = nn.LayerNorm(d_model)
80
+ self.dropout = nn.Dropout(dropout)
81
+
82
+ self.activation = _get_activation_fn(activation)
83
+ self.normalize_before = normalize_before
84
+
85
+ self._reset_parameters()
86
+
87
+ def _reset_parameters(self):
88
+ for p in self.parameters():
89
+ if p.dim() > 1:
90
+ nn.init.xavier_uniform_(p)
91
+
92
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
93
+ return tensor if pos is None else tensor + pos
94
+
95
+ def forward_post(self, tgt, memory,
96
+ memory_mask: Optional[Tensor] = None,
97
+ memory_key_padding_mask: Optional[Tensor] = None,
98
+ pos: Optional[Tensor] = None,
99
+ query_pos: Optional[Tensor] = None):
100
+ tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
101
+ key=self.with_pos_embed(memory, pos),
102
+ value=memory, attn_mask=memory_mask,
103
+ key_padding_mask=memory_key_padding_mask)
104
+ tgt = tgt + self.dropout(tgt2)
105
+ tgt = self.norm(tgt)
106
+ return tgt, avg_attn
107
+
108
+ def forward_pre(self, tgt, memory,
109
+ memory_mask: Optional[Tensor] = None,
110
+ memory_key_padding_mask: Optional[Tensor] = None,
111
+ pos: Optional[Tensor] = None,
112
+ query_pos: Optional[Tensor] = None):
113
+ tgt2 = self.norm(tgt)
114
+ tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
115
+ key=self.with_pos_embed(memory, pos),
116
+ value=memory, attn_mask=memory_mask,
117
+ key_padding_mask=memory_key_padding_mask)
118
+ tgt = tgt + self.dropout(tgt2)
119
+
120
+ return tgt, avg_attn
121
+
122
+ def forward(self, tgt, memory,
123
+ memory_mask: Optional[Tensor] = None,
124
+ memory_key_padding_mask: Optional[Tensor] = None,
125
+ pos: Optional[Tensor] = None,
126
+ query_pos: Optional[Tensor] = None):
127
+ if self.normalize_before:
128
+ return self.forward_pre(tgt, memory, memory_mask,
129
+ memory_key_padding_mask, pos, query_pos)
130
+ return self.forward_post(tgt, memory, memory_mask,
131
+ memory_key_padding_mask, pos, query_pos)
132
+
133
+
134
+ class FFNLayer(nn.Module):
135
+
136
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
137
+ activation="relu", normalize_before=False):
138
+ super().__init__()
139
+ # Implementation of Feedforward model
140
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
141
+ self.dropout = nn.Dropout(dropout)
142
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
143
+
144
+ self.norm = nn.LayerNorm(d_model)
145
+
146
+ self.activation = _get_activation_fn(activation)
147
+ self.normalize_before = normalize_before
148
+
149
+ self._reset_parameters()
150
+
151
+ def _reset_parameters(self):
152
+ for p in self.parameters():
153
+ if p.dim() > 1:
154
+ nn.init.xavier_uniform_(p)
155
+
156
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
157
+ return tensor if pos is None else tensor + pos
158
+
159
+ def forward_post(self, tgt):
160
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
161
+ tgt = tgt + self.dropout(tgt2)
162
+ tgt = self.norm(tgt)
163
+ return tgt
164
+
165
+ def forward_pre(self, tgt):
166
+ tgt2 = self.norm(tgt)
167
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
168
+ tgt = tgt + self.dropout(tgt2)
169
+ return tgt
170
+
171
+ def forward(self, tgt):
172
+ if self.normalize_before:
173
+ return self.forward_pre(tgt)
174
+ return self.forward_post(tgt)
175
+
176
+
177
+ def _get_activation_fn(activation):
178
+ """Return an activation function given a string"""
179
+ if activation == "relu":
180
+ return F.relu
181
+ if activation == "gelu":
182
+ return F.gelu
183
+ if activation == "glu":
184
+ return F.glu
185
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
186
+
187
+
188
+ class MLP(nn.Module):
189
+ """ Very simple multi-layer perceptron (also called FFN)"""
190
+
191
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
192
+ super().__init__()
193
+ self.num_layers = num_layers
194
+ h = [hidden_dim] * (num_layers - 1)
195
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
196
+
197
+ def forward(self, x):
198
+ for i, layer in enumerate(self.layers):
199
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
200
+ return x
modeling/interface/prototype/__init__.py ADDED
File without changes
modeling/interface/prototype/attention_data_struct_seemdemo.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ predict_name_matcher = {"predictions_class": ["pred_logits"],
13
+ "predictions_mask":["pred_masks", "pred_gmasks", "pred_smasks"],
14
+ "predictions_caption":["pred_captions", "pred_gtexts"],
15
+ "predictions_maskemb":["pred_maskembs", "pred_smaskembs"],
16
+ "predictions_pos_spatial":["pred_pspatials"],
17
+ "predictions_neg_spatial":["pred_nspatials"],
18
+ "predictions_pos_visual":["pred_pvisuals"],
19
+ "predictions_neg_visual":["pred_nvisuals"]}
20
+
21
+ predict_index_matcher = {"predictions_class": ["queries_object"],
22
+ "predictions_mask":["queries_object", "queries_grounding", "queries_spatial"],
23
+ "predictions_caption": ["queries_object", "queries_grounding"],
24
+ "predictions_maskemb":["queries_object", "queries_spatial"],
25
+ "predictions_pos_spatial":["all"],
26
+ "predictions_neg_spatial":["all"],
27
+ "predictions_pos_visual":["all"],
28
+ "predictions_neg_visual":["all"]}
29
+
30
+ class Variable(object):
31
+ '''
32
+ Store dataset variable for attention
33
+ output: embedding that accumuates during cross/self attention
34
+ pos: positional embedding that is fixed during cross/self attention
35
+ name: name of the variable
36
+ type: type of the variable, e.g. queries, tokens
37
+ attn_mask: attention mask for corss attention
38
+ masking: masking for padding
39
+ '''
40
+ def __init__(self, output, name, _type, pos=None):
41
+ self.output = output
42
+ self.pos = pos
43
+ self.name = name
44
+ self.type = _type
45
+ self.attn_mask = None
46
+ self.masking = None
47
+
48
+ def copy(self,):
49
+ output = self.output.clone() if self.output is not None else None
50
+ pos = self.pos.clone() if self.pos is not None else None
51
+ return Variable(output, self.name, self.type, pos)
52
+
53
+ class AttentionDataStruct(nn.Module):
54
+ '''
55
+ Store dataset structure for cross/self attention
56
+ task_switch: switch for different tasks
57
+
58
+ p_attn_variables: prototype of variables that is used in cross/self attention
59
+ p_self_attn: prototype of variables that is used in self attention
60
+ p_cross_attn: prototype of variables that is used in cross attention
61
+ p_iter: prototype of iteration for different queries
62
+ p_masking: prototype of masking for different tokens
63
+ p_duplication: prototype of duplication for different quries
64
+ '''
65
+ def __init__(self, attn_arch, task_switch):
66
+ super(AttentionDataStruct, self).__init__()
67
+ self.task_switch = task_switch
68
+
69
+ # p stands for prototype
70
+ self.p_attn_variables = attn_arch['VARIABLE']
71
+ self.p_self_attn = attn_arch['SELF_ATTENTION']
72
+ self.p_cross_attn = attn_arch['CROSS_ATTENTION']
73
+ self.p_masking = attn_arch['MASKING']
74
+ self.p_duplication = attn_arch['DUPLICATION']
75
+
76
+ self.num_layers = attn_arch['NUM_LAYERS']
77
+
78
+ def reset(self, flags, task, extra):
79
+ # reset variables
80
+ self.attn_variables = {}
81
+ self.cross_attn_dict = {}
82
+ self.self_attn_dict = {}
83
+ self.duplication_dict = {}
84
+ self.query_index = {}
85
+ self.output = {}
86
+ self.flags = {}
87
+ self.spatial_memory = {}
88
+
89
+ # initialize duplication
90
+ for key, values in self.p_duplication.items():
91
+ for name in values:
92
+ self.duplication_dict["{}_{}".format(key, name)] = self.p_duplication[key][name]
93
+
94
+ # initialize flag
95
+ self.flags = {"object": True}
96
+ self.flags.update(flags)
97
+
98
+ # initialize task
99
+ self.task = task
100
+
101
+ # initialize output
102
+ if self.task_switch['mask']:
103
+ self.output['predictions_class'] = []
104
+ self.output['predictions_mask'] = []
105
+ self.output['predictions_maskemb'] = []
106
+
107
+ if self.task_switch['bbox']:
108
+ self.output['predictions_bbox'] = []
109
+
110
+ if self.task_switch['spatial'] and ('spatial' in self.flags and self.flags['spatial']==True):
111
+ self.output['predictions_pos_spatial'] = []
112
+ self.output['predictions_neg_spatial'] = []
113
+
114
+ if self.task_switch['spatial'] and ('memories_spatial' in self.flags and self.flags['memories_spatial']==True):
115
+ self.spatial_memory['prev_batch_mask'] = extra['prev_mask']
116
+
117
+ if (self.task_switch['grounding'] and ('grounding' in self.flags and self.flags['grounding']==True)) \
118
+ or (self.task_switch['audio'] and ('audio' in self.flags and self.flags['audio']==True)):
119
+ self.output['predictions_caption'] = []
120
+
121
+ if self.task_switch['visual']:
122
+ self.output['predictions_pos_visual'] = []
123
+ self.output['predictions_neg_visual'] = []
124
+
125
+ # initialize cross_attn, whether the variable is used in cross attention
126
+ for key, values in self.p_cross_attn.items():
127
+ for name in values:
128
+ self.cross_attn_dict["{}_{}".format(key, name)] = self.p_cross_attn[key][name]
129
+
130
+ # initialize self_attn, whether the variable is used in self attention, and the interactions between queries
131
+ for key, values in self.p_self_attn.items():
132
+ for name in values:
133
+ self.self_attn_dict["{}_{}".format(key, name)] = self.p_self_attn[key][name]
134
+
135
+ # initialize masking
136
+ self.masking = self.p_masking
137
+
138
+ # initialize query_index
139
+ self.query_index = {"all":[0, None]}
140
+
141
+
142
+ def set(self, name, _type, output=None, pos=None, var=None):
143
+ if var is not None:
144
+ self.attn_variables[name] = var
145
+ elif name in self.duplication_dict:
146
+ assert self.duplication_dict[name] in self.attn_variables, "Duplication variable {} is not initialized yet.".format(name)
147
+ self.attn_variables[name] = self.attn_variables[self.duplication_dict[name]].copy()
148
+ else:
149
+ var = Variable(output, name, _type, pos)
150
+ self.attn_variables[name] = var
151
+
152
+ def set_results(self, results):
153
+ for name in self.cross_attn_name:
154
+ self.attn_variables[name].attn_mask = results['attn_mask'][:,self.query_index[name][0]:self.query_index[name][1]]
155
+ for key in self.output:
156
+ self.output[key].append(results[key])
157
+
158
+ def set_maskings(self, name, masking):
159
+ self.attn_variables[name].masking = masking
160
+
161
+ def cross_attn_variables(self, ):
162
+ cross_attn_name = [key for key, value in self.cross_attn_dict.items()
163
+ if (value==True) and (key in self.attn_variables)
164
+ and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
165
+ self.cross_attn_name = cross_attn_name
166
+
167
+ output = torch.cat([self.attn_variables[name].output for name in cross_attn_name])
168
+ pos_emb = torch.cat([self.attn_variables[name].pos for name in cross_attn_name])
169
+
170
+ index = 0
171
+ for name in cross_attn_name:
172
+ self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
173
+ index += self.attn_variables[name].output.shape[0]
174
+ return output, pos_emb
175
+
176
+ def cross_attn_mask(self, size, num_heads):
177
+ attn_mask = torch.cat([self.attn_variables[name].attn_mask for name in self.cross_attn_name], dim=1)
178
+
179
+ # hard code memories_spatial to previous selected mask
180
+ if 'memories_spatial' in self.cross_attn_name:
181
+ memory_attn_mask = self.spatial_memory['prev_batch_mask']
182
+ bs,c,_,_ = memory_attn_mask.shape
183
+ memory_attn_mask = F.interpolate(memory_attn_mask, size, mode='bilinear', align_corners=False)
184
+ memory_attn_mask = (memory_attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, num_heads, 1, 1).flatten(0, 1) < 0.5).bool().detach()
185
+ attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = memory_attn_mask
186
+
187
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
188
+ return attn_mask
189
+
190
+ def self_attn(self, bs, num_heads):
191
+ self_attn_name = [key for key, value in self.self_attn_dict.items()
192
+ if len(value)>0 and key in self.attn_variables
193
+ and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
194
+ self.self_attn_name = self_attn_name
195
+
196
+ output = torch.cat([self.attn_variables[name].output for name in self_attn_name])
197
+ pos_emb = torch.cat([self.attn_variables[name].pos for name in self_attn_name])
198
+
199
+ index = 0
200
+ for name in self_attn_name:
201
+ self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
202
+ index += self.attn_variables[name].output.shape[0]
203
+
204
+ self_attn_mask = torch.ones((bs, output.shape[0], output.shape[0]), dtype=torch.bool, device=output.device)
205
+ self_attn_pair = []
206
+ # build self_attention mask by query interaction
207
+ for key1, value in self.self_attn_dict.items():
208
+ for key2 in value:
209
+ if key1 not in self_attn_name or key2 not in self_attn_name:
210
+ # exclude the variables that are not used in the current layer
211
+ continue
212
+ if (key1 in self.masking or key2 in self.masking) and (key1 != key2):
213
+ self_attn_pair += [[key1, key2]]
214
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1], self.query_index[key2][0]:self.query_index[key2][1]] = False
215
+
216
+ # build self_attention mask by masking, for birectional
217
+ for key in self.masking:
218
+ if key in self_attn_name:
219
+ self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]][self.attn_variables[key].masking] = True
220
+ self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]].transpose(1,2)[self.attn_variables[key].masking] = True
221
+
222
+ # build self_attention mask by masking, for uni-directional
223
+ for key1, key2 in self_attn_pair:
224
+ if key1 not in self_attn_name or key2 not in self_attn_name:
225
+ # exclude the variables that are not used in the current layer
226
+ continue
227
+ if key1 in self.masking:
228
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]][self.attn_variables[key1].masking] = True # HACK, not verified
229
+ if key2 in self.masking:
230
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]].transpose(1,2)[self.attn_variables[key2].masking] = True
231
+
232
+ self_attn_mask = self_attn_mask.repeat_interleave(num_heads, dim=0)
233
+ return output, pos_emb, self_attn_mask
234
+
235
+ def update_variables(self, output, mode):
236
+ name_set = self.self_attn_name if mode=='self_attn' else self.cross_attn_name
237
+ for key in name_set:
238
+ self.attn_variables[key].output = output[self.query_index[key][0]:self.query_index[key][1]]
239
+
240
+ def update_spatial_results(self, results):
241
+ v_emb = results['pred_smaskembs']
242
+ pred_smasks = results['pred_smasks']
243
+
244
+ s_emb = results['pred_pspatials']
245
+ pred_logits = v_emb @ s_emb.transpose(1,2)
246
+ logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
247
+ logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
248
+ logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
249
+ pred_masks_pos = pred_smasks[logits_idx][:,None,]
250
+
251
+ extra = {"prev_mask": pred_masks_pos}
252
+ return extra
253
+
254
+ def organize_output(self, ):
255
+ outputs = {}
256
+ outputs['aux_outputs'] = [{} for i in range(self.num_layers)]
257
+
258
+ for key, values in self.output.items():
259
+ for _key, idx_name in zip(predict_name_matcher[key], predict_index_matcher[key]):
260
+ if idx_name not in self.query_index:
261
+ continue
262
+ outputs[_key] = self.output[key][-1][:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
263
+ for idx, aux_values in enumerate(self.output[key][:-1]):
264
+ outputs['aux_outputs'][idx][_key] = aux_values[:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
265
+ return outputs
modeling/interface/prototype/attention_data_struct_seemv0.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ predict_name_matcher = {"predictions_class": ["pred_logits"],
6
+ "predictions_mask":["pred_masks", "pred_gmasks", "pred_smasks"],
7
+ "predictions_caption":["pred_captions", "pred_gtexts"],
8
+ "predictions_maskemb":["pred_smaskembs"],
9
+ "predictions_pos_spatial":["pred_pspatials"],
10
+ "predictions_neg_spatial":["pred_nspatials"],}
11
+
12
+ predict_index_matcher = {"predictions_class": ["queries_object"],
13
+ "predictions_mask":["queries_object", "queries_grounding", "queries_spatial"],
14
+ "predictions_caption": ["queries_object", "queries_grounding"],
15
+ "predictions_maskemb":["queries_spatial"],
16
+ "predictions_pos_spatial":["all"],
17
+ "predictions_neg_spatial":["all"],}
18
+
19
+ class Variable(object):
20
+ '''
21
+ Store dataset variable for attention
22
+ output: embedding that accumuates during cross/self attention
23
+ pos: positional embedding that is fixed during cross/self attention
24
+ name: name of the variable
25
+ type: type of the variable, e.g. queries, tokens
26
+ attn_mask: attention mask for corss attention
27
+ masking: masking for padding
28
+ '''
29
+ def __init__(self, output, name, _type, pos=None):
30
+ self.output = output
31
+ self.pos = pos
32
+ self.name = name
33
+ self.type = _type
34
+ self.attn_mask = None
35
+ self.masking = None
36
+
37
+ def copy(self,):
38
+ output = self.output.clone() if self.output is not None else None
39
+ pos = self.pos.clone() if self.pos is not None else None
40
+ return Variable(output, self.name, self.type, pos)
41
+
42
+ class AttentionDataStruct(nn.Module):
43
+ '''
44
+ Store dataset structure for cross/self attention
45
+ task_switch: switch for different tasks
46
+
47
+ p_attn_variables: prototype of variables that is used in cross/self attention
48
+ p_self_attn: prototype of variables that is used in self attention
49
+ p_cross_attn: prototype of variables that is used in cross attention
50
+ p_iter: prototype of iteration for different queries
51
+ p_masking: prototype of masking for different tokens
52
+ p_duplication: prototype of duplication for different quries
53
+ '''
54
+ def __init__(self, attn_arch, task_switch):
55
+ super(AttentionDataStruct, self).__init__()
56
+ self.task_switch = task_switch
57
+
58
+ # p stands for prototype
59
+ self.p_attn_variables = attn_arch['VARIABLE']
60
+ self.p_self_attn = attn_arch['SELF_ATTENTION']
61
+ self.p_cross_attn = attn_arch['CROSS_ATTENTION']
62
+ self.p_masking = attn_arch['MASKING']
63
+ self.p_duplication = attn_arch['DUPLICATION']
64
+
65
+ self.num_layers = attn_arch['NUM_LAYERS']
66
+
67
+ def reset(self, flags, task, extra):
68
+ # reset variables
69
+ self.attn_variables = {}
70
+ self.cross_attn_dict = {}
71
+ self.self_attn_dict = {}
72
+ self.duplication_dict = {}
73
+ self.query_index = {}
74
+ self.output = {}
75
+ self.flags = {}
76
+ self.spatial_memory = {}
77
+
78
+ # initialize duplication
79
+ for key, values in self.p_duplication.items():
80
+ for name in values:
81
+ self.duplication_dict["{}_{}".format(key, name)] = self.p_duplication[key][name]
82
+
83
+ # initialize flag
84
+ self.flags = {"object": True}
85
+ self.flags.update(flags)
86
+
87
+ # initialize task
88
+ self.task = task
89
+
90
+ # initialize output
91
+ if self.task_switch['mask']:
92
+ self.output['predictions_class'] = []
93
+ self.output['predictions_mask'] = []
94
+
95
+ if self.task_switch['bbox']:
96
+ self.output['predictions_bbox'] = []
97
+
98
+ if self.task_switch['spatial'] and ('spatial' in self.flags and self.flags['spatial']==True):
99
+ self.output['predictions_maskemb'] = []
100
+ self.output['predictions_pos_spatial'] = []
101
+ self.output['predictions_neg_spatial'] = []
102
+ # self.spatial_memory['spatial_query_mode'] = extra['spatial_query_mode']
103
+
104
+ if self.task_switch['spatial'] and ('memories_spatial' in self.flags and self.flags['memories_spatial']==True):
105
+ self.spatial_memory['prev_batch_mask'] = extra['prev_mask']
106
+
107
+ if self.task_switch['grounding'] and ('grounding' in self.flags and self.flags['grounding']==True):
108
+ self.output['predictions_caption'] = []
109
+
110
+ # initialize cross_attn, whether the variable is used in cross attention
111
+ for key, values in self.p_cross_attn.items():
112
+ for name in values:
113
+ self.cross_attn_dict["{}_{}".format(key, name)] = self.p_cross_attn[key][name]
114
+
115
+ # initialize self_attn, whether the variable is used in self attention, and the interactions between queries
116
+ for key, values in self.p_self_attn.items():
117
+ for name in values:
118
+ self.self_attn_dict["{}_{}".format(key, name)] = self.p_self_attn[key][name]
119
+
120
+ # initialize masking
121
+ self.masking = self.p_masking
122
+
123
+ # initialize query_index
124
+ self.query_index = {"all":[0, None]}
125
+
126
+
127
+ def set(self, name, _type, output=None, pos=None, var=None):
128
+ if var is not None:
129
+ self.attn_variables[name] = var
130
+ elif name in self.duplication_dict:
131
+ assert self.duplication_dict[name] in self.attn_variables, "Duplication variable {} is not initialized yet.".format(name)
132
+ self.attn_variables[name] = self.attn_variables[self.duplication_dict[name]].copy()
133
+ else:
134
+ var = Variable(output, name, _type, pos)
135
+ self.attn_variables[name] = var
136
+
137
+ def set_results(self, results):
138
+ for name in self.cross_attn_name:
139
+ self.attn_variables[name].attn_mask = results['attn_mask'][:,self.query_index[name][0]:self.query_index[name][1]]
140
+ for key in self.output:
141
+ self.output[key].append(results[key])
142
+
143
+ def set_maskings(self, name, masking):
144
+ self.attn_variables[name].masking = masking
145
+
146
+ def cross_attn_variables(self, ):
147
+ cross_attn_name = [key for key, value in self.cross_attn_dict.items()
148
+ if (value==True) and (key in self.attn_variables)
149
+ and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
150
+ self.cross_attn_name = cross_attn_name
151
+
152
+ output = torch.cat([self.attn_variables[name].output for name in cross_attn_name])
153
+ pos_emb = torch.cat([self.attn_variables[name].pos for name in cross_attn_name])
154
+
155
+ index = 0
156
+ for name in cross_attn_name:
157
+ self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
158
+ index += self.attn_variables[name].output.shape[0]
159
+ return output, pos_emb
160
+
161
+ def cross_attn_mask(self, size, num_heads):
162
+ attn_mask = torch.cat([self.attn_variables[name].attn_mask for name in self.cross_attn_name], dim=1)
163
+
164
+ # hard code memories_spatial to previous selected mask
165
+ if 'memories_spatial' in self.cross_attn_name:
166
+ memory_attn_mask = self.spatial_memory['prev_batch_mask']
167
+ bs,c,_,_ = memory_attn_mask.shape
168
+ memory_attn_mask = F.interpolate(memory_attn_mask, size, mode='bilinear', align_corners=False)
169
+ memory_attn_mask = (memory_attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, num_heads, 1, 1).flatten(0, 1) < 0.5).bool().detach()
170
+ attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = memory_attn_mask
171
+
172
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
173
+ return attn_mask
174
+
175
+ def self_attn(self, bs, num_heads):
176
+ self_attn_name = [key for key, value in self.self_attn_dict.items()
177
+ if len(value)>0 and key in self.attn_variables
178
+ and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
179
+ self.self_attn_name = self_attn_name
180
+
181
+ output = torch.cat([self.attn_variables[name].output for name in self_attn_name])
182
+ pos_emb = torch.cat([self.attn_variables[name].pos for name in self_attn_name])
183
+
184
+ index = 0
185
+ for name in self_attn_name:
186
+ self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
187
+ index += self.attn_variables[name].output.shape[0]
188
+
189
+ self_attn_mask = torch.ones((bs, output.shape[0], output.shape[0]), dtype=torch.bool, device=output.device)
190
+ self_attn_pair = []
191
+ # build self_attention mask by query interaction
192
+ for key1, value in self.self_attn_dict.items():
193
+ for key2 in value:
194
+ if key1 not in self_attn_name or key2 not in self_attn_name:
195
+ # exclude the variables that are not used in the current layer
196
+ continue
197
+ if (key1 in self.masking or key2 in self.masking) and (key1 != key2):
198
+ self_attn_pair += [[key1, key2]]
199
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1], self.query_index[key2][0]:self.query_index[key2][1]] = False
200
+
201
+ # build self_attention mask by masking, for birectional
202
+ for key in self.masking:
203
+ if key in self_attn_name:
204
+ self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]][self.attn_variables[key].masking] = True
205
+ self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]].transpose(1,2)[self.attn_variables[key].masking] = True
206
+
207
+ # build self_attention mask by masking, for uni-directional
208
+ for key1, key2 in self_attn_pair:
209
+ if key1 not in self_attn_name or key2 not in self_attn_name:
210
+ # exclude the variables that are not used in the current layer
211
+ continue
212
+ if key1 in self.masking:
213
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]][self.attn_variables[key1].masking] = True # HACK, not verified
214
+ if key2 in self.masking:
215
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]].transpose(1,2)[self.attn_variables[key2].masking] = True
216
+
217
+ self_attn_mask = self_attn_mask.repeat_interleave(num_heads, dim=0)
218
+ return output, pos_emb, self_attn_mask
219
+
220
+ def update_variables(self, output, mode):
221
+ name_set = self.self_attn_name if mode=='self_attn' else self.cross_attn_name
222
+ for key in name_set:
223
+ self.attn_variables[key].output = output[self.query_index[key][0]:self.query_index[key][1]]
224
+
225
+ def update_spatial_results(self, results):
226
+ v_emb = results['pred_smaskembs']
227
+ pred_smasks = results['pred_smasks']
228
+
229
+ s_emb = results['pred_pspatials']
230
+ pred_logits = v_emb @ s_emb.transpose(1,2)
231
+ logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
232
+ logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
233
+ logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
234
+ pred_masks_pos = pred_smasks[logits_idx][:,None,]
235
+
236
+ # s_emb = results['pred_nspatials']
237
+ # pred_logits = v_emb @ s_emb.transpose(1,2)
238
+ # logits_idx_y = pred_logits[:,:,0].max(dim=1)[1]
239
+ # logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)
240
+ # logits_idx = torch.stack([logits_idx_x, logits_idx_y]).tolist()
241
+ # pred_masks_neg = pred_smasks[logits_idx][:,None,]
242
+ # # clip the negative mask to 0, and then multiply by -1
243
+ # pred_masks_neg = (pred_masks_neg.clip(0) * -1)
244
+ # keep_neg = (s_emb.sum(dim=list(range(1, s_emb.dim()))) != 0).float()
245
+ # pred_masks_neg = pred_masks_neg * keep_neg[:,None,None,None]
246
+ # extra = {"prev_mask": pred_masks_pos + pred_masks_neg}
247
+
248
+ extra = {"prev_mask": pred_masks_pos}
249
+ return extra
250
+
251
+ def organize_output(self, ):
252
+ outputs = {}
253
+ outputs['aux_outputs'] = [{} for i in range(self.num_layers)]
254
+ for key, values in self.output.items():
255
+ for _key, idx_name in zip(predict_name_matcher[key], predict_index_matcher[key]):
256
+ if idx_name not in self.query_index:
257
+ continue
258
+ outputs[_key] = self.output[key][-1][:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
259
+ for idx, aux_values in enumerate(self.output[key][:-1]):
260
+ outputs['aux_outputs'][idx][_key] = aux_values[:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
261
+ if self.task == 'spatial' or self.task == 'refimg':
262
+ outputs = self.update_spatial_results(outputs)
263
+ # outputs = self.update_spatial_results(outputs)
264
+ return outputs
modeling/interface/prototype/attention_data_struct_seemv1.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ predict_name_matcher = {"predictions_class": ["pred_logits"],
6
+ "predictions_mask":["pred_masks", "pred_gmasks", "pred_smasks"],
7
+ "predictions_caption":["pred_captions", "pred_gtexts", "pred_stexts"],
8
+ "predictions_maskemb":["pred_smaskembs"],
9
+ "predictions_pos_spatial":["pred_pspatials"],
10
+ "predictions_neg_spatial":["pred_nspatials"],}
11
+
12
+ predict_index_matcher = {"predictions_class": ["queries_object"],
13
+ "predictions_mask":["queries_object", "queries_grounding", "queries_spatial"],
14
+ "predictions_caption": ["queries_object", "queries_grounding", "queries_spatial"],
15
+ "predictions_maskemb":["queries_spatial"],
16
+ "predictions_pos_spatial":["all"],
17
+ "predictions_neg_spatial":["all"],}
18
+
19
+ class Variable(object):
20
+ '''
21
+ Store dataset variable for attention
22
+ output: embedding that accumuates during cross/self attention
23
+ pos: positional embedding that is fixed during cross/self attention
24
+ name: name of the variable
25
+ type: type of the variable, e.g. queries, tokens
26
+ attn_mask: attention mask for corss attention
27
+ masking: masking for padding
28
+ '''
29
+ def __init__(self, output, name, _type, pos=None):
30
+ self.output = output
31
+ self.pos = pos
32
+ self.name = name
33
+ self.type = _type
34
+ self.attn_mask = None
35
+ self.masking = None
36
+
37
+ def copy(self,):
38
+ output = self.output.clone() if self.output is not None else None
39
+ pos = self.pos.clone() if self.pos is not None else None
40
+ return Variable(output, self.name, self.type, pos)
41
+
42
+ def rand_sample(self, max_len):
43
+ rand_idx = torch.randint(0, len(self.pos), (max_len,))
44
+ self.output = self.output[rand_idx]
45
+ self.pos = self.pos[rand_idx]
46
+ return self
47
+
48
+ class AttentionDataStruct(nn.Module):
49
+ '''
50
+ Store dataset structure for cross/self attention
51
+ task_switch: switch for different tasks
52
+
53
+ p_attn_variables: prototype of variables that is used in cross/self attention
54
+ p_self_attn: prototype of variables that is used in self attention
55
+ p_cross_attn: prototype of variables that is used in cross attention
56
+ p_iter: prototype of iteration for different queries
57
+ p_masking: prototype of masking for different tokens
58
+ p_duplication: prototype of duplication for different quries
59
+ '''
60
+ def __init__(self, attn_arch, task_switch):
61
+ super(AttentionDataStruct, self).__init__()
62
+ self.task_switch = task_switch
63
+
64
+ # p stands for prototype
65
+ self.p_attn_variables = attn_arch['VARIABLE']
66
+ self.p_self_attn = attn_arch['SELF_ATTENTION']
67
+ self.p_cross_attn = attn_arch['CROSS_ATTENTION']
68
+ self.p_masking = attn_arch['MASKING']
69
+ self.p_duplication = attn_arch['DUPLICATION']
70
+
71
+ self.num_layers = attn_arch['NUM_LAYERS']
72
+
73
+ def reset(self, flags, task, extra):
74
+ # reset variables
75
+ self.attn_variables = {}
76
+ self.cross_attn_dict = {}
77
+ self.self_attn_dict = {}
78
+ self.duplication_dict = {}
79
+ self.query_index = {}
80
+ self.output = {}
81
+ self.flags = {}
82
+ self.spatial_memory = {}
83
+ self.extra = {}
84
+
85
+ # initialize duplication
86
+ for key, values in self.p_duplication.items():
87
+ for name in values:
88
+ self.duplication_dict["{}_{}".format(key, name)] = self.p_duplication[key][name]
89
+
90
+ # initialize flag
91
+ self.flags = {"object": True}
92
+ self.flags.update(flags)
93
+
94
+ # initialize task
95
+ self.task = task
96
+
97
+ # initialize output
98
+ if self.task_switch['mask']:
99
+ self.output['predictions_class'] = []
100
+ self.output['predictions_mask'] = []
101
+
102
+ if self.task_switch['bbox']:
103
+ self.output['predictions_bbox'] = []
104
+
105
+ if self.task_switch['spatial'] and ('memories_spatial' in self.flags and self.flags['memories_spatial']==True):
106
+ self.spatial_memory['prev_batch_mask'] = extra['prev_mask']
107
+
108
+ if self.task_switch['grounding'] and ('grounding' in self.flags and self.flags['grounding']==True):
109
+ self.output['predictions_caption'] = []
110
+
111
+ if self.task_switch['spatial'] and ('spatial' in self.flags and self.flags['spatial']==True):
112
+ self.output['predictions_maskemb'] = []
113
+ self.output['predictions_pos_spatial'] = []
114
+ self.output['predictions_neg_spatial'] = []
115
+ self.output['predictions_mask'] = [] if 'predictions_mask' not in self.output else self.output['predictions_mask']
116
+ self.output['predictions_class'] = [] if 'predictions_class' not in self.output else self.output['predictions_class']
117
+ self.output['predictions_caption'] = [] if 'predictions_caption' not in self.output else self.output['predictions_caption']
118
+
119
+ # initialize cross_attn, whether the variable is used in cross attention
120
+ for key, values in self.p_cross_attn.items():
121
+ for name in values:
122
+ self.cross_attn_dict["{}_{}".format(key, name)] = self.p_cross_attn[key][name]
123
+
124
+ # initialize self_attn, whether the variable is used in self attention, and the interactions between queries
125
+ for key, values in self.p_self_attn.items():
126
+ for name in values:
127
+ self.self_attn_dict["{}_{}".format(key, name)] = self.p_self_attn[key][name]
128
+
129
+ # initialize masking
130
+ self.masking = self.p_masking
131
+
132
+ # initialize query_index
133
+ self.query_index = {"all":[0, None]}
134
+
135
+
136
+ def set(self, name, _type, output=None, pos=None, var=None, sample_size=None):
137
+ if var is not None:
138
+ self.attn_variables[name] = var
139
+ elif name in self.duplication_dict:
140
+ assert self.duplication_dict[name] in self.attn_variables, "Duplication variable {} is not initialized yet.".format(name)
141
+ var = self.attn_variables[self.duplication_dict[name]].copy()
142
+ if sample_size is not None:
143
+ var = var.rand_sample(sample_size)
144
+ self.attn_variables[name] = var
145
+ else:
146
+ var = Variable(output, name, _type, pos)
147
+ self.attn_variables[name] = var
148
+
149
+ def set_results(self, results):
150
+ for name in self.cross_attn_name:
151
+ self.attn_variables[name].attn_mask = results['attn_mask'][:,self.query_index[name][0]:self.query_index[name][1]]
152
+ for key in self.output:
153
+ self.output[key].append(results[key])
154
+
155
+ def set_maskings(self, name, masking):
156
+ self.attn_variables[name].masking = masking
157
+
158
+ def set_extra(self, extra):
159
+ self.extra.update(extra)
160
+
161
+ def cross_attn_variables(self, ):
162
+ cross_attn_name = [key for key, value in self.cross_attn_dict.items()
163
+ if (value==True) and (key in self.attn_variables)
164
+ and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
165
+ self.cross_attn_name = cross_attn_name
166
+
167
+ output = torch.cat([self.attn_variables[name].output for name in cross_attn_name])
168
+ pos_emb = torch.cat([self.attn_variables[name].pos for name in cross_attn_name])
169
+
170
+ index = 0
171
+ for name in cross_attn_name:
172
+ self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
173
+ index += self.attn_variables[name].output.shape[0]
174
+ return output, pos_emb
175
+
176
+ def cross_attn_mask(self, size, num_heads):
177
+ attn_mask = torch.cat([self.attn_variables[name].attn_mask for name in self.cross_attn_name], dim=1)
178
+
179
+ # hard code memories_spatial to previous selected mask
180
+ if 'memories_spatial' in self.cross_attn_name:
181
+ memory_attn_mask = self.spatial_memory['prev_batch_mask']
182
+ bs,c,_,_ = memory_attn_mask.shape
183
+ memory_attn_mask = F.interpolate(memory_attn_mask, size, mode='bilinear', align_corners=False)
184
+ memory_attn_mask = (memory_attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, num_heads, 1, 1).flatten(0, 1) < 0.5).bool().detach()
185
+ repeat = (self.query_index['memories_spatial'][1] - self.query_index['memories_spatial'][0]) // c
186
+ mem_len = self.query_index['memories_spatial'][1] - self.query_index['memories_spatial'][0]
187
+ probs = torch.tensor([1./repeat for i in range(c)])
188
+ indices = torch.multinomial(probs, num_samples=mem_len, replacement=True).sort()[0]
189
+ attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = memory_attn_mask[:,indices]
190
+ self.extra['memory_indices'] = indices
191
+
192
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
193
+ return attn_mask
194
+
195
+ def self_attn(self, bs, num_heads):
196
+ self_attn_name = [key for key, value in self.self_attn_dict.items()
197
+ if len(value)>0 and key in self.attn_variables
198
+ and ((key not in self.flags) or (key in self.flags and self.flags[key]==True))]
199
+ self.self_attn_name = self_attn_name
200
+
201
+ output = torch.cat([self.attn_variables[name].output for name in self_attn_name])
202
+ pos_emb = torch.cat([self.attn_variables[name].pos for name in self_attn_name])
203
+
204
+ index = 0
205
+ for name in self_attn_name:
206
+ self.query_index[name] = [index, index + self.attn_variables[name].output.shape[0]]
207
+ index += self.attn_variables[name].output.shape[0]
208
+
209
+ self_attn_mask = torch.ones((bs, output.shape[0], output.shape[0]), dtype=torch.bool, device=output.device)
210
+ self_attn_pair = []
211
+ # build self_attention mask by query interaction
212
+ for key1, value in self.self_attn_dict.items():
213
+ for key2 in value:
214
+ if key1 not in self_attn_name or key2 not in self_attn_name:
215
+ # exclude the variables that are not used in the current layer
216
+ continue
217
+ if (key1 in self.masking or key2 in self.masking) and (key1 != key2):
218
+ self_attn_pair += [[key1, key2]]
219
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1], self.query_index[key2][0]:self.query_index[key2][1]] = False
220
+
221
+ # build self_attention mask by masking, for birectional
222
+ for key in self.masking:
223
+ if key in self_attn_name:
224
+ self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]][self.attn_variables[key].masking] = True
225
+ self_attn_mask[:,self.query_index[key][0]:self.query_index[key][1],self.query_index[key][0]:self.query_index[key][1]].transpose(1,2)[self.attn_variables[key].masking] = True
226
+
227
+ # build self_attention mask by masking, for uni-directional
228
+ for key1, key2 in self_attn_pair:
229
+ if key1 not in self_attn_name or key2 not in self_attn_name:
230
+ # exclude the variables that are not used in the current layer
231
+ continue
232
+ if key1 in self.masking:
233
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]][self.attn_variables[key1].masking] = True # HACK, not verified
234
+ if key2 in self.masking:
235
+ self_attn_mask[:,self.query_index[key1][0]:self.query_index[key1][1],self.query_index[key2][0]:self.query_index[key2][1]].transpose(1,2)[self.attn_variables[key2].masking] = True
236
+
237
+ # build self_attention mask masking for spatial query
238
+ # spatial query attend with itself
239
+ if 'queries_spatial' in self_attn_name and 'tokens_spatial' in self_attn_name:
240
+ diag_mask = ~(torch.eye(self.extra['spatial_query_number']).repeat_interleave(self.extra['sample_size'],dim=0).repeat_interleave(self.extra['sample_size'],dim=1)).bool()
241
+ self_attn_mask[:,self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1],self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1]] = diag_mask[None,]
242
+ # spatial query attend with spatial token
243
+ indices = self.extra['spatial_indices'].permute(0,2,1)
244
+ diag_index = torch.arange(self.extra['spatial_query_number'], device=indices.device).repeat_interleave(self.extra['sample_size'],dim=0)[None,:,None]
245
+ diag_mask = ~(indices == diag_index)
246
+ self_attn_mask[:,self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1],self.query_index['tokens_spatial'][0]:self.query_index['tokens_spatial'][1]] = diag_mask
247
+ # spatial token attend with itself
248
+ diag_mask = ~(indices == indices.transpose(1,2))
249
+ self_attn_mask[:,self.query_index['tokens_spatial'][0]:self.query_index['tokens_spatial'][1],self.query_index['tokens_spatial'][0]:self.query_index['tokens_spatial'][1]] = diag_mask
250
+
251
+ if 'memory_indices' in self.extra:
252
+ # spatial query attend with memory
253
+ memory_indices = self.extra['memory_indices'][None,None,:]
254
+ diag_index = torch.arange(self.extra['spatial_query_number'], device=memory_indices.device).repeat_interleave(self.extra['sample_size'],dim=0)[None,:,None]
255
+ diag_mask = ~(diag_index == memory_indices)
256
+ self_attn_mask[:,self.query_index['queries_spatial'][0]:self.query_index['queries_spatial'][1],self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = diag_mask
257
+ # memory attend with itself
258
+ diag_mask = ~(memory_indices == memory_indices.transpose(1,2))
259
+ self_attn_mask[:,self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1],self.query_index['memories_spatial'][0]:self.query_index['memories_spatial'][1]] = diag_mask
260
+
261
+ self_attn_mask = self_attn_mask.repeat_interleave(num_heads, dim=0)
262
+ return output, pos_emb, self_attn_mask
263
+
264
+ def update_variables(self, output, mode):
265
+ name_set = self.self_attn_name if mode=='self_attn' else self.cross_attn_name
266
+ for key in name_set:
267
+ self.attn_variables[key].output = output[self.query_index[key][0]:self.query_index[key][1]]
268
+
269
+ def update_spatial_results(self, results):
270
+ v_emb = results['pred_smaskembs']
271
+ pred_smasks = results['pred_smasks']
272
+
273
+ s_emb = results['pred_pspatials']
274
+ diag_mask = ~(torch.eye(self.extra['spatial_query_number'], device=s_emb.device).repeat_interleave(self.extra['sample_size'],dim=0)).bool()
275
+ offset = torch.zeros_like(diag_mask, device=s_emb.device).float()
276
+ offset.masked_fill_(diag_mask, float("-inf"))
277
+
278
+ pred_logits = v_emb @ s_emb.transpose(1,2) + offset[None,]
279
+ bs,_,ns=pred_logits.shape
280
+ _,_,h,w=pred_smasks.shape
281
+
282
+ logits_idx_y = pred_logits.max(dim=1)[1]
283
+ logits_idx_x = torch.arange(len(logits_idx_y), device=logits_idx_y.device)[:,None].repeat(1, logits_idx_y.shape[1])
284
+ logits_idx = torch.stack([logits_idx_x, logits_idx_y]).view(2,-1).tolist()
285
+ pred_masks_pos = pred_smasks[logits_idx].reshape(bs,ns,h,w)
286
+ extra = {"prev_mask": pred_masks_pos}
287
+ return extra
288
+
289
+ def organize_output(self, ):
290
+ outputs = {}
291
+ outputs['aux_outputs'] = [{} for i in range(self.num_layers)]
292
+ for key, values in self.output.items():
293
+ for _key, idx_name in zip(predict_name_matcher[key], predict_index_matcher[key]):
294
+ if idx_name not in self.query_index:
295
+ continue
296
+ outputs[_key] = self.output[key][-1][:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
297
+ for idx, aux_values in enumerate(self.output[key][:-1]):
298
+ outputs['aux_outputs'][idx][_key] = aux_values[:,self.query_index[idx_name][0]:self.query_index[idx_name][1]]
299
+ if self.task == 'spatial' or self.task == 'refimg':
300
+ outputs = self.update_spatial_results(outputs)
301
+ # outputs = self.update_spatial_results(outputs)
302
+ return outputs
modeling/interface/seem_demo.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # SEEM -- Segment Everything Everywhere All At Once
3
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
4
+ # Written by Xueyan Zou ([email protected]), Jianwei Yang ([email protected])
5
+ # --------------------------------------------------------
6
+
7
+ import logging
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from torch import nn, Tensor
12
+ from torch.nn import functional as F
13
+
14
+ from timm.models.layers import trunc_normal_
15
+ from detectron2.layers import Conv2d
16
+ import fvcore.nn.weight_init as weight_init
17
+
18
+ from .build import register_decoder
19
+ from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
20
+ from .prototype.attention_data_struct_seemdemo import AttentionDataStruct
21
+ from ..utils import rand_sample_plain as rand_sample
22
+ from ..utils import prepare_features, configurable
23
+ from ..modules import PositionEmbeddingSine
24
+ from ..modules.point_features import point_sample
25
+
26
+
27
+ class SEEMDecoder(nn.Module):
28
+
29
+ @configurable
30
+ def __init__(
31
+ self,
32
+ lang_encoder: nn.Module,
33
+ in_channels,
34
+ mask_classification=True,
35
+ *,
36
+ hidden_dim: int,
37
+ dim_proj: int,
38
+ num_queries: int,
39
+ contxt_len: int,
40
+ nheads: int,
41
+ dim_feedforward: int,
42
+ dec_layers: int,
43
+ pre_norm: bool,
44
+ mask_dim: int,
45
+ task_switch: dict,
46
+ enforce_input_project: bool,
47
+ max_spatial_len: int,
48
+ attn_arch: dict,
49
+ ):
50
+ """
51
+ NOTE: this interface is experimental.
52
+ Args:
53
+ in_channels: channels of the input features
54
+ mask_classification: whether to add mask classifier or not
55
+ num_classes: number of classes
56
+ hidden_dim: Transformer feature dimension
57
+ num_queries: number of queries
58
+ nheads: number of heads
59
+ dim_feedforward: feature dimension in feedforward network
60
+ enc_layers: number of Transformer encoder layers
61
+ dec_layers: number of Transformer decoder layers
62
+ pre_norm: whether to use pre-LayerNorm or not
63
+ mask_dim: mask feature dimension
64
+ enforce_input_project: add input project 1x1 conv even if input
65
+ channels and hidden dim is identical
66
+ """
67
+ super().__init__()
68
+ assert mask_classification, "Only support mask classification model"
69
+ self.mask_classification = mask_classification
70
+
71
+ # positional encoding
72
+ N_steps = hidden_dim // 2
73
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
74
+
75
+ # define Transformer decoder here
76
+ self.num_heads = nheads
77
+ self.num_layers = dec_layers
78
+ self.contxt_len = contxt_len
79
+ self.transformer_self_attention_layers = nn.ModuleList()
80
+ self.transformer_cross_attention_layers = nn.ModuleList()
81
+ self.transformer_ffn_layers = nn.ModuleList()
82
+
83
+ for _ in range(self.num_layers):
84
+ self.transformer_self_attention_layers.append(
85
+ SelfAttentionLayer(
86
+ d_model=hidden_dim,
87
+ nhead=nheads,
88
+ dropout=0.0,
89
+ normalize_before=pre_norm,
90
+ )
91
+ )
92
+
93
+ self.transformer_cross_attention_layers.append(
94
+ CrossAttentionLayer(
95
+ d_model=hidden_dim,
96
+ nhead=nheads,
97
+ dropout=0.0,
98
+ normalize_before=pre_norm,
99
+ )
100
+ )
101
+
102
+ self.transformer_ffn_layers.append(
103
+ FFNLayer(
104
+ d_model=hidden_dim,
105
+ dim_feedforward=dim_feedforward,
106
+ dropout=0.0,
107
+ normalize_before=pre_norm,
108
+ )
109
+ )
110
+
111
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
112
+
113
+ self.num_queries = num_queries
114
+ # learnable query features
115
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
116
+ # learnable query p.e.
117
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
118
+ # learnable positive negative indicator
119
+ self.pn_indicator = nn.Embedding(2, hidden_dim)
120
+
121
+ # level embedding (we always use 3 scales)
122
+ self.num_feature_levels = 3
123
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
124
+ self.input_proj = nn.ModuleList()
125
+
126
+ for _ in range(self.num_feature_levels):
127
+ if in_channels != hidden_dim or enforce_input_project:
128
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
129
+ weight_init.c2_xavier_fill(self.input_proj[-1])
130
+ else:
131
+ self.input_proj.append(nn.Sequential())
132
+
133
+ self.task_switch = task_switch
134
+ self.query_index = {}
135
+
136
+ # output FFNs
137
+ self.lang_encoder = lang_encoder
138
+ if self.task_switch['mask']:
139
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
140
+
141
+ self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
142
+ trunc_normal_(self.class_embed, std=.02)
143
+
144
+ if task_switch['bbox']:
145
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
146
+
147
+ if task_switch['spatial']:
148
+ # spatial query
149
+ self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(3)])
150
+ trunc_normal_(self.mask_sptial_embed[0], std=.02)
151
+ trunc_normal_(self.mask_sptial_embed[1], std=.02)
152
+ trunc_normal_(self.mask_sptial_embed[2], std=.02)
153
+
154
+ self.max_spatial_len = max_spatial_len
155
+ # spatial memory
156
+ num_spatial_memories = attn_arch['SPATIAL_MEMORIES']
157
+ self.spatial_embed = nn.Embedding(num_spatial_memories, hidden_dim)
158
+ self.spatial_featured = nn.Embedding(num_spatial_memories, hidden_dim)
159
+
160
+ # build AttentionDataStruct
161
+ attn_arch['NUM_LAYERS'] = self.num_layers
162
+ self.attention_data = AttentionDataStruct(attn_arch, task_switch)
163
+
164
+ @classmethod
165
+ def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
166
+ ret = {}
167
+
168
+ ret["lang_encoder"] = lang_encoder
169
+ ret["in_channels"] = in_channels
170
+ ret["mask_classification"] = mask_classification
171
+
172
+ enc_cfg = cfg['MODEL']['ENCODER']
173
+ dec_cfg = cfg['MODEL']['DECODER']
174
+
175
+ ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
176
+ ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
177
+ ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
178
+ ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
179
+
180
+ # Transformer parameters:
181
+ ret["nheads"] = dec_cfg['NHEADS']
182
+ ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
183
+
184
+ # NOTE: because we add learnable query features which requires supervision,
185
+ # we add minus 1 to decoder layers to be consistent with our loss
186
+ # implementation: that is, number of auxiliary losses is always
187
+ # equal to number of decoder layers. With learnable query features, the number of
188
+ # auxiliary losses equals number of decoders plus 1.
189
+ assert dec_cfg['DEC_LAYERS'] >= 1
190
+ ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
191
+ ret["pre_norm"] = dec_cfg['PRE_NORM']
192
+ ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
193
+ ret["mask_dim"] = enc_cfg['MASK_DIM']
194
+ ret["task_switch"] = extra['task_switch']
195
+ ret["max_spatial_len"] = dec_cfg['MAX_SPATIAL_LEN']
196
+
197
+ # attn data struct
198
+ ret["attn_arch"] = cfg['ATTENTION_ARCH']
199
+
200
+ return ret
201
+
202
+ def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
203
+ # x is a list of multi-scale feature
204
+ assert len(x) == self.num_feature_levels; del mask
205
+ spatial_extra_flag = 'spatial_query_pos_mask' in extra.keys() or task == 'refimg'
206
+ grounding_extra_flag = 'grounding_tokens' in extra.keys()
207
+ visual_extra_flag = 'visual_query_pos' in extra.keys()
208
+ audio_extra_flag = 'audio_tokens' in extra.keys()
209
+ spatial_memory_flag = 'prev_mask' in extra.keys()
210
+ flags = {"spatial": spatial_extra_flag, "grounding": grounding_extra_flag, "memories_spatial": spatial_memory_flag, "visual": visual_extra_flag, "audio": audio_extra_flag}
211
+ self.attention_data.reset(flags, task, extra)
212
+
213
+ src, pos, size_list = prepare_features(x, self.num_feature_levels, self.pe_layer, self.input_proj, self.level_embed)
214
+ _, bs, _ = src[0].shape
215
+
216
+ # QxNxC
217
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
218
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
219
+ self.attention_data.set('queries_object', 'queries', output, query_embed)
220
+
221
+ if self.task_switch['spatial'] and spatial_extra_flag:
222
+ # get divisor
223
+ _,h,w = extra['spatial_query_pos_mask'][0].shape
224
+ divisor = torch.tensor([h,w], device=output.device)[None,]
225
+
226
+ # Get mean pos spatial query
227
+ non_zero_pos_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
228
+ non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
229
+ non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
230
+ spatial_query_pos = point_sample(mask_features, non_zero_pos_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
231
+ spatial_query_pos = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask)]).transpose(0,1).nan_to_num()
232
+
233
+ # Get mean neg spatial query
234
+ non_zero_neg_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
235
+ non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
236
+ non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
237
+ spatial_query_neg = point_sample(mask_features, non_zero_neg_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
238
+ spatial_query_neg = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask)]).transpose(0,1).nan_to_num()
239
+
240
+ # merge positive and negative sample points for self attention
241
+
242
+ # Get layerwise spatial query
243
+ src_spatial_queries = []
244
+ src_spatial_maskings = []
245
+ for i in range(len(src)):
246
+ hw,_,dc = src[i].shape
247
+ src_mask_features = src[i].view(size_list[i][0],size_list[i][1],bs,dc)
248
+ src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
249
+
250
+ non_zero_query_point_pos = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
251
+ non_zero_query_point_neg = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
252
+ non_zero_query_point = [torch.cat([x,y], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
253
+
254
+ pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
255
+ pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
256
+
257
+ non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
258
+ non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
259
+ non_zero_query_point[non_zero_query_mask] = 0
260
+
261
+ spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
262
+ spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
263
+ spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
264
+
265
+ src_spatial_queries += [spatial_tokens]
266
+ src_spatial_maskings += [non_zero_query_mask]
267
+
268
+ if 'refimg' in task:
269
+ output_refimg = {}
270
+ output_refimg['visual_query_pos'] = spatial_query_pos
271
+ output_refimg['visual_query_neg'] = spatial_query_neg
272
+ output_refimg['src_visual_queries'] = src_spatial_queries
273
+ output_refimg['src_visual_maskings'] = src_spatial_maskings
274
+ return output_refimg
275
+
276
+ if task != 'demo':
277
+ # Get object query for spatial index
278
+ self.attention_data.set('queries_spatial', 'queries')
279
+
280
+ if self.task_switch['visual'] and visual_extra_flag:
281
+ visual_query_pos = extra['visual_query_pos']
282
+ visual_query_neg = extra['visual_query_neg']
283
+ src_visual_queries = extra['src_visual_queries']
284
+ src_visual_maskings = extra['src_visual_maskings']
285
+
286
+ if self.task_switch['grounding'] and grounding_extra_flag:
287
+ # Get grounding tokens
288
+ grounding_tokens = extra['grounding_tokens']
289
+ _grounding_tokens = grounding_tokens.detach().clone()
290
+
291
+ self.attention_data.set('tokens_grounding', 'tokens', grounding_tokens, _grounding_tokens)
292
+ self.attention_data.set_maskings('tokens_grounding', extra['grounding_nonzero_mask'])
293
+
294
+ if self.task_switch['audio'] and audio_extra_flag:
295
+ # Get grounding tokens
296
+ grounding_tokens = extra['audio_tokens']
297
+ _grounding_tokens = grounding_tokens.detach().clone()
298
+
299
+ self.attention_data.set('tokens_audio', 'tokens', grounding_tokens, _grounding_tokens)
300
+ self.attention_data.set_maskings('tokens_audio', extra['audio_nonzero_mask'])
301
+
302
+ output, query_embed = self.attention_data.cross_attn_variables()
303
+ # prediction heads on learnable query features
304
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
305
+ results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
306
+ results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
307
+ results["predictions_pos_visual"] = visual_query_pos.transpose(0,1) if visual_extra_flag else None
308
+ results["predictions_neg_visual"] = visual_query_neg.transpose(0,1) if visual_extra_flag else None
309
+ self.attention_data.set_results(results)
310
+
311
+ for i in range(self.num_layers):
312
+ level_index = i % self.num_feature_levels
313
+ # CROSS ATTENTION
314
+ output, avg_attn = self.transformer_cross_attention_layers[i](
315
+ output, src[level_index],
316
+ memory_mask=self.attention_data.cross_attn_mask(size_list[level_index], self.num_heads),
317
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
318
+ pos=pos[level_index], query_pos=query_embed
319
+ )
320
+ self.attention_data.update_variables(output, 'cross_attn')
321
+
322
+ # SELF ATTENTION
323
+ self_attn_mask = torch.zeros((bs, self.num_queries, self.num_queries), device=query_embed.device).bool() # Default False (attend oq)
324
+ if self.task_switch['spatial'] and spatial_extra_flag:
325
+ # get spatial tokens
326
+ spatial_tokens = src_spatial_queries[level_index]
327
+ _spatial_tokens = spatial_tokens.detach().clone()
328
+
329
+ self.attention_data.set('tokens_spatial', 'tokens', spatial_tokens, _spatial_tokens)
330
+ self.attention_data.set_maskings('tokens_spatial', src_spatial_maskings[level_index])
331
+
332
+ if self.task_switch['visual'] and visual_extra_flag:
333
+ # get spatial tokens
334
+ visual_tokens = src_visual_queries[level_index]
335
+ _visual_tokens = visual_tokens.detach().clone()
336
+
337
+ self.attention_data.set('tokens_visual', 'tokens', visual_tokens, _visual_tokens)
338
+ self.attention_data.set_maskings('tokens_visual', src_visual_maskings[level_index])
339
+
340
+ output, query_embed, self_attn_mask = self.attention_data.self_attn(bs, self.num_heads)
341
+ output = self.transformer_self_attention_layers[i](
342
+ output, tgt_mask=self_attn_mask,
343
+ tgt_key_padding_mask=None,
344
+ query_pos=query_embed)
345
+
346
+ # FFN
347
+ output = self.transformer_ffn_layers[i](
348
+ output
349
+ )
350
+
351
+ self.attention_data.update_variables(output, 'self_attn')
352
+ output, query_embed = self.attention_data.cross_attn_variables()
353
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i)
354
+ results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
355
+ results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
356
+ results["predictions_pos_visual"] = visual_query_pos.transpose(0,1) if visual_extra_flag else None
357
+ results["predictions_neg_visual"] = visual_query_neg.transpose(0,1) if visual_extra_flag else None
358
+ self.attention_data.set_results(results)
359
+
360
+ return self.attention_data.organize_output()
361
+
362
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1):
363
+ decoder_output = self.decoder_norm(output)
364
+ decoder_output = decoder_output.transpose(0, 1)
365
+ class_embed = decoder_output @ self.class_embed
366
+ outputs_class = self.lang_encoder.compute_similarity(class_embed)
367
+ mask_embed = self.mask_embed(decoder_output)
368
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
369
+
370
+ outputs_bbox = [None for i in range(len(outputs_mask))]
371
+ if self.task_switch['bbox']:
372
+ outputs_bbox = self.bbox_embed(decoder_output)
373
+
374
+ # NOTE: prediction is of higher-resolution
375
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
376
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
377
+
378
+ # must use bool type
379
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
380
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
381
+ attn_mask = attn_mask.detach()
382
+
383
+ outputs_caption = class_embed
384
+
385
+ results = {
386
+ "attn_mask": attn_mask,
387
+ "predictions_class": outputs_class,
388
+ "predictions_mask": outputs_mask,
389
+ "predictions_bbox": outputs_bbox,
390
+ "predictions_caption": outputs_caption,
391
+ "predictions_maskemb": mask_embed,
392
+ }
393
+ return results
394
+
395
+ @register_decoder
396
+ def get_seem_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
397
+ return SEEMDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
modeling/interface/seem_v0.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # SEEM -- Segment Everything Everywhere All at Once
3
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
4
+ # Written by Xueyan Zou ([email protected])
5
+ # --------------------------------------------------------
6
+
7
+ import logging
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from torch import nn, Tensor
12
+ from torch.nn import functional as F
13
+
14
+ from timm.models.layers import trunc_normal_
15
+ from detectron2.layers import Conv2d
16
+ import fvcore.nn.weight_init as weight_init
17
+
18
+ from .build import register_decoder
19
+ from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
20
+ from .prototype.attention_data_struct_seemv0 import AttentionDataStruct
21
+ from ..utils import rand_sample_plain as rand_sample
22
+ from ..utils import prepare_features, configurable
23
+ from ..modules import PositionEmbeddingSine
24
+ from ..modules.point_features import point_sample
25
+
26
+
27
+ class SEEMDecoder(nn.Module):
28
+
29
+ @configurable
30
+ def __init__(
31
+ self,
32
+ lang_encoder: nn.Module,
33
+ in_channels,
34
+ mask_classification=True,
35
+ *,
36
+ hidden_dim: int,
37
+ dim_proj: int,
38
+ num_queries: int,
39
+ contxt_len: int,
40
+ nheads: int,
41
+ dim_feedforward: int,
42
+ dec_layers: int,
43
+ pre_norm: bool,
44
+ mask_dim: int,
45
+ task_switch: dict,
46
+ enforce_input_project: bool,
47
+ max_spatial_len: int,
48
+ attn_arch: dict,
49
+ ):
50
+ """
51
+ NOTE: this interface is experimental.
52
+ Args:
53
+ in_channels: channels of the input features
54
+ mask_classification: whether to add mask classifier or not
55
+ num_classes: number of classes
56
+ hidden_dim: Transformer feature dimension
57
+ num_queries: number of queries
58
+ nheads: number of heads
59
+ dim_feedforward: feature dimension in feedforward network
60
+ enc_layers: number of Transformer encoder layers
61
+ dec_layers: number of Transformer decoder layers
62
+ pre_norm: whether to use pre-LayerNorm or not
63
+ mask_dim: mask feature dimension
64
+ enforce_input_project: add input project 1x1 conv even if input
65
+ channels and hidden dim is identical
66
+ """
67
+ super().__init__()
68
+ assert mask_classification, "Only support mask classification model"
69
+ self.mask_classification = mask_classification
70
+
71
+ # positional encoding
72
+ N_steps = hidden_dim // 2
73
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
74
+
75
+ # define Transformer decoder here
76
+ self.num_heads = nheads
77
+ self.num_layers = dec_layers
78
+ self.contxt_len = contxt_len
79
+ self.transformer_self_attention_layers = nn.ModuleList()
80
+ self.transformer_cross_attention_layers = nn.ModuleList()
81
+ self.transformer_ffn_layers = nn.ModuleList()
82
+
83
+ for _ in range(self.num_layers):
84
+ self.transformer_self_attention_layers.append(
85
+ SelfAttentionLayer(
86
+ d_model=hidden_dim,
87
+ nhead=nheads,
88
+ dropout=0.0,
89
+ normalize_before=pre_norm,
90
+ )
91
+ )
92
+
93
+ self.transformer_cross_attention_layers.append(
94
+ CrossAttentionLayer(
95
+ d_model=hidden_dim,
96
+ nhead=nheads,
97
+ dropout=0.0,
98
+ normalize_before=pre_norm,
99
+ )
100
+ )
101
+
102
+ self.transformer_ffn_layers.append(
103
+ FFNLayer(
104
+ d_model=hidden_dim,
105
+ dim_feedforward=dim_feedforward,
106
+ dropout=0.0,
107
+ normalize_before=pre_norm,
108
+ )
109
+ )
110
+
111
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
112
+
113
+ self.num_queries = num_queries
114
+ # learnable query features
115
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
116
+ # learnable query p.e.
117
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
118
+
119
+ # level embedding (we always use 3 scales)
120
+ self.num_feature_levels = 3
121
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
122
+ self.input_proj = nn.ModuleList()
123
+
124
+ for _ in range(self.num_feature_levels):
125
+ if in_channels != hidden_dim or enforce_input_project:
126
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
127
+ weight_init.c2_xavier_fill(self.input_proj[-1])
128
+ else:
129
+ self.input_proj.append(nn.Sequential())
130
+
131
+ self.task_switch = task_switch
132
+ self.query_index = {}
133
+
134
+ # output FFNs
135
+ self.lang_encoder = lang_encoder
136
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
137
+ self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
138
+ trunc_normal_(self.class_embed, std=.02)
139
+
140
+ if task_switch['bbox']:
141
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
142
+
143
+ if task_switch['spatial']:
144
+ # spatial query
145
+ self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(3)])
146
+ trunc_normal_(self.mask_sptial_embed[0], std=.02)
147
+ trunc_normal_(self.mask_sptial_embed[1], std=.02)
148
+ trunc_normal_(self.mask_sptial_embed[2], std=.02)
149
+
150
+ self.max_spatial_len = max_spatial_len
151
+ # spatial memory
152
+ num_spatial_memories = attn_arch['SPATIAL_MEMORIES']
153
+ self.spatial_embed = nn.Embedding(num_spatial_memories, hidden_dim)
154
+ self.spatial_featured = nn.Embedding(num_spatial_memories, hidden_dim)
155
+
156
+ # learnable positive negative indicator
157
+ self.pn_indicator = nn.Embedding(2, hidden_dim)
158
+
159
+ # build AttentionDataStruct
160
+ attn_arch['NUM_LAYERS'] = self.num_layers
161
+ self.attention_data = AttentionDataStruct(attn_arch, task_switch)
162
+
163
+ @classmethod
164
+ def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
165
+ ret = {}
166
+
167
+ ret["lang_encoder"] = lang_encoder
168
+ ret["in_channels"] = in_channels
169
+ ret["mask_classification"] = mask_classification
170
+
171
+ enc_cfg = cfg['MODEL']['ENCODER']
172
+ dec_cfg = cfg['MODEL']['DECODER']
173
+
174
+ ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
175
+ ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
176
+ ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
177
+ ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
178
+
179
+ # Transformer parameters:
180
+ ret["nheads"] = dec_cfg['NHEADS']
181
+ ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
182
+
183
+ # NOTE: because we add learnable query features which requires supervision,
184
+ # we add minus 1 to decoder layers to be consistent with our loss
185
+ # implementation: that is, number of auxiliary losses is always
186
+ # equal to number of decoder layers. With learnable query features, the number of
187
+ # auxiliary losses equals number of decoders plus 1.
188
+ assert dec_cfg['DEC_LAYERS'] >= 1
189
+ ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
190
+ ret["pre_norm"] = dec_cfg['PRE_NORM']
191
+ ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
192
+ ret["mask_dim"] = enc_cfg['MASK_DIM']
193
+ ret["task_switch"] = extra['task_switch']
194
+ ret["max_spatial_len"] = dec_cfg['MAX_SPATIAL_LEN']
195
+
196
+ # attn data struct
197
+ ret["attn_arch"] = cfg['ATTENTION_ARCH']
198
+
199
+ return ret
200
+
201
+ def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
202
+ # x is a list of multi-scale feature
203
+ assert len(x) == self.num_feature_levels; del mask
204
+ spatial_extra_flag = 'spatial_query_pos_mask' in extra.keys() or task == 'refimg' or 'refimg_tokens' in extra
205
+ grounding_extra_flag = 'grounding_tokens' in extra.keys()
206
+ spatial_memory_flag = 'prev_mask' in extra.keys()
207
+ flags = {"spatial": spatial_extra_flag, "grounding": grounding_extra_flag, "memories_spatial": spatial_memory_flag}
208
+ self.attention_data.reset(flags, task, extra)
209
+
210
+ src, pos, size_list = prepare_features(x, self.num_feature_levels, self.pe_layer, self.input_proj, self.level_embed)
211
+ _, bs, _ = src[0].shape
212
+
213
+ # QxNxC
214
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
215
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
216
+ self.attention_data.set('queries_object', 'queries', output, query_embed)
217
+
218
+ if self.task_switch['spatial'] and spatial_extra_flag:
219
+ if 'refimg_tokens' not in extra:
220
+ # get divisor
221
+ _,h,w = extra['spatial_query_pos_mask'][0].shape
222
+ divisor = torch.tensor([h,w], device=output.device)[None,]
223
+
224
+ # Get mean pos spatial query
225
+ non_zero_pos_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
226
+ non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
227
+ non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
228
+ spatial_query_pos = point_sample(mask_features, non_zero_pos_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
229
+ spatial_query_pos = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask)]).transpose(0,1).nan_to_num()
230
+
231
+ # Get mean neg spatial query
232
+ non_zero_neg_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
233
+ non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
234
+ non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
235
+ spatial_query_neg = point_sample(mask_features, non_zero_neg_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
236
+ spatial_query_neg = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask)]).transpose(0,1).nan_to_num()
237
+
238
+ # merge positive and negative sample points for self attention
239
+ # pos_neg_points = [x|y for x,y in zip(extra['spatial_query_pos_mask'], extra['spatial_query_neg_mask'])]
240
+
241
+ # Get layerwise spatial query
242
+ src_spatial_queries = []
243
+ src_spatial_maskings = []
244
+ for i in range(len(src)):
245
+ hw,_,dc = src[i].shape
246
+ src_mask_features = src[i].view(size_list[i][0],size_list[i][1],bs,dc)
247
+ src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
248
+
249
+ non_zero_query_point_pos = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
250
+ non_zero_query_point_neg = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
251
+ non_zero_query_point = [torch.cat([x,y], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
252
+
253
+ pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
254
+ pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
255
+
256
+ non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
257
+ non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
258
+ non_zero_query_point[non_zero_query_mask] = 0
259
+
260
+ spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
261
+ spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
262
+ spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
263
+
264
+ src_spatial_queries += [spatial_tokens]
265
+ src_spatial_maskings += [non_zero_query_mask]
266
+
267
+ if 'refimg' in task:
268
+ output_refimg = {}
269
+ output_refimg['spatial_query_pos'] = spatial_query_pos
270
+ output_refimg['spatial_query_neg'] = spatial_query_neg
271
+ output_refimg['src_spatial_queries'] = src_spatial_queries
272
+ output_refimg['src_spatial_maskings'] = src_spatial_maskings
273
+ return output_refimg
274
+ else:
275
+ spatial_query_pos = extra['refimg_tokens']['spatial_query_pos']
276
+ spatial_query_neg = extra['refimg_tokens']['spatial_query_neg']
277
+ src_spatial_queries = extra['refimg_tokens']['src_spatial_queries']
278
+ src_spatial_maskings = extra['refimg_tokens']['src_spatial_maskings']
279
+
280
+ # Get object query for spatial index
281
+ self.attention_data.set('queries_spatial', 'queries')
282
+
283
+ # set spatial memory
284
+ spatial_output = self.spatial_featured.weight.unsqueeze(1).repeat(1, bs, 1)
285
+ spatial_embed = self.spatial_embed.weight.unsqueeze(1).repeat(1, bs, 1)
286
+ self.attention_data.set('memories_spatial', 'memories', spatial_output, spatial_embed)
287
+
288
+ # if 'queries_spatial' in extra:
289
+ # self.attention_data.set('queries_spatial', 'queries', var=extra['queries_spatial'])
290
+
291
+ # if spatial_memory_flag:
292
+ # prev_mask = (extra['prev_mask'].sigmoid() > 0.5).detach()
293
+ # non_zero_query_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in prev_mask]
294
+ # non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
295
+ # non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
296
+ # spatial_memory = point_sample(mask_features, non_zero_query_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
297
+ # spatial_memory = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_memory.transpose(1,2), ~non_zero_query_mask)]).transpose(0,1).nan_to_num()
298
+
299
+ if self.task_switch['grounding'] and grounding_extra_flag:
300
+ # Get grounding tokens
301
+ grounding_tokens = extra['grounding_tokens']
302
+ _grounding_tokens = grounding_tokens.detach().clone()
303
+
304
+ self.attention_data.set('tokens_grounding', 'tokens', grounding_tokens, _grounding_tokens)
305
+ self.attention_data.set('queries_grounding', 'queries')
306
+ self.attention_data.set_maskings('tokens_grounding', extra['grounding_nonzero_mask'])
307
+
308
+ output, query_embed = self.attention_data.cross_attn_variables()
309
+ # prediction heads on learnable query features
310
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
311
+ results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
312
+ results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
313
+ self.attention_data.set_results(results)
314
+
315
+ for i in range(self.num_layers):
316
+ level_index = i % self.num_feature_levels
317
+ # CROSS ATTENTION
318
+ output, avg_attn = self.transformer_cross_attention_layers[i](
319
+ output, src[level_index],
320
+ memory_mask=self.attention_data.cross_attn_mask(size_list[level_index], self.num_heads),
321
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
322
+ pos=pos[level_index], query_pos=query_embed
323
+ )
324
+ self.attention_data.update_variables(output, 'cross_attn')
325
+
326
+ # SELF ATTENTION
327
+ self_attn_mask = torch.zeros((bs, self.num_queries, self.num_queries), device=query_embed.device).bool() # Default False (attend oq)
328
+ if self.task_switch['spatial'] and spatial_extra_flag:
329
+ # get spatial tokens
330
+ spatial_tokens = src_spatial_queries[level_index]
331
+ _spatial_tokens = spatial_tokens.detach().clone()
332
+
333
+ self.attention_data.set('tokens_spatial', 'tokens', spatial_tokens, _spatial_tokens)
334
+ self.attention_data.set_maskings('tokens_spatial', src_spatial_maskings[level_index])
335
+
336
+ output, query_embed, self_attn_mask = self.attention_data.self_attn(bs, self.num_heads)
337
+
338
+ output = self.transformer_self_attention_layers[i](
339
+ output, tgt_mask=self_attn_mask,
340
+ tgt_key_padding_mask=None,
341
+ query_pos=query_embed)
342
+
343
+ # FFN
344
+ output = self.transformer_ffn_layers[i](
345
+ output
346
+ )
347
+
348
+ self.attention_data.update_variables(output, 'self_attn')
349
+ output, query_embed = self.attention_data.cross_attn_variables()
350
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i)
351
+ results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
352
+ results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
353
+ self.attention_data.set_results(results)
354
+
355
+ return self.attention_data.organize_output()
356
+
357
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1):
358
+ decoder_output = self.decoder_norm(output)
359
+ decoder_output = decoder_output.transpose(0, 1)
360
+ class_embed = decoder_output @ self.class_embed
361
+ outputs_class = self.lang_encoder.compute_similarity(class_embed)
362
+ mask_embed = self.mask_embed(decoder_output)
363
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
364
+
365
+ outputs_bbox = [None for i in range(len(outputs_mask))]
366
+ if self.task_switch['bbox']:
367
+ outputs_bbox = self.bbox_embed(decoder_output)
368
+
369
+ # NOTE: prediction is of higher-resolution
370
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
371
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
372
+
373
+ # must use bool type
374
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
375
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
376
+ attn_mask = attn_mask.detach()
377
+
378
+ outputs_caption = class_embed
379
+
380
+ results = {
381
+ "attn_mask": attn_mask,
382
+ "predictions_class": outputs_class,
383
+ "predictions_mask": outputs_mask,
384
+ "predictions_bbox": outputs_bbox,
385
+ "predictions_caption": outputs_caption,
386
+ "predictions_maskemb": mask_embed,
387
+ }
388
+ return results
389
+
390
+ @register_decoder
391
+ def get_seem_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
392
+ return SEEMDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
modeling/interface/seem_v1.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # SEEM -- Segment Everything Everywhere All at Once
3
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
4
+ # Written by Xueyan Zou ([email protected])
5
+ # --------------------------------------------------------
6
+
7
+ import logging
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from torch import nn, Tensor
12
+ from torch.nn import functional as F
13
+
14
+ from timm.models.layers import trunc_normal_
15
+ from detectron2.layers import Conv2d
16
+ import fvcore.nn.weight_init as weight_init
17
+
18
+ from .build import register_decoder
19
+ from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
20
+ from .prototype.attention_data_struct_seemv1 import AttentionDataStruct
21
+ from ..utils import rand_sample, prepare_features, configurable
22
+ from ..modules import PositionEmbeddingSine
23
+ from ..modules.point_features import point_sample
24
+
25
+
26
+ class SEEMDecoder(nn.Module):
27
+
28
+ @configurable
29
+ def __init__(
30
+ self,
31
+ lang_encoder: nn.Module,
32
+ in_channels,
33
+ mask_classification=True,
34
+ *,
35
+ hidden_dim: int,
36
+ dim_proj: int,
37
+ num_queries: int,
38
+ contxt_len: int,
39
+ nheads: int,
40
+ dim_feedforward: int,
41
+ dec_layers: int,
42
+ pre_norm: bool,
43
+ mask_dim: int,
44
+ task_switch: dict,
45
+ enforce_input_project: bool,
46
+ max_spatial_len: int,
47
+ attn_arch: dict,
48
+ ):
49
+ """
50
+ NOTE: this interface is experimental.
51
+ Args:
52
+ in_channels: channels of the input features
53
+ mask_classification: whether to add mask classifier or not
54
+ num_classes: number of classes
55
+ hidden_dim: Transformer feature dimension
56
+ num_queries: number of queries
57
+ nheads: number of heads
58
+ dim_feedforward: feature dimension in feedforward network
59
+ enc_layers: number of Transformer encoder layers
60
+ dec_layers: number of Transformer decoder layers
61
+ pre_norm: whether to use pre-LayerNorm or not
62
+ mask_dim: mask feature dimension
63
+ enforce_input_project: add input project 1x1 conv even if input
64
+ channels and hidden dim is identical
65
+ """
66
+ super().__init__()
67
+ assert mask_classification, "Only support mask classification model"
68
+ self.mask_classification = mask_classification
69
+
70
+ # positional encoding
71
+ N_steps = hidden_dim // 2
72
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
73
+
74
+ # define Transformer decoder here
75
+ self.num_heads = nheads
76
+ self.num_layers = dec_layers
77
+ self.contxt_len = contxt_len
78
+ self.transformer_self_attention_layers = nn.ModuleList()
79
+ self.transformer_cross_attention_layers = nn.ModuleList()
80
+ self.transformer_ffn_layers = nn.ModuleList()
81
+
82
+ for _ in range(self.num_layers):
83
+ self.transformer_self_attention_layers.append(
84
+ SelfAttentionLayer(
85
+ d_model=hidden_dim,
86
+ nhead=nheads,
87
+ dropout=0.0,
88
+ normalize_before=pre_norm,
89
+ )
90
+ )
91
+
92
+ self.transformer_cross_attention_layers.append(
93
+ CrossAttentionLayer(
94
+ d_model=hidden_dim,
95
+ nhead=nheads,
96
+ dropout=0.0,
97
+ normalize_before=pre_norm,
98
+ )
99
+ )
100
+
101
+ self.transformer_ffn_layers.append(
102
+ FFNLayer(
103
+ d_model=hidden_dim,
104
+ dim_feedforward=dim_feedforward,
105
+ dropout=0.0,
106
+ normalize_before=pre_norm,
107
+ )
108
+ )
109
+
110
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
111
+
112
+ self.num_queries = num_queries
113
+ # learnable query features
114
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
115
+ # learnable query p.e.
116
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
117
+
118
+ # level embedding (we always use 3 scales)
119
+ self.num_feature_levels = 3
120
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
121
+ self.input_proj = nn.ModuleList()
122
+
123
+ for _ in range(self.num_feature_levels):
124
+ if in_channels != hidden_dim or enforce_input_project:
125
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
126
+ weight_init.c2_xavier_fill(self.input_proj[-1])
127
+ else:
128
+ self.input_proj.append(nn.Sequential())
129
+
130
+ self.task_switch = task_switch
131
+ self.query_index = {}
132
+
133
+ # output FFNs
134
+ self.lang_encoder = lang_encoder
135
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
136
+ self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
137
+ trunc_normal_(self.class_embed, std=.02)
138
+
139
+ if task_switch['bbox']:
140
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
141
+
142
+ if task_switch['spatial']:
143
+ # spatial query
144
+ self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(3)])
145
+ trunc_normal_(self.mask_sptial_embed[0], std=.02)
146
+ trunc_normal_(self.mask_sptial_embed[1], std=.02)
147
+ trunc_normal_(self.mask_sptial_embed[2], std=.02)
148
+
149
+ self.max_spatial_len = max_spatial_len
150
+ # spatial memory
151
+ num_spatial_memories = attn_arch['SPATIAL_MEMORIES']
152
+ self.spatial_embed = nn.Embedding(num_spatial_memories, hidden_dim)
153
+ self.spatial_featured = nn.Embedding(num_spatial_memories, hidden_dim)
154
+
155
+ # learnable positive negative indicator
156
+ self.pn_indicator = nn.Embedding(2, hidden_dim)
157
+
158
+ # build AttentionDataStruct
159
+ attn_arch['NUM_LAYERS'] = self.num_layers
160
+ self.attention_data = AttentionDataStruct(attn_arch, task_switch)
161
+ self.sample_size = attn_arch['QUERY_NUMBER']
162
+
163
+ @classmethod
164
+ def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
165
+ ret = {}
166
+
167
+ ret["lang_encoder"] = lang_encoder
168
+ ret["in_channels"] = in_channels
169
+ ret["mask_classification"] = mask_classification
170
+
171
+ enc_cfg = cfg['MODEL']['ENCODER']
172
+ dec_cfg = cfg['MODEL']['DECODER']
173
+
174
+ ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
175
+ ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
176
+ ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
177
+ ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
178
+
179
+ # Transformer parameters:
180
+ ret["nheads"] = dec_cfg['NHEADS']
181
+ ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
182
+
183
+ # NOTE: because we add learnable query features which requires supervision,
184
+ # we add minus 1 to decoder layers to be consistent with our loss
185
+ # implementation: that is, number of auxiliary losses is always
186
+ # equal to number of decoder layers. With learnable query features, the number of
187
+ # auxiliary losses equals number of decoders plus 1.
188
+ assert dec_cfg['DEC_LAYERS'] >= 1
189
+ ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
190
+ ret["pre_norm"] = dec_cfg['PRE_NORM']
191
+ ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
192
+ ret["mask_dim"] = enc_cfg['MASK_DIM']
193
+ ret["task_switch"] = extra['task_switch']
194
+ ret["max_spatial_len"] = dec_cfg['MAX_SPATIAL_LEN']
195
+
196
+ # attn data struct
197
+ ret["attn_arch"] = cfg['ATTENTION_ARCH']
198
+
199
+ return ret
200
+
201
+ def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
202
+ # x is a list of multi-scale feature
203
+ assert len(x) == self.num_feature_levels; del mask
204
+ spatial_extra_flag = 'spatial_query_pos_mask' in extra.keys() or task == 'refimg' or 'refimg_tokens' in extra
205
+ grounding_extra_flag = 'grounding_tokens' in extra.keys()
206
+ spatial_memory_flag = 'prev_mask' in extra.keys()
207
+ flags = {"spatial": spatial_extra_flag, "grounding": grounding_extra_flag, "memories_spatial": spatial_memory_flag}
208
+ self.attention_data.reset(flags, task, extra)
209
+
210
+ src, pos, size_list = prepare_features(x, self.num_feature_levels, self.pe_layer, self.input_proj, self.level_embed)
211
+ _,bs,_ = src[0].shape
212
+
213
+ # QxNxC
214
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
215
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
216
+ self.attention_data.set('queries_object', 'queries', output, query_embed)
217
+
218
+ if self.task_switch['spatial'] and spatial_extra_flag:
219
+ if 'refimg_tokens' not in extra:
220
+ # get divisor
221
+ c,h,w = extra['spatial_query_pos_mask'][0].shape
222
+ divisor = torch.tensor([1,h,w], device=output.device)[None,]
223
+
224
+ # Get mean pos spatial query
225
+ non_zero_pos_point = [rand_sample(m, divisor, self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
226
+ non_zero_pos_index = [m[:,0:1].long() for m in non_zero_pos_point]
227
+ non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
228
+ non_zero_pos_index = nn.utils.rnn.pad_sequence(non_zero_pos_index, padding_value=-1).permute(1,0,2)[:,:,0]
229
+ non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
230
+ spatial_query_pos = point_sample(mask_features, non_zero_pos_point[:,:,1:].flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
231
+ num_mask_per_batch = [len(m) for m in extra['spatial_query_pos_mask']]
232
+ spatial_query_pos = nn.utils.rnn.pad_sequence([torch.stack([x[ns==n].mean(dim=0, keepdim=False) if (ns==n).sum() > 0 else -torch.ones((x.shape[1]), device=spatial_query_pos.device) for n in range(mb)]) for x, m, ns, mb in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask, non_zero_pos_index, num_mask_per_batch)], padding_value=-1).nan_to_num()
233
+
234
+ # Get mean neg spatial query
235
+ non_zero_neg_point = [rand_sample(m, divisor, self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
236
+ non_zero_neg_index = [m[:,0:1].long() for m in non_zero_neg_point]
237
+ non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
238
+ non_zero_neg_index = nn.utils.rnn.pad_sequence(non_zero_neg_index, padding_value=-1).permute(1,0,2)[:,:,0]
239
+ non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
240
+ spatial_query_neg = point_sample(mask_features, non_zero_neg_point[:,:,1:].flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
241
+ num_mask_per_batch = [len(m) for m in extra['spatial_query_neg_mask']]
242
+ spatial_query_neg = nn.utils.rnn.pad_sequence([torch.stack([x[ns==n].mean(dim=0, keepdim=False) if (ns==n).sum() > 0 else -torch.ones((x.shape[1]), device=spatial_query_neg.device) for n in range(mb)]) for x, m, ns, mb in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask, non_zero_neg_index, num_mask_per_batch)], padding_value=-1).nan_to_num()
243
+ # Get layerwise spatial query
244
+ src_spatial_queries = []
245
+ src_spatial_maskings = []
246
+ src_spatial_indices = []
247
+ for i in range(len(src)):
248
+ hw,_,dc = src[i].shape
249
+ src_mask_features = src[i].view(size_list[i][0],size_list[i][1],bs,dc)
250
+ src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
251
+
252
+ non_zero_query_point_pos = [rand_sample(m, divisor, self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
253
+ non_zero_query_point_neg = [rand_sample(m, divisor, self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
254
+ non_zero_query_point = [torch.cat([x[:,1:],y[:,1:]], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
255
+ non_zero_query_index = [torch.cat([x[:,0:1],y[:,0:1]], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
256
+
257
+ pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
258
+ pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
259
+
260
+ non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
261
+ non_zero_query_index = nn.utils.rnn.pad_sequence(non_zero_query_index, padding_value=-1).permute(1,0,2)
262
+ non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
263
+ non_zero_query_point[non_zero_query_mask] = 0
264
+
265
+ spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
266
+ spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
267
+ spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
268
+
269
+ src_spatial_queries += [spatial_tokens]
270
+ src_spatial_maskings += [non_zero_query_mask]
271
+ src_spatial_indices += [non_zero_query_index]
272
+
273
+ if 'refimg' in task:
274
+ output_refimg = {}
275
+ output_refimg['spatial_query_pos'] = spatial_query_pos
276
+ output_refimg['spatial_query_neg'] = spatial_query_neg
277
+ output_refimg['src_spatial_queries'] = src_spatial_queries
278
+ output_refimg['src_spatial_maskings'] = src_spatial_maskings
279
+ return output_refimg
280
+ else:
281
+ spatial_query_pos = extra['refimg_tokens']['spatial_query_pos']
282
+ spatial_query_neg = extra['refimg_tokens']['spatial_query_neg']
283
+ src_spatial_queries = extra['refimg_tokens']['src_spatial_queries']
284
+ src_spatial_maskings = extra['refimg_tokens']['src_spatial_maskings']
285
+
286
+ # Get object query for spatial index
287
+ self.attention_data.set_extra({"spatial_query_number": len(spatial_query_pos), "sample_size": self.sample_size})
288
+ self.attention_data.set('queries_spatial', 'queries', sample_size=self.sample_size*len(spatial_query_pos))
289
+
290
+ # set spatial memory
291
+ spatial_output = self.spatial_featured.weight.unsqueeze(1).repeat(1, bs, 1)
292
+ spatial_embed = self.spatial_embed.weight.unsqueeze(1).repeat(1, bs, 1)
293
+ self.attention_data.set('memories_spatial', 'memories', spatial_output, spatial_embed)
294
+
295
+ if self.task_switch['grounding'] and grounding_extra_flag:
296
+ # Get grounding tokens
297
+ grounding_tokens = extra['grounding_tokens']
298
+ _grounding_tokens = grounding_tokens.detach().clone()
299
+
300
+ self.attention_data.set('tokens_grounding', 'tokens', grounding_tokens, _grounding_tokens)
301
+ self.attention_data.set('queries_grounding', 'queries')
302
+ self.attention_data.set_maskings('tokens_grounding', extra['grounding_nonzero_mask'])
303
+
304
+ output, query_embed = self.attention_data.cross_attn_variables()
305
+ # prediction heads on learnable query features
306
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
307
+ results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
308
+ results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
309
+ self.attention_data.set_results(results)
310
+
311
+ for i in range(self.num_layers):
312
+ level_index = i % self.num_feature_levels
313
+ # CROSS ATTENTION
314
+ output, avg_attn = self.transformer_cross_attention_layers[i](
315
+ output, src[level_index],
316
+ memory_mask=self.attention_data.cross_attn_mask(size_list[level_index], self.num_heads),
317
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
318
+ pos=pos[level_index], query_pos=query_embed
319
+ )
320
+ self.attention_data.update_variables(output, 'cross_attn')
321
+
322
+ # SELF ATTENTION
323
+ self_attn_mask = torch.zeros((bs, self.num_queries, self.num_queries), device=query_embed.device).bool() # Default False (attend oq)
324
+ if self.task_switch['spatial'] and spatial_extra_flag:
325
+ # get spatial tokens
326
+ spatial_tokens = src_spatial_queries[level_index]
327
+ _spatial_tokens = spatial_tokens.detach().clone()
328
+
329
+ self.attention_data.set('tokens_spatial', 'tokens', spatial_tokens, _spatial_tokens)
330
+ self.attention_data.set_maskings('tokens_spatial', src_spatial_maskings[level_index])
331
+ self.attention_data.set_extra({"spatial_indices": src_spatial_indices[level_index]})
332
+
333
+ output, query_embed, self_attn_mask = self.attention_data.self_attn(bs, self.num_heads)
334
+
335
+ output = self.transformer_self_attention_layers[i](
336
+ output, tgt_mask=self_attn_mask,
337
+ tgt_key_padding_mask=None,
338
+ query_pos=query_embed)
339
+
340
+ # FFN
341
+ output = self.transformer_ffn_layers[i](
342
+ output
343
+ )
344
+
345
+ self.attention_data.update_variables(output, 'self_attn')
346
+ output, query_embed = self.attention_data.cross_attn_variables()
347
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i)
348
+ results["predictions_pos_spatial"] = spatial_query_pos.transpose(0,1) if spatial_extra_flag else None
349
+ results["predictions_neg_spatial"] = spatial_query_neg.transpose(0,1) if spatial_extra_flag else None
350
+ self.attention_data.set_results(results)
351
+
352
+ return self.attention_data.organize_output()
353
+
354
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1):
355
+ decoder_output = self.decoder_norm(output)
356
+ decoder_output = decoder_output.transpose(0, 1)
357
+ class_embed = decoder_output @ self.class_embed
358
+ outputs_class = self.lang_encoder.compute_similarity(class_embed)
359
+ mask_embed = self.mask_embed(decoder_output)
360
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
361
+
362
+ outputs_bbox = [None for i in range(len(outputs_mask))]
363
+ if self.task_switch['bbox']:
364
+ outputs_bbox = self.bbox_embed(decoder_output)
365
+
366
+ # NOTE: prediction is of higher-resolution
367
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
368
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
369
+
370
+ # must use bool type
371
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
372
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
373
+ attn_mask = attn_mask.detach()
374
+
375
+ outputs_caption = class_embed
376
+
377
+ results = {
378
+ "attn_mask": attn_mask,
379
+ "predictions_class": outputs_class,
380
+ "predictions_mask": outputs_mask,
381
+ "predictions_bbox": outputs_bbox,
382
+ "predictions_caption": outputs_caption,
383
+ "predictions_maskemb": mask_embed,
384
+ }
385
+ return results
386
+
387
+ @register_decoder
388
+ def get_seem_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
389
+ return SEEMDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
modeling/interface/xdecoder.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+
8
+ import logging
9
+ from typing import Optional
10
+
11
+ import torch
12
+ from torch import nn, Tensor
13
+ from torch.nn import functional as F
14
+
15
+ from timm.models.layers import trunc_normal_
16
+ from detectron2.layers import Conv2d
17
+ import fvcore.nn.weight_init as weight_init
18
+
19
+ from .build import register_decoder
20
+ from .modules import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
21
+ from ..utils import configurable
22
+ from ..modules import PositionEmbeddingSine
23
+
24
+
25
+ class XDecoder(nn.Module):
26
+
27
+ @configurable
28
+ def __init__(
29
+ self,
30
+ lang_encoder: nn.Module,
31
+ in_channels,
32
+ mask_classification=True,
33
+ *,
34
+ hidden_dim: int,
35
+ dim_proj: int,
36
+ num_queries: int,
37
+ contxt_len: int,
38
+ nheads: int,
39
+ dim_feedforward: int,
40
+ dec_layers: int,
41
+ pre_norm: bool,
42
+ mask_dim: int,
43
+ task_switch: dict,
44
+ captioning_step: int,
45
+ enforce_input_project: bool,
46
+ ):
47
+ """
48
+ NOTE: this interface is experimental.
49
+ Args:
50
+ in_channels: channels of the input features
51
+ mask_classification: whether to add mask classifier or not
52
+ num_classes: number of classes
53
+ hidden_dim: Transformer feature dimension
54
+ num_queries: number of queries
55
+ nheads: number of heads
56
+ dim_feedforward: feature dimension in feedforward network
57
+ enc_layers: number of Transformer encoder layers
58
+ dec_layers: number of Transformer decoder layers
59
+ pre_norm: whether to use pre-LayerNorm or not
60
+ mask_dim: mask feature dimension
61
+ enforce_input_project: add input project 1x1 conv even if input
62
+ channels and hidden dim is identical
63
+ """
64
+ super().__init__()
65
+ assert mask_classification, "Only support mask classification model"
66
+ self.mask_classification = mask_classification
67
+
68
+ # positional encoding
69
+ N_steps = hidden_dim // 2
70
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
71
+
72
+ # define Transformer decoder here
73
+ self.num_heads = nheads
74
+ self.num_layers = dec_layers
75
+ self.contxt_len = contxt_len
76
+ self.transformer_self_attention_layers = nn.ModuleList()
77
+ self.transformer_cross_attention_layers = nn.ModuleList()
78
+ self.transformer_ffn_layers = nn.ModuleList()
79
+
80
+ for _ in range(self.num_layers):
81
+ self.transformer_self_attention_layers.append(
82
+ SelfAttentionLayer(
83
+ d_model=hidden_dim,
84
+ nhead=nheads,
85
+ dropout=0.0,
86
+ normalize_before=pre_norm,
87
+ )
88
+ )
89
+
90
+ self.transformer_cross_attention_layers.append(
91
+ CrossAttentionLayer(
92
+ d_model=hidden_dim,
93
+ nhead=nheads,
94
+ dropout=0.0,
95
+ normalize_before=pre_norm,
96
+ )
97
+ )
98
+
99
+ self.transformer_ffn_layers.append(
100
+ FFNLayer(
101
+ d_model=hidden_dim,
102
+ dim_feedforward=dim_feedforward,
103
+ dropout=0.0,
104
+ normalize_before=pre_norm,
105
+ )
106
+ )
107
+
108
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
109
+
110
+ self.num_queries = num_queries
111
+ # learnable query features
112
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
113
+ # learnable query p.e.
114
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
115
+
116
+ # level embedding (we always use 3 scales)
117
+ self.num_feature_levels = 3
118
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
119
+ self.input_proj = nn.ModuleList()
120
+
121
+ for _ in range(self.num_feature_levels):
122
+ if in_channels != hidden_dim or enforce_input_project:
123
+ self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
124
+ weight_init.c2_xavier_fill(self.input_proj[-1])
125
+ else:
126
+ self.input_proj.append(nn.Sequential())
127
+
128
+ self.task_switch = task_switch
129
+
130
+ # output FFNs
131
+ self.lang_encoder = lang_encoder
132
+ if self.task_switch['mask']:
133
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
134
+
135
+ self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
136
+ trunc_normal_(self.class_embed, std=.02)
137
+
138
+ if task_switch['bbox']:
139
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
140
+
141
+ # Caption Project and query
142
+ if task_switch['captioning']:
143
+ self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj))
144
+ trunc_normal_(self.caping_embed, std=.02)
145
+ self.pos_embed_caping = nn.Embedding(contxt_len, hidden_dim)
146
+ self.captioning_step = captioning_step
147
+
148
+ # register self_attn_mask to avoid information leakage, it includes interaction between object query, class query and caping query
149
+ self_attn_mask = torch.zeros((1, num_queries + contxt_len, num_queries + contxt_len)).bool()
150
+ self_attn_mask[:, :num_queries, num_queries:] = True # object+class query does not attend with caption query.
151
+ self_attn_mask[:, num_queries:, num_queries:] = torch.triu(torch.ones((1, contxt_len, contxt_len)), diagonal=1).bool() # caption query only attend with previous token.
152
+ self_attn_mask[:, :num_queries-1, num_queries-1:num_queries] = True # object query does not attend with class query.
153
+ self_attn_mask[:, num_queries-1:num_queries, :num_queries-1] = True # class query does not attend with object query.
154
+ self.register_buffer("self_attn_mask", self_attn_mask)
155
+
156
+
157
+ @classmethod
158
+ def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra):
159
+ ret = {}
160
+
161
+ ret["lang_encoder"] = lang_encoder
162
+ ret["in_channels"] = in_channels
163
+ ret["mask_classification"] = mask_classification
164
+
165
+ enc_cfg = cfg['MODEL']['ENCODER']
166
+ dec_cfg = cfg['MODEL']['DECODER']
167
+
168
+ ret["hidden_dim"] = dec_cfg['HIDDEN_DIM']
169
+ ret["dim_proj"] = cfg['MODEL']['DIM_PROJ']
170
+ ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES']
171
+ ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
172
+
173
+ # Transformer parameters:
174
+ ret["nheads"] = dec_cfg['NHEADS']
175
+ ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD']
176
+
177
+ # NOTE: because we add learnable query features which requires supervision,
178
+ # we add minus 1 to decoder layers to be consistent with our loss
179
+ # implementation: that is, number of auxiliary losses is always
180
+ # equal to number of decoder layers. With learnable query features, the number of
181
+ # auxiliary losses equals number of decoders plus 1.
182
+ assert dec_cfg['DEC_LAYERS'] >= 1
183
+ ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1
184
+ ret["pre_norm"] = dec_cfg['PRE_NORM']
185
+ ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ']
186
+ ret["mask_dim"] = enc_cfg['MASK_DIM']
187
+
188
+ ret["task_switch"] = extra['task_switch']
189
+ ret["captioning_step"] = dec_cfg['CAPTIONING'].get('STEP', 50)
190
+
191
+ return ret
192
+
193
+ def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}):
194
+ if task == 'captioning_infer':
195
+ return self.forward_captioning(x, mask_features, mask=mask, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra)
196
+ # x is a list of multi-scale feature
197
+ assert len(x) == self.num_feature_levels
198
+ src = []
199
+ pos = []
200
+ size_list = []
201
+
202
+ # disable mask, it does not affect performance
203
+ del mask
204
+ for i in range(self.num_feature_levels):
205
+ size_list.append(x[i].shape[-2:])
206
+ pos.append(self.pe_layer(x[i], None).flatten(2))
207
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
208
+
209
+ # flatten NxCxHxW to HWxNxC
210
+ pos[-1] = pos[-1].permute(2, 0, 1)
211
+ src[-1] = src[-1].permute(2, 0, 1)
212
+
213
+ _, bs, _ = src[0].shape
214
+
215
+ # QxNxC
216
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
217
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
218
+
219
+ predictions_class = []
220
+ predictions_mask = []
221
+ predictions_bbox = []
222
+ predictions_caption = []
223
+ predictions_captioning = []
224
+
225
+ self_tgt_mask = None
226
+ if self.training and task == 'vlp' and self.task_switch['captioning']:
227
+ # output = torch.cat((output, self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1)), dim=0) # concat object query, class token and caption token.
228
+ caping_lang_embed = torch.cat([caption['caption_tokens'] for caption in target_vlp], dim=0).transpose(0, 1) # language output
229
+ _caping_lang_embed = caping_lang_embed.detach().clone()
230
+ output = torch.cat((output, _caping_lang_embed), dim=0) # concat object query, class token and caption token.
231
+ caping_lang_embed += self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1)
232
+ query_embed = torch.cat((query_embed, caping_lang_embed), dim=0) # may not add at the beginning.
233
+ self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1)
234
+ elif (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']):
235
+ self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1)
236
+ grounding_tokens = extra['grounding_tokens']
237
+ _grounding_tokens = grounding_tokens.detach().clone()
238
+ # initialize with negative attention at the beginning.
239
+ pad_tgt_mask = torch.ones((1, self.num_queries + (self.num_queries-1) + len(grounding_tokens), self.num_queries + (self.num_queries-1) + len(grounding_tokens)), device=self_tgt_mask.device).bool().repeat(output.shape[1]*self.num_heads, 1, 1)
240
+ pad_tgt_mask[:,:self.num_queries,:self.num_queries] = self_tgt_mask
241
+ pad_tgt_mask[:,self.num_queries:,self.num_queries:] = False # grounding tokens could attend with eatch other
242
+ self_tgt_mask = pad_tgt_mask
243
+ output = torch.cat((output, output[:-1]), dim=0)
244
+ query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0) # also pad language embdding to fix embedding
245
+ else:
246
+ self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1)
247
+
248
+ # prediction heads on learnable query features
249
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task)
250
+ attn_mask = results["attn_mask"]
251
+ predictions_class.append(results["outputs_class"])
252
+ predictions_mask.append(results["outputs_mask"])
253
+ predictions_bbox.append(results["outputs_bbox"])
254
+ predictions_caption.append(results["outputs_caption"])
255
+ predictions_captioning.append(results["outputs_captionting"])
256
+
257
+ for i in range(self.num_layers):
258
+ level_index = i % self.num_feature_levels
259
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
260
+
261
+ if self.training and task == 'vlp' and self.task_switch['captioning']:
262
+ attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1)
263
+ # attention: cross-attention first
264
+ output, avg_attn = self.transformer_cross_attention_layers[i](
265
+ output, src[level_index],
266
+ memory_mask=attn_mask,
267
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
268
+ pos=pos[level_index], query_pos=query_embed
269
+ )
270
+
271
+ if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']):
272
+ output = torch.cat((output, _grounding_tokens), dim=0)
273
+ query_embed = torch.cat((query_embed, grounding_tokens), dim=0)
274
+
275
+ output = self.transformer_self_attention_layers[i](
276
+ output, tgt_mask=self_tgt_mask,
277
+ tgt_key_padding_mask=None,
278
+ query_pos=query_embed
279
+ )
280
+
281
+ # FFN
282
+ output = self.transformer_ffn_layers[i](
283
+ output
284
+ )
285
+
286
+ if ((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']:
287
+ _grounding_tokens = output[-len(_grounding_tokens):]
288
+ output = output[:-len(_grounding_tokens)]
289
+ query_embed = query_embed[:-len(_grounding_tokens)]
290
+
291
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task)
292
+ attn_mask = results["attn_mask"]
293
+ predictions_class.append(results["outputs_class"])
294
+ predictions_mask.append(results["outputs_mask"])
295
+ predictions_bbox.append(results["outputs_bbox"])
296
+ predictions_caption.append(results["outputs_caption"])
297
+ predictions_captioning.append(results["outputs_captionting"])
298
+
299
+ assert len(predictions_class) == self.num_layers + 1
300
+ if task == 'vlp':
301
+ out = {'pred_captionings': predictions_captioning[-1],
302
+ 'pred_captions': predictions_caption[-1],
303
+ 'aux_outputs': [{'pred_captionings': x, 'pred_captions': y } for x, y in zip(predictions_captioning[:-1], predictions_caption[:-1])]}
304
+ return out
305
+ else:
306
+ out = {
307
+ 'pred_logits': predictions_class[-1],
308
+ 'pred_masks': predictions_mask[-1],
309
+ 'pred_boxes': predictions_bbox[-1],
310
+ 'pred_captions': predictions_caption[-1],
311
+ 'aux_outputs': self._set_aux_loss(
312
+ predictions_class if self.mask_classification else None, predictions_mask, predictions_bbox, predictions_caption
313
+ )
314
+ }
315
+ return out
316
+
317
+ def forward_captioning(self, x, mask_features, mask = None, target_queries = None, target_vlp = None, task='seg', extra={}):
318
+ # x is a list of multi-scale feature
319
+ assert len(x) == self.num_feature_levels
320
+ src = []
321
+ pos = []
322
+ size_list = []
323
+
324
+ # disable mask, it does not affect performance
325
+ del mask
326
+ for i in range(self.num_feature_levels):
327
+ size_list.append(x[i].shape[-2:])
328
+ pos.append(self.pe_layer(x[i], None).flatten(2))
329
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
330
+
331
+ # flatten NxCxHxW to HWxNxC
332
+ pos[-1] = pos[-1].permute(2, 0, 1)
333
+ src[-1] = src[-1].permute(2, 0, 1)
334
+
335
+ _, bs, _ = src[0].shape
336
+
337
+ # QxNxC
338
+ query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
339
+ query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
340
+ caping_lang_token = extra['start_token'].repeat(bs, 1)
341
+ pos_embed_caping = self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1)
342
+
343
+ # prepare token embedding for evaluation
344
+ token_embs = self.lang_encoder.lang_encoder.token_embedding.weight
345
+ # token_embs = (token_embs / token_embs.norm(dim=-1, keepdim=True) + 1e-7)
346
+
347
+ for cap_idx in range(0, self.captioning_step):
348
+ caping_lang_embed = self.lang_encoder.forward_language_token((caping_lang_token,))[0].transpose(0, 1)
349
+ output = torch.cat((query_feat, caping_lang_embed), dim=0) # concat object query, class token and caption token.
350
+ caping_lang_embed += pos_embed_caping
351
+ query_embed = torch.cat((query_embed_, caping_lang_embed), dim=0) # may not add at the beginning.
352
+ # output = torch.cat((query_feat, query_feat_caping), dim=0) # concat object query, class token and caption token.
353
+
354
+ # prediction heads on learnable query features
355
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task)
356
+ attn_mask = results["attn_mask"]
357
+
358
+ for i in range(self.num_layers):
359
+ level_index = i % self.num_feature_levels
360
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
361
+ attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1)
362
+ self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1)
363
+
364
+ if extra['captioning_mask'] is not None:
365
+ bs,nq,wh = attn_mask.shape
366
+ assert bs==self.num_heads, "Only support single image referring captioning."
367
+ cap_mask = extra['captioning_mask']
368
+ attn_mask = attn_mask.reshape(bs,nq,size_list[i%3][0],size_list[i%3][1])
369
+ cap_mask = F.interpolate(cap_mask[None,].float(), size_list[i%3], mode='nearest').bool()[0,0]
370
+ attn_mask[:,self.num_queries:, cap_mask] = True
371
+ attn_mask = attn_mask.reshape(bs,nq,wh)
372
+
373
+ # attention: cross-attention first
374
+ output, avg_attn = self.transformer_cross_attention_layers[i](
375
+ output, src[level_index],
376
+ memory_mask=attn_mask,
377
+ memory_key_padding_mask=None, # here we do not apply masking on padded region
378
+ pos=pos[level_index], query_pos=query_embed
379
+ )
380
+
381
+ output = self.transformer_self_attention_layers[i](
382
+ output, tgt_mask=self_tgt_mask,
383
+ tgt_key_padding_mask=None,
384
+ query_pos=query_embed
385
+ )
386
+
387
+ # FFN
388
+ output = self.transformer_ffn_layers[i](
389
+ output
390
+ )
391
+
392
+ results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task)
393
+ attn_mask = results["attn_mask"]
394
+
395
+ pred_captions_gen = results['outputs_captionting']
396
+ # pred_captions_gen = (pred_captions_gen / pred_captions_gen.norm(dim=-1, keepdim=True) + 1e-7)
397
+ pred_captions_gen = pred_captions_gen @ token_embs.t()
398
+ caping_lang_token[:,cap_idx+1] = pred_captions_gen[:,cap_idx].max(-1)[1]
399
+
400
+ texts = self.lang_encoder.tokenizer.batch_decode(caping_lang_token, skip_special_tokens=False)
401
+ texts_new = []
402
+
403
+ for x in texts:
404
+ x = x.split('<|endoftext|>')[0]
405
+ x = x.replace('<|endoftext|>','')
406
+ x = x.replace('<|startoftext|>','')
407
+ x = x.strip()
408
+ texts_new.append(x)
409
+
410
+ out = {'pred_captionings': caping_lang_token,
411
+ 'pred_texts': texts_new}
412
+ return out
413
+
414
+
415
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1, task='seg'):
416
+ decoder_output = self.decoder_norm(output)
417
+ decoder_output = decoder_output.transpose(0, 1)
418
+
419
+ # extract image captioning token from decoder output.
420
+ if self.task_switch['captioning'] and (task == 'vlp' or task == 'captioning_infer'):
421
+ outputs_captionting = decoder_output[:,self.num_queries:] @ self.caping_embed
422
+ else:
423
+ outputs_captionting = None
424
+
425
+ # recompute class token output.
426
+ norm_decoder_output = decoder_output / (decoder_output.norm(dim=-1, keepdim=True) + 1e-7)
427
+ obj_token = norm_decoder_output[:,:self.num_queries-1]
428
+ cls_token = norm_decoder_output[:,self.num_queries-1:self.num_queries]
429
+
430
+ sim = (cls_token @ obj_token.transpose(1,2)).softmax(-1)[:,0,:,None] # TODO include class token.
431
+ cls_token = (sim * decoder_output[:,:self.num_queries-1]).sum(dim=1, keepdim=True)
432
+
433
+ if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']):
434
+ decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token, decoder_output[:,self.num_queries:2*self.num_queries-1]), dim=1)
435
+ else:
436
+ decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token), dim=1)
437
+
438
+ # compute class, mask and bbox.
439
+ class_embed = decoder_output @ self.class_embed
440
+ # HACK do not compute similarity if mask is not on
441
+ outputs_class = self.lang_encoder.compute_similarity(class_embed, fake=(((not self.task_switch['mask']) and self.training)))
442
+
443
+ if self.task_switch['mask']:
444
+ mask_embed = self.mask_embed(decoder_output)
445
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
446
+
447
+ # NOTE: prediction is of higher-resolution
448
+ # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
449
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bicubic", align_corners=False, antialias=True)
450
+
451
+ # must use bool type
452
+ # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
453
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
454
+ attn_mask = attn_mask.detach()
455
+
456
+ # NOTE: fill False for cls token (JY)
457
+ attn_mask[:, self.num_queries:self.num_queries+1].fill_(False)
458
+ else:
459
+ outputs_mask = None
460
+ attn_mask = torch.zeros((list(decoder_output.shape[:2]) + [attn_mask_target_size[0]*attn_mask_target_size[1]]), device=decoder_output.device).repeat(self.num_heads, 1, 1).bool()
461
+
462
+ outputs_bbox = [None for i in range(len(decoder_output))]
463
+ if self.task_switch['bbox']:
464
+ outputs_bbox = self.bbox_embed(decoder_output)
465
+
466
+ outputs_caption = None
467
+ if self.task_switch['caption']:
468
+ outputs_caption = class_embed
469
+
470
+
471
+ results = {
472
+ "outputs_class": outputs_class,
473
+ "outputs_mask": outputs_mask,
474
+ "outputs_bbox": outputs_bbox,
475
+ "attn_mask": attn_mask,
476
+ "outputs_caption": outputs_caption,
477
+ "outputs_captionting": outputs_captionting,
478
+ }
479
+ return results
480
+
481
+ @torch.jit.unused
482
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks, outputs_boxes, outputs_captions):
483
+ # this is a workaround to make torchscript happy, as torchscript
484
+ # doesn't support dictionary with non-homogeneous values, such
485
+ # as a dict having both a Tensor and a list.
486
+ if self.mask_classification:
487
+ return [
488
+ {"pred_logits": a, "pred_masks": b, "pred_boxes": c, "pred_captions": d}
489
+ for a, b, c, d in zip(outputs_class[:-1], outputs_seg_masks[:-1], outputs_boxes[:-1], outputs_captions[:-1])
490
+ ]
491
+ else:
492
+ return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
493
+
494
+
495
+ @register_decoder
496
+ def get_xdecoder_interface(cfg, in_channels, lang_encoder, mask_classification, extra):
497
+ return XDecoder(cfg, in_channels, lang_encoder, mask_classification, extra)
modeling/language/LangEncoder/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPTokenizer, CLIPTokenizerFast
2
+ from transformers import AutoTokenizer
3
+
4
+ from .transformer import *
5
+ from .build import *
6
+
7
+
8
+ def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
9
+ model_name = config_encoder['NAME']
10
+
11
+ if not is_lang_encoder(model_name):
12
+ raise ValueError(f'Unkown model: {model_name}')
13
+
14
+ return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs)
15
+
16
+ def build_tokenizer(config_encoder):
17
+ tokenizer = None
18
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
19
+ if config_encoder['TOKENIZER'] == 'clip':
20
+ pretrained_tokenizer = config_encoder.get(
21
+ 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
22
+ )
23
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer)
24
+ tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token})
25
+ elif config_encoder['TOKENIZER'] == 'clip-fast':
26
+ pretrained_tokenizer = config_encoder.get(
27
+ 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
28
+ )
29
+ tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True)
30
+ elif config_encoder['TOKENIZER'] == 'biomed-clip':
31
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
32
+ else:
33
+ tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER'])
34
+
35
+ return tokenizer
modeling/language/LangEncoder/build.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _lang_encoders = {}
2
+
3
+
4
+ def register_lang_encoder(fn):
5
+ module_name_split = fn.__module__.split('.')
6
+ model_name = module_name_split[-1]
7
+
8
+ _lang_encoders[model_name] = fn
9
+
10
+ return fn
11
+
12
+ def lang_encoders(model_name):
13
+ return _lang_encoders[model_name]
14
+
15
+ def is_lang_encoder(model_name):
16
+ return model_name in _lang_encoders
modeling/language/LangEncoder/transformer.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+ import logging
4
+ import os
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from timm.models.layers import DropPath, trunc_normal_
12
+
13
+ from .build import register_lang_encoder
14
+ from utilities.distributed import is_main_process
15
+ from utilities.model import register_norm_module
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @register_norm_module
21
+ class LayerNorm(nn.Module):
22
+ def __init__(self, hidden_size, eps=1e-12):
23
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
24
+ """
25
+ super(LayerNorm, self).__init__()
26
+ self.weight = nn.Parameter(torch.ones(hidden_size))
27
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
28
+ self.variance_epsilon = eps
29
+
30
+ def forward(self, x):
31
+ pdtype = x.dtype
32
+ x = x.float()
33
+ u = x.mean(-1, keepdim=True)
34
+ s = (x - u).pow(2).mean(-1, keepdim=True)
35
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
36
+ return self.weight * x.to(pdtype) + self.bias
37
+
38
+
39
+ class QuickGELU(nn.Module):
40
+ def forward(self, x: torch.Tensor):
41
+ return x * torch.sigmoid(1.702 * x)
42
+
43
+
44
+ class ResidualAttentionBlock(nn.Module):
45
+ def __init__(self,
46
+ d_model: int,
47
+ n_head: int,
48
+ attn_mask: torch.Tensor = None,
49
+ drop_path: float = 0.0):
50
+ super().__init__()
51
+
52
+ self.attn = nn.MultiheadAttention(d_model, n_head)
53
+ self.ln_1 = LayerNorm(d_model)
54
+ self.mlp = nn.Sequential(OrderedDict([
55
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
56
+ ("gelu", QuickGELU()),
57
+ ("c_proj", nn.Linear(d_model * 4, d_model))
58
+ ]))
59
+ self.ln_2 = LayerNorm(d_model)
60
+ self.attn_mask = attn_mask
61
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
62
+
63
+ def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
64
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
65
+ if self.attn_mask is not None else None
66
+
67
+
68
+ return self.attn(
69
+ x, x, x,
70
+ key_padding_mask=key_padding_mask,
71
+ need_weights=False,
72
+ attn_mask=self.attn_mask
73
+ )[0]
74
+
75
+ def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
76
+ x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
77
+ x = x + self.drop_path(self.mlp(self.ln_2(x)))
78
+ return x
79
+
80
+
81
+ class Transformer(nn.Module):
82
+ def __init__(self,
83
+ context_length: int,
84
+ vocab_size: int,
85
+ width: int,
86
+ layers: int,
87
+ heads: int,
88
+ drop_path: float = 0.0,
89
+ autogressive: bool =True):
90
+ super().__init__()
91
+
92
+ self.token_embedding = nn.Embedding(vocab_size, width)
93
+
94
+ self.context_length = context_length
95
+ self.positional_embedding = nn.Parameter(
96
+ torch.empty(self.context_length, width)
97
+ )
98
+
99
+ self.width = width
100
+ self.layers = layers
101
+ self.autogressive = autogressive
102
+ attn_mask = self.build_attention_mask() if autogressive else None
103
+ dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule
104
+ self.resblocks = nn.ModuleList(
105
+ [
106
+ ResidualAttentionBlock(width, heads, attn_mask, dpr[i])
107
+ for i in range(layers)
108
+ ]
109
+ )
110
+
111
+ self.ln_final = LayerNorm(width)
112
+
113
+ trunc_normal_(self.positional_embedding, std=.02)
114
+ # nn.init.normal_(self.token_embedding, std=.02)
115
+ trunc_normal_(self.token_embedding.weight, std=.02)
116
+ self.apply(self._init_weights)
117
+
118
+ @property
119
+ def dim_out(self):
120
+ return self.width
121
+
122
+ def build_attention_mask(self):
123
+ # lazily create causal attention mask, with full attention between the vision tokens
124
+ # pytorch uses additive attention mask; fill with -inf
125
+ mask = torch.empty(self.context_length, self.context_length)
126
+ mask.fill_(float("-inf"))
127
+ mask.triu_(1) # zero out the lower diagonal
128
+ return mask
129
+
130
+ def _init_weights(self, m):
131
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
132
+ if is_main_process():
133
+ logger.info('=> init weight of Linear/Conv2d from trunc norm')
134
+ trunc_normal_(m.weight, std=0.02)
135
+ if m.bias is not None:
136
+ if is_main_process():
137
+ logger.info('=> init bias of Linear/Conv2d to zeros')
138
+ nn.init.constant_(m.bias, 0)
139
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
140
+ nn.init.constant_(m.bias, 0)
141
+
142
+ def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
143
+ if os.path.isfile(pretrained):
144
+ pretrained_dict = torch.load(pretrained, map_location='cpu')
145
+ logging.info(f'=> loading pretrained model {pretrained}')
146
+ model_dict = self.state_dict()
147
+ stripped_key = lambda x: x[13:] if x.startswith('lang_encoder.') else x
148
+ pretrained_dict = {
149
+ stripped_key(k): v for k, v in pretrained_dict.items()
150
+ if stripped_key(k) in model_dict.keys()
151
+ }
152
+ need_init_state_dict = {}
153
+ for k, v in pretrained_dict.items():
154
+ need_init = (
155
+ k.split('.')[0] in pretrained_layers
156
+ or pretrained_layers[0] == '*'
157
+ )
158
+ if need_init:
159
+ if verbose:
160
+ logger.info(f'=> init {k} from {pretrained}')
161
+
162
+ if 'positional_embedding' in k and v.size() != model_dict[k].size():
163
+ positional_embedding_pretrained = v
164
+ positional_embedding_current = model_dict[k]
165
+ L1, nH1 = positional_embedding_pretrained.size()
166
+ L2, nH2 = positional_embedding_current.size()
167
+ if nH1 != nH2:
168
+ logger.info(f"Error in loading {k}, passing")
169
+ else:
170
+ if L1 != L2:
171
+ logger.info(
172
+ '=> load_pretrained: resized variant: {} to {}'
173
+ .format((L1, nH1), (L2, nH2))
174
+ )
175
+
176
+ posemb = positional_embedding_pretrained.float()
177
+ posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1)
178
+ posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear')
179
+ posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0)
180
+ v = posemb_grid
181
+
182
+ need_init_state_dict[k] = v
183
+
184
+ self.load_state_dict(need_init_state_dict, strict=False)
185
+
186
+
187
+ @torch.jit.ignore
188
+ def no_weight_decay(self):
189
+ return {
190
+ 'positional_embedding',
191
+ 'token_embedding',
192
+ }
193
+
194
+ def forward(self, input_ids, attention_mask=None):
195
+ key_padding_mask = (attention_mask == 0) if (not self.autogressive and attention_mask is not None) else None
196
+ # key_padding_mask = (input_ids == 0) if not self.autogressive else None
197
+ x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model]
198
+ x = x + self.positional_embedding
199
+ x = x.permute(1, 0, 2) # NLD -> LND
200
+ for block in self.resblocks:
201
+ x = block(x, key_padding_mask)
202
+ x = x.permute(1, 0, 2) # LND -> NLD
203
+
204
+ x = self.ln_final(x)
205
+
206
+ return {'last_hidden_state': x}
207
+
208
+
209
+ @register_lang_encoder
210
+ def lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
211
+ transformer = Transformer(
212
+ context_length=config_encoder['CONTEXT_LENGTH'],
213
+ vocab_size=tokenizer.vocab_size,
214
+ width=config_encoder['WIDTH'],
215
+ layers=config_encoder['LAYERS'],
216
+ heads=config_encoder['HEADS'],
217
+ autogressive=config_encoder.get('AUTOGRESSIVE', True)
218
+ )
219
+
220
+ if config_encoder.get('LOAD_PRETRAINED', False):
221
+ transformer.load_pretrained(config_encoder['PRETRAINED'], config_encoder.get('PRETRAINED_LAYERS', ['*']))
222
+ return transformer
modeling/language/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .vlpencoder import *
2
+ from .build import *
3
+
4
+ def build_language_encoder(config, **kwargs):
5
+ model_name = config['MODEL']['TEXT']['ARCH']
6
+
7
+ if not is_model(model_name):
8
+ raise ValueError(f'Unkown model: {model_name}')
9
+
10
+ return model_entrypoints(model_name)(config, **kwargs)
modeling/language/build.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _model_entrypoints = {}
2
+
3
+
4
+ def register_model(fn):
5
+ module_name_split = fn.__module__.split('.')
6
+ model_name = module_name_split[-1]
7
+ _model_entrypoints[model_name] = fn
8
+ return fn
9
+
10
+ def model_entrypoints(model_name):
11
+ return _model_entrypoints[model_name]
12
+
13
+ def is_model(model_name):
14
+ return model_name in _model_entrypoints
modeling/language/loss.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+
8
+ import pickle
9
+ from distutils import log
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torch.distributed as dist
14
+
15
+ from einops import rearrange, repeat
16
+ from timm.loss import SoftTargetCrossEntropy
17
+
18
+ soft_cross_entropy = SoftTargetCrossEntropy()
19
+
20
+ def is_dist_initialized():
21
+ return torch.distributed.is_initialized()
22
+
23
+ def get_world_size():
24
+ if is_dist_initialized():
25
+ return torch.distributed.get_world_size()
26
+ return 1
27
+
28
+ def get_rank():
29
+ if is_dist_initialized():
30
+ return dist.get_rank()
31
+ return 0
32
+
33
+ def all_gather_grad(x):
34
+ if get_world_size() > 1:
35
+ all_x = [torch.zeros_like(x) for _ in range(get_world_size())]
36
+ torch.distributed.all_gather(all_x, x)
37
+ all_x[torch.distributed.get_rank()] = x
38
+ x = torch.cat(all_x, dim=0)
39
+ return x
40
+
41
+ def vl_multilabel_contrastive_loss(image_feat, text_feat, temperature=1):
42
+ """
43
+ Args:
44
+ image_feat (torch.Tensor): shape [B, L1, C] # B: batch_size, L1: 1, C: 256
45
+ text_feat (torch.Tensor): shape [B, L2, C] # B:batch_size, L2: number of selected nouns, C: 256
46
+
47
+ Returns:
48
+ """
49
+ # [B, L1, C], L1 = 1
50
+ # image_feat = F.normalize(image_feat, dim=-1)
51
+ # [B, L2, C]
52
+ # text_feat = F.normalize(text_feat, dim=-1)
53
+ # HACK: normalize outside
54
+
55
+ # [B, L1, L2]
56
+ dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l')
57
+ # [B, L2, L1]
58
+ dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l')
59
+
60
+ batch = image_feat.shape[0]
61
+ img_len = image_feat.shape[1]
62
+ text_len = text_feat.shape[1]
63
+ # [B, L1, L2]
64
+ pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2')
65
+ # [B, L2, L1]
66
+ pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1')
67
+
68
+ image_x = rearrange(image_feat, 'b l c -> (b l) c')
69
+ text_x = rearrange(text_feat, 'b l c -> (b l) c')
70
+
71
+ logits_per_img = image_x @ all_gather_grad(text_x).t()
72
+ logits_per_text = text_x @ all_gather_grad(image_x).t()
73
+
74
+ # get label globally
75
+ # [B, L1, B, L2, W]
76
+ labels_per_img = F.one_hot(
77
+ torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * get_rank(),
78
+ num_classes=get_world_size()).to(image_x.dtype)
79
+ labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat(
80
+ torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1')
81
+ # [BxL1, WxBxL2]
82
+ labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)')
83
+ # [B, L2, B, L1, W]
84
+ labels_per_text = F.one_hot(
85
+ torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * get_rank(),
86
+ num_classes=get_world_size()).to(text_x.dtype)
87
+ labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat(
88
+ torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1')
89
+ # [BxL2, WxBxL1]
90
+ labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)')
91
+
92
+ logit_scale = temperature.exp().clamp(max=100)
93
+
94
+ loss_img = soft_cross_entropy(logit_scale * logits_per_img, labels_per_img)
95
+ loss_text = soft_cross_entropy(logit_scale * logits_per_text, labels_per_text)
96
+
97
+ loss = 0.5 * (loss_img + loss_text)
98
+ return loss
99
+
100
+ def vl_contrastive_loss(image_feat, text_feat, temperature=1):
101
+ # if image_id or text_id is None, it should be None across all GPUs
102
+ # image_feat = F.normalize(image_feat, dim=1)
103
+ # text_feat = F.normalize(text_feat, dim=1)
104
+ # handle normalization outside
105
+
106
+ # add the following 4 lines
107
+ image_feat = all_gather_grad(image_feat)
108
+ text_feat = all_gather_grad(text_feat)
109
+
110
+ logits = torch.matmul(image_feat, text_feat.t())
111
+ logit_scale = temperature.exp().clamp(max=100)
112
+
113
+ gt = torch.arange(logits.shape[0], device=logits.device)
114
+ loss1 = F.cross_entropy(logit_scale * logits, gt)
115
+ loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
116
+ return (loss1 + loss2) / 2 # scale it up by the number of GPUs
117
+
118
+
119
+ def all_gather_pickle(data, device):
120
+ """
121
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
122
+ Args:
123
+ data: any picklable object
124
+ Returns:
125
+ list[data]: list of data gathered from each rank
126
+ """
127
+ world_size = get_world_size()
128
+ if world_size == 1:
129
+ return [data]
130
+
131
+ # serialized to a Tensor
132
+ buffer = pickle.dumps(data)
133
+ storage = torch.ByteStorage.from_buffer(buffer)
134
+ tensor = torch.ByteTensor(storage).to(device)
135
+
136
+ # obtain Tensor size of each rank
137
+ local_size = torch.LongTensor([tensor.numel()]).cuda()
138
+ size_list = [torch.LongTensor([0]).cuda() for _ in range(world_size)]
139
+ dist.all_gather(size_list, local_size)
140
+ size_list = [int(size.item()) for size in size_list]
141
+ max_size = max(size_list)
142
+
143
+ # receiving Tensor from all ranks
144
+ # we pad the tensor because torch all_gather does not support
145
+ # gathering tensors of different shapes
146
+ tensor_list = []
147
+ for _ in size_list:
148
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).cuda())
149
+ if local_size != max_size:
150
+ padding = torch.ByteTensor(size=(max_size - local_size,)).cuda()
151
+ tensor = torch.cat((tensor, padding), dim=0)
152
+ dist.all_gather(tensor_list, tensor)
153
+
154
+ data_list = []
155
+ for size, tensor in zip(size_list, tensor_list):
156
+ buffer = tensor.cpu().numpy().tobytes()[:size]
157
+ data_list.append(pickle.loads(buffer))
158
+
159
+ return data_list
160
+
161
+ def all_gather_arbitary_tensor(tensor):
162
+ if get_world_size() > 1:
163
+ device = tensor.device
164
+ tensor_batch = all_gather_pickle(tensor.cpu(), device)
165
+ tensor_batch = [x.to(device) for x in tensor_batch]
166
+ tensor_batch[torch.distributed.get_rank()] = tensor
167
+ tensor_batch = torch.cat(tensor_batch, dim=0)
168
+ else:
169
+ tensor_batch = tensor
170
+ return tensor_batch
171
+
172
+ def ql_contrastive_loss(image_feat, text_feat, temperature=1):
173
+ # add the following 4 lines
174
+ image_feat = all_gather_arbitary_tensor(image_feat)
175
+ text_feat = all_gather_arbitary_tensor(text_feat)
176
+
177
+ logits = torch.matmul(image_feat, text_feat.t())
178
+ logit_scale = temperature.exp().clamp(max=100)
179
+
180
+ gt = torch.arange(logits.shape[0], device=logits.device)
181
+ loss1 = F.cross_entropy(logit_scale * logits, gt)
182
+ loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
183
+ return (loss1 + loss2) / 2 # scale it up by the number of GPUs
184
+
185
+ def vl_similarity(image_feat, text_feat, temperature=1):
186
+ # Only support single GPU for now.
187
+ logits = torch.matmul(image_feat, text_feat.t())
188
+ logits = temperature.exp().clamp(max=100) * logits
189
+ return logits
190
+
191
+ def ql_multi_contrastive_loss(image_feat, text_feat, text_hash, temperature=1):
192
+ # add the following 4 lines
193
+ image_feat = all_gather_arbitary_tensor(image_feat)
194
+ text_feat = all_gather_arbitary_tensor(text_feat)
195
+
196
+ text_hash_batch = all_gather_pickle(text_hash, text_feat.device)
197
+ text_hash_all = torch.cat(text_hash_batch)
198
+
199
+ text_hash_all_unique = torch.unique(text_hash_all).tolist()
200
+ gt = torch.zeros((image_feat.shape[0], len(text_hash_all_unique)), device=text_feat.device)
201
+ text_hash_all = text_hash_all.tolist()
202
+ text_feat_unique = torch.stack([text_feat[text_hash_all.index(txt)] for txt in text_hash_all_unique])
203
+
204
+ for idx, txt in enumerate(text_hash_all):
205
+ gt[idx][text_hash_all_unique.index(txt)] = 1
206
+
207
+ logits = torch.matmul(image_feat, text_feat_unique.t())
208
+ logits = logits*temperature.exp().clamp(max=100)
209
+
210
+ loss_img = soft_cross_entropy(logits, gt)
211
+ loss_text = soft_cross_entropy(logits.t(), gt.t() / gt.t().sum(-1, keepdim=True))
212
+
213
+ loss = 0.7 * loss_img + 0.3 * loss_text
214
+ return loss
215
+
216
+ def image_text_contrastive_loss_queue(image_feat_inp, text_feat_inp, lang_enc, training):
217
+ # add the following 4 lines
218
+ image_feat = all_gather_grad(image_feat_inp.contiguous())
219
+ text_feat = all_gather_grad(text_feat_inp.contiguous())
220
+
221
+ image_feat = image_feat / (image_feat.norm(dim=-1, keepdim=True) + 1e-7)
222
+ text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-7)
223
+
224
+ temperature = lang_enc.logit_scale
225
+ logits = torch.matmul(image_feat, text_feat.t())
226
+ logit_scale = temperature.exp().clamp(max=100)
227
+
228
+ gt = torch.arange(logits.shape[0], device=logits.device)
229
+ loss1 = F.cross_entropy(logit_scale * logits, gt)
230
+ loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
231
+
232
+ return (loss1 + loss2) / 2 # scale it up by the number of GPUs
modeling/language/misc.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ import nltk
5
+ import numpy as np
6
+
7
+ from utilities.constants import IMAGENET_DEFAULT_TEMPLATES
8
+
9
+ nltk.download('punkt', quiet=True)
10
+ nltk.download('averaged_perceptron_tagger', quiet=True)
11
+
12
+ def get_tag(tokenized, tags):
13
+ if not isinstance(tags, (list, tuple)):
14
+ tags = [tags]
15
+ ret = []
16
+ for (word, pos) in nltk.pos_tag(tokenized):
17
+ for tag in tags:
18
+ if pos == tag:
19
+ ret.append(word)
20
+ return ret
21
+
22
+ def get_noun_phrase(tokenized):
23
+ # Taken from Su Nam Kim Paper...
24
+ grammar = r"""
25
+ NBAR:
26
+ {<NN.*|JJ>*<NN.*>} # Nouns and Adjectives, terminated with Nouns
27
+
28
+ NP:
29
+ {<NBAR>}
30
+ {<NBAR><IN><NBAR>} # Above, connected with in/of/etc...
31
+ """
32
+ chunker = nltk.RegexpParser(grammar)
33
+
34
+ chunked = chunker.parse(nltk.pos_tag(tokenized))
35
+ continuous_chunk = []
36
+ current_chunk = []
37
+
38
+ for subtree in chunked:
39
+ if isinstance(subtree, nltk.Tree):
40
+ current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
41
+ elif current_chunk:
42
+ named_entity = ' '.join(current_chunk)
43
+ if named_entity not in continuous_chunk:
44
+ continuous_chunk.append(named_entity)
45
+ current_chunk = []
46
+ else:
47
+ continue
48
+
49
+ return continuous_chunk
50
+
51
+ def text_noun_with_prompt_all(text, phrase_prob=0.0, append_text=True):
52
+ tokenized = nltk.word_tokenize(text)
53
+
54
+ if random.random() >= phrase_prob:
55
+ nouns = get_tag(tokenized, ['NN', 'NNS', 'NNP'])
56
+ else:
57
+ nouns = get_noun_phrase(tokenized)
58
+
59
+
60
+ prompt_texts = [np.random.choice(IMAGENET_DEFAULT_TEMPLATES).format(noun) for noun in nouns]
61
+
62
+ if append_text:
63
+ prompt_texts += [text]
64
+ nouns += [text]
65
+
66
+ return prompt_texts, nouns
modeling/language/vlpencoder.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ from timm.models.layers import trunc_normal_
13
+
14
+ from .build import register_model
15
+ from ..utils import configurable
16
+ from .LangEncoder import build_tokenizer, build_lang_encoder
17
+ from utilities.prompt_engineering import prompt_engineering, get_prompt_templates
18
+
19
+ from transformers import AutoTokenizer, AutoModel
20
+
21
+ class LanguageEncoder(nn.Module):
22
+
23
+ @configurable
24
+ def __init__(
25
+ self,
26
+ tokenizer,
27
+ tokenizer_type,
28
+ lang_encoder,
29
+ lang_projection,
30
+ max_token_num,
31
+ queue_operator,
32
+ ):
33
+ super().__init__()
34
+ # seg
35
+ self.tokenizer = tokenizer
36
+ self.tokenizer_type = tokenizer_type
37
+ self.lang_encoder = lang_encoder
38
+ self.lang_proj = lang_projection
39
+ self.max_token_num = max_token_num
40
+ self.logit_scale = nn.Parameter(torch.ones([]))
41
+
42
+ # captioning & retrieval
43
+ for key, value in queue_operator.items():
44
+ self.register_buffer(key, value)
45
+
46
+ self.biomed_encoder = AutoModel.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
47
+
48
+ @classmethod
49
+ def from_config(cls, cfg):
50
+ # build up text encoder for seg
51
+ tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])
52
+ tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER']
53
+ lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE'])
54
+ max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
55
+
56
+ dim_lang = cfg['MODEL']['TEXT']['WIDTH']
57
+ dim_projection = cfg['MODEL']['DIM_PROJ']
58
+ lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection))
59
+ trunc_normal_(lang_projection, std=.02)
60
+
61
+ # tested not working better
62
+ queue_operator = {}
63
+
64
+ return {
65
+ "tokenizer": tokenizer,
66
+ "tokenizer_type": tokenizer_type,
67
+ "lang_encoder": lang_encoder,
68
+ "lang_projection": lang_projection,
69
+ "max_token_num": max_token_num,
70
+ "queue_operator": queue_operator,
71
+ }
72
+
73
+ def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True, store_buffer=None):
74
+ if not is_eval:
75
+ if prompt:
76
+ # randomly sample one template
77
+ arbitary_concepts = [
78
+ prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \
79
+ for label in range(len(class_names))
80
+ ]
81
+ if add_bgd:
82
+ arbitary_concepts.append("A background in coco.")
83
+ else:
84
+ arbitary_concepts = class_names
85
+
86
+ input_ids = []
87
+ attention_masks = []
88
+ for txt in arbitary_concepts:
89
+ tokens = self.tokenizer(
90
+ txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
91
+ )
92
+ tokens['input_ids'].squeeze_()
93
+ tokens['attention_mask'].squeeze_()
94
+
95
+ input_ids.append(tokens['input_ids'])
96
+ attention_masks.append(tokens['attention_mask'])
97
+
98
+ arbitary_tokens = torch.stack(input_ids)
99
+ arbitary_attention_masks = torch.stack(attention_masks)
100
+
101
+ text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm)
102
+ setattr(self, '{}_text_embeddings'.format(name), text_emb)
103
+ else:
104
+ with torch.no_grad():
105
+ def extract_mean_emb(txts):
106
+ tokens = self.tokenizer(
107
+ txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
108
+ )
109
+ clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm)
110
+ clss_embedding = clss_embedding.mean(dim=0)
111
+ clss_embedding /= clss_embedding.norm()
112
+ return clss_embedding
113
+
114
+ templates = get_prompt_templates()
115
+ clss_embeddings = []
116
+ if prompt:
117
+ for clss in class_names:
118
+ txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates]
119
+ clss_embeddings.append(extract_mean_emb(txts))
120
+ else:
121
+ for clss in class_names:
122
+ clss_embeddings.append(extract_mean_emb([clss]))
123
+
124
+ if add_bgd:
125
+ txts = ["A background in coco."]
126
+ clss_embeddings.append(extract_mean_emb(txts))
127
+
128
+ text_emb = torch.stack(clss_embeddings, dim=0)
129
+ setattr(self, '{}_text_embeddings'.format(name), text_emb)
130
+
131
+ def reset_text_embeddings(self, name='default'):
132
+ pass
133
+
134
+ def get_text_token_embeddings(self, txts, name='default', token=False, norm=False):
135
+ if not token:
136
+ tokens = self.tokenizer(
137
+ txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
138
+ )
139
+ tokens = {key: value.cuda() for key, value in tokens.items()}
140
+ else:
141
+ tokens = txts
142
+ token_emb, class_emb = self.forward_language_token((tokens['input_ids'], tokens['attention_mask']), norm=norm)
143
+ ret = {"tokens": tokens,
144
+ "token_emb": token_emb,
145
+ "class_emb": class_emb,}
146
+ setattr(self, '{}_token_embeddings'.format(name), ret)
147
+ return ret
148
+
149
+ def forward_language(self, texts, norm=True):
150
+ if self.tokenizer_type == 'biomed-clip':
151
+ with torch.no_grad(): # Disable gradient calculation
152
+ outputs = self.biomed_encoder(*texts)
153
+ # Extract the last hidden state
154
+ x = outputs['last_hidden_state']
155
+ x = x[:, 0] # Get the [CLS] token's embeddings for all examples
156
+ else:
157
+ x = self.lang_encoder(*texts)
158
+ x = x['last_hidden_state']
159
+
160
+ if self.tokenizer_type == 'clip':
161
+ x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)]
162
+ else:
163
+ x = x[:, 0]
164
+
165
+ x = x @ self.lang_proj
166
+ if norm:
167
+ x = x / (x.norm(dim=-1, keepdim=True) + 1e-7)
168
+ return x
169
+
170
+ def forward_language_token(self, texts, norm=False):
171
+ if self.tokenizer_type == 'biomed-clip':
172
+ with torch.no_grad(): # Disable gradient calculation
173
+ outputs = self.biomed_encoder(*texts)
174
+ # Extract the last hidden state
175
+ token_x = outputs['last_hidden_state']
176
+ class_x = token_x[:, 0] # Get the [CLS] token's embeddings for all examples
177
+ else:
178
+ x = self.lang_encoder(*texts)
179
+ token_x = x['last_hidden_state']
180
+
181
+ if self.tokenizer_type == 'clip':
182
+ class_x = token_x[torch.arange(token_x.size(0)), texts[0].argmax(dim=-1)]
183
+ else:
184
+ class_x = token_x[:, 0]
185
+
186
+ class_x = class_x @ self.lang_proj
187
+ token_x = token_x @ self.lang_proj
188
+
189
+ if norm:
190
+ class_x = class_x / (class_x.norm(dim=-1, keepdim=True) + 1e-7)
191
+ token_x = token_x / (token_x.norm(dim=-1, keepdim=True) + 1e-7)
192
+
193
+ return token_x, class_x
194
+
195
+ def compute_similarity(self, v_emb, name='default', fake=False):
196
+ if fake:
197
+ return None
198
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
199
+ t_emb = getattr(self, '{}_text_embeddings'.format(name))
200
+ output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2)
201
+ return output
202
+
203
+
204
+ @register_model
205
+ def get_language_model(cfg, **kwargs):
206
+ return LanguageEncoder(cfg)
modeling/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .point_features import *
2
+ from .position_encoding import *
3
+ from .postprocessing import *
4
+ from .attention import *
5
+ from .criterion import *
6
+ from .matcher import *
modeling/modules/attention.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
8
+ from torch.nn.parameter import Parameter
9
+ from torch.overrides import has_torch_function, handle_torch_function
10
+ from torch.nn.functional import pad, linear, softmax, dropout
11
+
12
+
13
+ def multi_head_attention_forward(
14
+ query: Tensor,
15
+ key: Tensor,
16
+ value: Tensor,
17
+ embed_dim_to_check: int,
18
+ num_heads: int,
19
+ in_proj_weight: Tensor,
20
+ in_proj_bias: Tensor,
21
+ bias_k: Optional[Tensor],
22
+ bias_v: Optional[Tensor],
23
+ add_zero_attn: bool,
24
+ dropout_p: float,
25
+ out_proj_weight: Tensor,
26
+ out_proj_bias: Tensor,
27
+ training: bool = True,
28
+ key_padding_mask: Optional[Tensor] = None,
29
+ need_weights: bool = True,
30
+ attn_mask: Optional[Tensor] = None,
31
+ use_separate_proj_weight: bool = False,
32
+ q_proj_weight: Optional[Tensor] = None,
33
+ k_proj_weight: Optional[Tensor] = None,
34
+ v_proj_weight: Optional[Tensor] = None,
35
+ static_k: Optional[Tensor] = None,
36
+ static_v: Optional[Tensor] = None,
37
+ ) -> Tuple[Tensor, Optional[Tensor]]:
38
+ r"""
39
+ Args:
40
+ query, key, value: map a query and a set of key-value pairs to an output.
41
+ See "Attention Is All You Need" for more details.
42
+ embed_dim_to_check: total dimension of the model.
43
+ num_heads: parallel attention heads.
44
+ in_proj_weight, in_proj_bias: input projection weight and bias.
45
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
46
+ add_zero_attn: add a new batch of zeros to the key and
47
+ value sequences at dim=1.
48
+ dropout_p: probability of an element to be zeroed.
49
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
50
+ training: apply dropout if is ``True``.
51
+ key_padding_mask: if provided, specified padding elements in the key will
52
+ be ignored by the attention. This is an binary mask. When the value is True,
53
+ the corresponding value on the attention layer will be filled with -inf.
54
+ need_weights: output attn_output_weights.
55
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
56
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
57
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
58
+ and value in different forms. If false, in_proj_weight will be used, which is
59
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
60
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
61
+ static_k, static_v: static key and value used for attention operators.
62
+
63
+
64
+ Shape:
65
+ Inputs:
66
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
67
+ the embedding dimension.
68
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
69
+ the embedding dimension.
70
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
71
+ the embedding dimension.
72
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
73
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
74
+ will be unchanged. If a BoolTensor is provided, the positions with the
75
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
76
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
77
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
78
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
79
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
80
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
81
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
82
+ is provided, it will be added to the attention weight.
83
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
84
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
85
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
86
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
87
+
88
+ Outputs:
89
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
90
+ E is the embedding dimension.
91
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
92
+ L is the target sequence length, S is the source sequence length.
93
+ """
94
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
95
+ if has_torch_function(tens_ops):
96
+ return handle_torch_function(
97
+ multi_head_attention_forward,
98
+ tens_ops,
99
+ query,
100
+ key,
101
+ value,
102
+ embed_dim_to_check,
103
+ num_heads,
104
+ in_proj_weight,
105
+ in_proj_bias,
106
+ bias_k,
107
+ bias_v,
108
+ add_zero_attn,
109
+ dropout_p,
110
+ out_proj_weight,
111
+ out_proj_bias,
112
+ training=training,
113
+ key_padding_mask=key_padding_mask,
114
+ need_weights=need_weights,
115
+ attn_mask=attn_mask,
116
+ use_separate_proj_weight=use_separate_proj_weight,
117
+ q_proj_weight=q_proj_weight,
118
+ k_proj_weight=k_proj_weight,
119
+ v_proj_weight=v_proj_weight,
120
+ static_k=static_k,
121
+ static_v=static_v,
122
+ )
123
+ tgt_len, bsz, embed_dim = query.size()
124
+ assert embed_dim == embed_dim_to_check
125
+ # allow MHA to have different sizes for the feature dimension
126
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
127
+
128
+ head_dim = embed_dim // num_heads
129
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
130
+ scaling = float(head_dim) ** -0.5
131
+
132
+ if not use_separate_proj_weight:
133
+ if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)):
134
+ # self-attention
135
+ q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
136
+
137
+ elif key is value or torch.equal(key, value):
138
+ # encoder-decoder attention
139
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
140
+ _b = in_proj_bias
141
+ _start = 0
142
+ _end = embed_dim
143
+ _w = in_proj_weight[_start:_end, :]
144
+ if _b is not None:
145
+ _b = _b[_start:_end]
146
+ q = linear(query, _w, _b)
147
+
148
+ if key is None:
149
+ assert value is None
150
+ k = None
151
+ v = None
152
+ else:
153
+
154
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
155
+ _b = in_proj_bias
156
+ _start = embed_dim
157
+ _end = None
158
+ _w = in_proj_weight[_start:, :]
159
+ if _b is not None:
160
+ _b = _b[_start:]
161
+ k, v = linear(key, _w, _b).chunk(2, dim=-1)
162
+
163
+ else:
164
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
165
+ _b = in_proj_bias
166
+ _start = 0
167
+ _end = embed_dim
168
+ _w = in_proj_weight[_start:_end, :]
169
+ if _b is not None:
170
+ _b = _b[_start:_end]
171
+ q = linear(query, _w, _b)
172
+
173
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
174
+ _b = in_proj_bias
175
+ _start = embed_dim
176
+ _end = embed_dim * 2
177
+ _w = in_proj_weight[_start:_end, :]
178
+ if _b is not None:
179
+ _b = _b[_start:_end]
180
+ k = linear(key, _w, _b)
181
+
182
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
183
+ _b = in_proj_bias
184
+ _start = embed_dim * 2
185
+ _end = None
186
+ _w = in_proj_weight[_start:, :]
187
+ if _b is not None:
188
+ _b = _b[_start:]
189
+ v = linear(value, _w, _b)
190
+ else:
191
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
192
+ len1, len2 = q_proj_weight_non_opt.size()
193
+ assert len1 == embed_dim and len2 == query.size(-1)
194
+
195
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
196
+ len1, len2 = k_proj_weight_non_opt.size()
197
+ assert len1 == embed_dim and len2 == key.size(-1)
198
+
199
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
200
+ len1, len2 = v_proj_weight_non_opt.size()
201
+ assert len1 == embed_dim and len2 == value.size(-1)
202
+
203
+ if in_proj_bias is not None:
204
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
205
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)])
206
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])
207
+ else:
208
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias)
209
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias)
210
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias)
211
+ q = q * scaling
212
+
213
+ if attn_mask is not None:
214
+ assert (
215
+ attn_mask.dtype == torch.float32
216
+ or attn_mask.dtype == torch.float64
217
+ or attn_mask.dtype == torch.float16
218
+ or attn_mask.dtype == torch.uint8
219
+ or attn_mask.dtype == torch.bool
220
+ ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype)
221
+ if attn_mask.dtype == torch.uint8:
222
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
223
+ attn_mask = attn_mask.to(torch.bool)
224
+
225
+ if attn_mask.dim() == 2:
226
+ attn_mask = attn_mask.unsqueeze(0)
227
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
228
+ raise RuntimeError("The size of the 2D attn_mask is not correct.")
229
+ elif attn_mask.dim() == 3:
230
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
231
+ raise RuntimeError("The size of the 3D attn_mask is not correct.")
232
+ else:
233
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
234
+ # attn_mask's dim is 3 now.
235
+
236
+ # convert ByteTensor key_padding_mask to bool
237
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
238
+ warnings.warn(
239
+ "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
240
+ )
241
+ key_padding_mask = key_padding_mask.to(torch.bool)
242
+
243
+ if bias_k is not None and bias_v is not None:
244
+ if static_k is None and static_v is None:
245
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
246
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
247
+ if attn_mask is not None:
248
+ attn_mask = pad(attn_mask, (0, 1))
249
+ if key_padding_mask is not None:
250
+ key_padding_mask = pad(key_padding_mask, (0, 1))
251
+ else:
252
+ assert static_k is None, "bias cannot be added to static key."
253
+ assert static_v is None, "bias cannot be added to static value."
254
+ else:
255
+ assert bias_k is None
256
+ assert bias_v is None
257
+
258
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
259
+ if k is not None:
260
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
261
+ if v is not None:
262
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
263
+
264
+ if static_k is not None:
265
+ assert static_k.size(0) == bsz * num_heads
266
+ assert static_k.size(2) == head_dim
267
+ k = static_k
268
+
269
+ if static_v is not None:
270
+ assert static_v.size(0) == bsz * num_heads
271
+ assert static_v.size(2) == head_dim
272
+ v = static_v
273
+
274
+ src_len = k.size(1)
275
+
276
+ if key_padding_mask is not None:
277
+ # assert key_padding_mask.size(0) == bsz
278
+ assert key_padding_mask.size(1) == src_len
279
+
280
+ if add_zero_attn:
281
+ src_len += 1
282
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
283
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
284
+ if attn_mask is not None:
285
+ attn_mask = pad(attn_mask, (0, 1))
286
+ if key_padding_mask is not None:
287
+ key_padding_mask = pad(key_padding_mask, (0, 1))
288
+
289
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
290
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
291
+
292
+ if attn_mask is not None:
293
+ if attn_mask.dtype == torch.bool:
294
+ attn_output_weights.masked_fill_(attn_mask, float("-inf"))
295
+ else:
296
+ attn_output_weights += attn_mask
297
+
298
+ if key_padding_mask is not None:
299
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
300
+ attn_output_weights = attn_output_weights.masked_fill(
301
+ key_padding_mask.unsqueeze(1),
302
+ float("-inf"),
303
+ )
304
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
305
+
306
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
307
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
308
+
309
+ attn_output = torch.bmm(attn_output_weights, v)
310
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
311
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
312
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
313
+
314
+ if need_weights:
315
+ # average attention weights over heads
316
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
317
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
318
+ else:
319
+ return attn_output, None
320
+
321
+
322
+ # This class exists solely for Transformer; it has an annotation stating
323
+ # that bias is never None, which appeases TorchScript
324
+ class _LinearWithBias(nn.Linear):
325
+ bias: Tensor # type: ignore
326
+
327
+ def __init__(self, in_features: int, out_features: int) -> None:
328
+ super().__init__(in_features, out_features, bias=True) # type: ignore
329
+
330
+
331
+ class MultiheadAttention(nn.Module):
332
+ r"""Allows the model to jointly attend to information
333
+ from different representation subspaces.
334
+ See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_
335
+
336
+ .. math::
337
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
338
+
339
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
340
+
341
+ Args:
342
+ embed_dim: total dimension of the model.
343
+ num_heads: parallel attention heads.
344
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
345
+ bias: add bias as module parameter. Default: True.
346
+ add_bias_kv: add bias to the key and value sequences at dim=0.
347
+ add_zero_attn: add a new batch of zeros to the key and
348
+ value sequences at dim=1.
349
+ kdim: total number of features in key. Default: None.
350
+ vdim: total number of features in value. Default: None.
351
+
352
+ Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
353
+ to :attr:`embed_dim` such that query, key, and value have the same
354
+ number of features.
355
+
356
+ Examples::
357
+
358
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
359
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
360
+ """
361
+ bias_k: Optional[torch.Tensor]
362
+ bias_v: Optional[torch.Tensor]
363
+
364
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
365
+ super(MultiheadAttention, self).__init__()
366
+ self.embed_dim = embed_dim
367
+ self.kdim = kdim if kdim is not None else embed_dim
368
+ self.vdim = vdim if vdim is not None else embed_dim
369
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
370
+
371
+ self.num_heads = num_heads
372
+ self.dropout = dropout
373
+ self.head_dim = embed_dim // num_heads
374
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
375
+
376
+ if self._qkv_same_embed_dim is False:
377
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
378
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
379
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
380
+ self.register_parameter('in_proj_weight', None)
381
+ else:
382
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
383
+ self.register_parameter('q_proj_weight', None)
384
+ self.register_parameter('k_proj_weight', None)
385
+ self.register_parameter('v_proj_weight', None)
386
+
387
+ if bias:
388
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
389
+ else:
390
+ self.register_parameter('in_proj_bias', None)
391
+ self.out_proj = _LinearWithBias(embed_dim, embed_dim)
392
+
393
+ if add_bias_kv:
394
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
395
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
396
+ else:
397
+ self.bias_k = self.bias_v = None
398
+
399
+ self.add_zero_attn = add_zero_attn
400
+
401
+ self._reset_parameters()
402
+
403
+ def _reset_parameters(self):
404
+ if self._qkv_same_embed_dim:
405
+ xavier_uniform_(self.in_proj_weight)
406
+ else:
407
+ xavier_uniform_(self.q_proj_weight)
408
+ xavier_uniform_(self.k_proj_weight)
409
+ xavier_uniform_(self.v_proj_weight)
410
+
411
+ if self.in_proj_bias is not None:
412
+ constant_(self.in_proj_bias, 0.)
413
+ constant_(self.out_proj.bias, 0.)
414
+ if self.bias_k is not None:
415
+ xavier_normal_(self.bias_k)
416
+ if self.bias_v is not None:
417
+ xavier_normal_(self.bias_v)
418
+
419
+ def __setstate__(self, state):
420
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
421
+ if '_qkv_same_embed_dim' not in state:
422
+ state['_qkv_same_embed_dim'] = True
423
+
424
+ super(MultiheadAttention, self).__setstate__(state)
425
+
426
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
427
+ need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
428
+ r"""
429
+ Args:
430
+ query, key, value: map a query and a set of key-value pairs to an output.
431
+ See "Attention Is All You Need" for more details.
432
+ key_padding_mask: if provided, specified padding elements in the key will
433
+ be ignored by the attention. When given a binary mask and a value is True,
434
+ the corresponding value on the attention layer will be ignored. When given
435
+ a byte mask and a value is non-zero, the corresponding value on the attention
436
+ layer will be ignored
437
+ need_weights: output attn_output_weights.
438
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
439
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
440
+
441
+ Shapes for inputs:
442
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
443
+ the embedding dimension.
444
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
445
+ the embedding dimension.
446
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
447
+ the embedding dimension.
448
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
449
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
450
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
451
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
452
+ - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
453
+ source sequence length.
454
+
455
+ If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
456
+ length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
457
+ the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
458
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
459
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
460
+ is provided, it will be added to the attention weight.
461
+
462
+ Shapes for outputs:
463
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
464
+ E is the embedding dimension.
465
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
466
+ L is the target sequence length, S is the source sequence length.
467
+ """
468
+ if not self._qkv_same_embed_dim:
469
+ return multi_head_attention_forward(
470
+ query, key, value, self.embed_dim, self.num_heads,
471
+ self.in_proj_weight, self.in_proj_bias,
472
+ self.bias_k, self.bias_v, self.add_zero_attn,
473
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
474
+ training=self.training,
475
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
476
+ attn_mask=attn_mask, use_separate_proj_weight=True,
477
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
478
+ v_proj_weight=self.v_proj_weight)
479
+ else:
480
+ return multi_head_attention_forward(
481
+ query, key, value, self.embed_dim, self.num_heads,
482
+ self.in_proj_weight, self.in_proj_bias,
483
+ self.bias_k, self.bias_v, self.add_zero_attn,
484
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
485
+ training=self.training,
486
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
487
+ attn_mask=attn_mask)
modeling/modules/criterion.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Modified by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
10
+ """
11
+ MaskFormer criterion.
12
+ """
13
+ import logging
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch import nn
18
+
19
+ from detectron2.utils.comm import get_world_size
20
+ from timm.loss import SoftTargetCrossEntropy
21
+ from .point_features import (
22
+ get_uncertain_point_coords_with_randomness,
23
+ point_sample,
24
+ )
25
+
26
+ from ..language.loss import ql_multi_contrastive_loss, image_text_contrastive_loss_queue, vl_similarity, all_gather_grad
27
+ from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list, _max_by_axis
28
+ from ..utils import box_ops
29
+
30
+ # from image2html.visualizer import VL
31
+
32
+
33
+ def dice_loss(
34
+ inputs: torch.Tensor,
35
+ targets: torch.Tensor,
36
+ num_masks: float,
37
+ ):
38
+ """
39
+ Compute the DICE loss, similar to generalized IOU for masks
40
+ Args:
41
+ inputs: A float tensor of arbitrary shape.
42
+ The predictions for each example.
43
+ targets: A float tensor with the same shape as inputs. Stores the binary
44
+ classification label for each element in inputs
45
+ (0 for the negative class and 1 for the positive class).
46
+ """
47
+ inputs = inputs.sigmoid()
48
+ inputs = inputs.flatten(1)
49
+ numerator = 2 * (inputs * targets).sum(-1)
50
+ denominator = inputs.sum(-1) + targets.sum(-1)
51
+ loss = 1 - (numerator + 1) / (denominator + 1)
52
+ return loss.sum() / num_masks
53
+
54
+
55
+ dice_loss_jit = torch.jit.script(
56
+ dice_loss
57
+ ) # type: torch.jit.ScriptModule
58
+
59
+
60
+ def sigmoid_ce_loss(
61
+ inputs: torch.Tensor,
62
+ targets: torch.Tensor,
63
+ num_masks: float,
64
+ ):
65
+ """
66
+ Args:
67
+ inputs: A float tensor of arbitrary shape.
68
+ The predictions for each example.
69
+ targets: A float tensor with the same shape as inputs. Stores the binary
70
+ classification label for each element in inputs
71
+ (0 for the negative class and 1 for the positive class).
72
+ Returns:
73
+ Loss tensor
74
+ """
75
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
76
+
77
+ return loss.mean(1).sum() / num_masks
78
+
79
+
80
+ sigmoid_ce_loss_jit = torch.jit.script(
81
+ sigmoid_ce_loss
82
+ ) # type: torch.jit.ScriptModule
83
+
84
+
85
+ def calculate_uncertainty(logits):
86
+ """
87
+ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
88
+ foreground class in `classes`.
89
+ Args:
90
+ logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
91
+ class-agnostic, where R is the total number of predicted masks in all images and C is
92
+ the number of foreground classes. The values are logits.
93
+ Returns:
94
+ scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
95
+ the most uncertain locations having the highest uncertainty score.
96
+ """
97
+ assert logits.shape[1] == 1
98
+ gt_class_logits = logits.clone()
99
+ return -(torch.abs(gt_class_logits))
100
+
101
+
102
+ class SetCriterion(nn.Module):
103
+ """This class computes the loss for DETR.
104
+ The process happens in two steps:
105
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
106
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
107
+ """
108
+
109
+ def __init__(self, num_classes, matcher, weight_dict, eos_coef, top_x_layers, losses,
110
+ num_points, oversample_ratio, importance_sample_ratio, grounding_weight):
111
+ """Create the criterion.
112
+ Parameters:
113
+ num_classes: number of object categories, omitting the special no-object category
114
+ matcher: module able to compute a matching between targets and proposals
115
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
116
+ eos_coef: relative classification weight applied to the no-object category
117
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
118
+ """
119
+ super().__init__()
120
+ self.num_classes = num_classes
121
+ self.matcher = matcher
122
+ self.weight_dict = weight_dict
123
+ self.eos_coef = eos_coef
124
+ self.top_x_layers = top_x_layers
125
+ self.losses = losses
126
+ empty_weight = torch.ones(self.num_classes + 1)
127
+ empty_weight[-1] = self.eos_coef
128
+ self.register_buffer("empty_weight", empty_weight)
129
+
130
+ # pointwise mask loss parameters
131
+ self.num_points = num_points
132
+ self.oversample_ratio = oversample_ratio
133
+ self.importance_sample_ratio = importance_sample_ratio
134
+
135
+ # grounding
136
+ self.grounding_weight = grounding_weight
137
+
138
+ def loss_labels(self, outputs, targets, indices, num_masks, layer_id, extra):
139
+ """Classification loss (NLL)
140
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
141
+ """
142
+ if layer_id > self.top_x_layers['mask']:
143
+ return {"loss_mask_ce_0": 0}
144
+
145
+ if indices is None or len(targets) == 0:
146
+ loss_ce = outputs['pred_logits'].sum() * 0.0
147
+ losses = {"loss_mask_ce_0": loss_ce}
148
+ return losses
149
+
150
+ assert "pred_logits" in outputs
151
+ src_logits = outputs["pred_logits"].type(self.empty_weight.dtype)
152
+
153
+ idx = self._get_src_permutation_idx(indices)
154
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
155
+ target_classes = torch.full(
156
+ src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
157
+ )
158
+ target_classes[idx] = target_classes_o
159
+
160
+ if src_logits.shape[2] == self.num_classes+1:
161
+ empty_weight = torch.ones(self.num_classes + 1).to(src_logits.device).type(self.empty_weight.dtype)
162
+ empty_weight[-1] = self.eos_coef
163
+ else:
164
+ empty_weight = torch.ones(self.num_classes + 1000 + 1).to(src_logits.device).type(self.empty_weight.dtype)
165
+ empty_weight[self.num_classes] = self.eos_coef
166
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes)
167
+ losses = {"loss_mask_ce_0": loss_ce}
168
+ return losses
169
+
170
+ def loss_labels_openimage(self, outputs, targets, indices, num_masks, layer_id, extra):
171
+ """Classification loss (NLL)
172
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
173
+ """
174
+ if layer_id > self.top_x_layers['mask']:
175
+ return {"loss_openimage_ce_0": 0}
176
+
177
+ assert "pred_captions" in outputs
178
+
179
+ if indices is None or len(targets) == 0 or (len(targets) > 0 and len(targets[0]['labels']) == 0):
180
+ loss_ce = outputs['pred_captions'].sum() * 0.0
181
+ losses = {"loss_openimage_ce_0": loss_ce}
182
+ return losses
183
+
184
+ # compute i2t loss
185
+ loss_openimage_ce = 0
186
+ losses = {}
187
+ for b in range(len(indices)):
188
+ pred_logit = outputs["pred_logits"][b][indices[b][0]]
189
+ gt_logit = torch.zeros_like(pred_logit)
190
+ select_idx = torch.stack((torch.arange(len(indices[b][1])), indices[b][1])).tolist()
191
+ gt_logit[select_idx] = 1
192
+ loss_openimage_ce += torch.sum(-gt_logit * F.log_softmax(pred_logit, dim=-1), dim=-1).mean()
193
+ loss_openimage_ce = loss_openimage_ce / len(indices)
194
+ losses.update({"loss_openimage_ce_0": loss_openimage_ce})
195
+ return losses
196
+
197
+ def loss_itc(self, outputs, targets, indices, num_masks, layer_id, extra):
198
+ if layer_id >= self.top_x_layers['retrieval']:
199
+ return {"loss_retrieval_decoder_0": 0}
200
+ t_emb = torch.cat([x['caption_proj'] for x in targets], dim=0)
201
+ v_emb = outputs['pred_captions'][:,-1]
202
+ loss_contrast = image_text_contrastive_loss_queue(v_emb, t_emb, extra['lang_encoder'], extra['training'])
203
+
204
+ # compute query-token contrastive loss
205
+ ttk_emb = torch.cat([x['caption_tokens'] for x in targets], dim=0)
206
+ ttk_mask = torch.cat([x['caption_mask'] for x in targets], dim=0).float()
207
+ ttk_mask = ttk_mask * torch.cumsum(ttk_mask, dim=1)
208
+ vtk_emb = outputs['pred_captions'][:,:-1]
209
+ keep = torch.cat([x['caption_mask'] for x in targets], dim=0).bool()
210
+
211
+ ttk_emb = ttk_emb / (ttk_emb.norm(dim=-1, keepdim=True) + 1e-7)
212
+ vtk_emb = vtk_emb / (vtk_emb.norm(dim=-1, keepdim=True) + 1e-7)
213
+ logit_scale = extra['lang_encoder'].logit_scale.exp().clamp(max=100)
214
+
215
+ # prepare gt
216
+ gt = (torch.eye(vtk_emb.shape[0]).type_as(ttk_mask).unsqueeze(-1) * ttk_mask.unsqueeze(0).repeat(vtk_emb.shape[0], 1, 1))[:,keep].flatten(1)
217
+ gt = gt / (gt.sum(1, keepdim=True) + 1e-7)
218
+ # compute i2t loss
219
+ logits = logit_scale * (vtk_emb @ ttk_emb[keep].transpose(0, 1)).mean(1)
220
+ loss_contrast_fine_vt = SoftTargetCrossEntropy()(logits, gt)
221
+ # loss_contrast_fine = loss_contrast_fine_vt # i2t only
222
+
223
+ # compute t2i loss
224
+ bs, nq, _ = vtk_emb.shape
225
+ logits = logit_scale * (ttk_emb @ vtk_emb.flatten(0,1).transpose(0, 1)).reshape(bs,-1,bs,nq).mean(dim=-1)[keep,:]
226
+ loss_contrast_fine_tv = SoftTargetCrossEntropy()(logits, gt.t())
227
+ # compute loss
228
+ loss_contrast_fine = (loss_contrast_fine_vt * 0.7 + loss_contrast_fine_tv * 0.3)
229
+
230
+ losses = {"loss_retrieval_decoder_0": loss_contrast + loss_contrast_fine * 0.5}
231
+ return losses
232
+
233
+ def loss_captionings(self, outputs, targets, indices, num_masks, layer_id, extra):
234
+ if layer_id >= self.top_x_layers['captioning']:
235
+ return {"loss_captioning_0": 0}
236
+
237
+ pred_captions_gen = outputs['pred_captionings'][:, :-1]
238
+ token_embs = extra['token_embedding'].weight
239
+ # token_embs = (token_embs / token_embs.norm(dim=-1, keepdim=True) + 1e-7)
240
+ # pred_captions_gen = (pred_captions_gen / pred_captions_gen.norm(dim=-1, keepdim=True) + 1e-7)
241
+ pred_captions_gen = pred_captions_gen @ token_embs.t()
242
+
243
+ # temperature = extra['lang_encoder'].logit_scale
244
+ # logit_scale = temperature.exp().clamp(max=100)
245
+
246
+ target_captions_gen = torch.cat([target['caption_tokenids'] for target in targets], 0)[:, 1:]
247
+ target_captions_gen_mask = torch.cat([target['caption_mask'] for target in targets], 0)[:, 1:]
248
+
249
+ # loss_caption = F.cross_entropy(pred_captions_gen.transpose(1,2) * logit_scale, target_captions_gen, reduction='none')
250
+ loss_caption = F.cross_entropy(pred_captions_gen.transpose(1,2), target_captions_gen, reduction='none')
251
+ loss_caption = (loss_caption * target_captions_gen_mask).sum() / (target_captions_gen_mask.sum() + 1)
252
+ losses = {"loss_captioning_0": loss_caption}
253
+ return losses
254
+
255
+ def loss_captions(self, outputs, targets, indices, num_masks, layer_id, extra):
256
+ if layer_id >= self.top_x_layers['caption']:
257
+ return {"loss_caption_0": 0}
258
+ matched_tokens = [m[0] for m in indices]
259
+ t_emb_class = torch.cat([extra['class_embeddings'][targets[bs]['labels'][m[1]]] for bs, m in enumerate(indices)])
260
+ t_hash_class = torch.cat([torch.tensor(targets[bs]['labels_hash'])[m[1]] for bs, m in enumerate(indices)])
261
+
262
+ # pred_captions denotes all unmatched object queries.
263
+ unmatched_pred_captions = []
264
+ matched_pred_captions = []
265
+ for idx, m in enumerate(matched_tokens):
266
+ unmatched_masks = torch.ones(outputs['pred_captions'].shape[1:-1]).bool()
267
+ matched_masks = torch.zeros(outputs['pred_captions'].shape[1:-1]).bool()
268
+
269
+ unmatched_masks[m] = False
270
+ matched_masks[m] = True
271
+
272
+ unmatched_pred_captions.append(outputs['pred_captions'][idx][unmatched_masks])
273
+ matched_pred_captions.append(outputs['pred_captions'][idx][matched_masks])
274
+
275
+ outputs['unmatched_pred_captions'] = unmatched_pred_captions
276
+ v_emb_class = torch.cat(matched_pred_captions)
277
+ v_emb_class = v_emb_class / (v_emb_class.norm(dim=-1, keepdim=True) + 1e-7)
278
+
279
+ indices = self.matcher(outputs, targets, mode="caption_womask", extra={'temperature':extra['lang_logit']})
280
+ src_idx = self._get_src_permutation_idx(indices)
281
+
282
+ t_emb = torch.cat([t['captions'][indices[bs][1]] for bs,t in enumerate(targets)])
283
+ t_hash = torch.cat([torch.tensor(t['captions_hash'])[indices[bs][1]] for bs,t in enumerate(targets)])
284
+
285
+ unmatched_pred_captions, _ = nested_tensor_from_tensor_list(unmatched_pred_captions).decompose()
286
+ v_emb = unmatched_pred_captions[src_idx]
287
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
288
+
289
+ loss_contrast = ql_multi_contrastive_loss(torch.cat((v_emb, v_emb_class)), torch.cat((t_emb, t_emb_class)), torch.cat((t_hash, t_hash_class)), temperature=extra['lang_logit'])
290
+ losses = {"loss_caption_0": loss_contrast}
291
+
292
+ return losses
293
+
294
+ def loss_masks(self, outputs, targets, indices, num_masks, layer_id, extra):
295
+ """Compute the losses related to the masks: the focal loss and the dice loss.
296
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
297
+ """
298
+ if layer_id >= self.top_x_layers['mask']:
299
+ return {"loss_mask_bce_0": 0, "loss_mask_dice_0": 0}
300
+
301
+ assert "pred_masks" in outputs
302
+ if indices is None or len(targets) == 0:
303
+ loss = outputs['pred_masks'].sum() * 0.0
304
+ losses = {"loss_mask_bce_0": loss, "loss_mask_dice_0": loss}
305
+ return losses
306
+
307
+ src_idx = self._get_src_permutation_idx(indices)
308
+ tgt_idx = self._get_tgt_permutation_idx(indices)
309
+ src_masks = outputs["pred_masks"]
310
+ src_masks = src_masks[src_idx]
311
+ masks = [t["masks"] for t in targets]
312
+ # TODO use valid to mask invalid areas due to padding in loss
313
+ target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
314
+ target_masks = target_masks.to(src_masks)
315
+ target_masks = target_masks[tgt_idx]
316
+ # No need to upsample predictions as we are using normalized coordinates :)
317
+ # N x 1 x H x W
318
+ src_masks = src_masks[:, None]
319
+ target_masks = target_masks[:, None]
320
+
321
+ with torch.no_grad():
322
+ # sample point_coords
323
+ point_coords = get_uncertain_point_coords_with_randomness(
324
+ src_masks,
325
+ lambda logits: calculate_uncertainty(logits),
326
+ self.num_points,
327
+ self.oversample_ratio,
328
+ self.importance_sample_ratio,
329
+ ).type(src_masks.dtype)
330
+ # get gt labels
331
+ point_labels = point_sample(
332
+ target_masks,
333
+ point_coords,
334
+ align_corners=False,
335
+ ).squeeze(1)
336
+
337
+ point_logits = point_sample(
338
+ src_masks,
339
+ point_coords,
340
+ align_corners=False,
341
+ ).squeeze(1)
342
+
343
+ losses = {
344
+ "loss_mask_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
345
+ "loss_mask_dice_0": dice_loss_jit(point_logits, point_labels, num_masks),
346
+ }
347
+
348
+ del src_masks
349
+ del target_masks
350
+ return losses
351
+
352
+ def loss_groundings(self, outputs, targets, indices, num_masks, layer_id, extra):
353
+ """Compute the losses related to the masks: the focal loss and the dice loss.
354
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
355
+ """
356
+ assert "pred_gmasks" in outputs
357
+ assert "pred_gtexts" in outputs
358
+
359
+ if layer_id >= self.top_x_layers['grounding']:
360
+ return {"loss_grounding_bce_0": 0, "loss_grounding_dice_0": 0, "loss_grounding_ce_0": 0}
361
+
362
+ masks = [t["grounding_masks"] for t in targets]
363
+ if indices is None or None in masks:
364
+ loss = outputs['pred_gmasks'].sum() * 0.0
365
+ return {"loss_grounding_bce_0": loss, "loss_grounding_dice_0": loss, "loss_grounding_ce_0": loss}
366
+
367
+ pred_logits = []
368
+ for b in range(len(indices)):
369
+ t_emb = targets[b]['grounding_class_embs']
370
+ v_emb = outputs["pred_gtexts"][b]
371
+
372
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
373
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
374
+
375
+ out_prob = vl_similarity(v_emb, t_emb, temperature=extra['lang_logit'])
376
+ pred_logits += [out_prob]
377
+ outputs['pred_logits'] = pred_logits
378
+
379
+ indices = self.matcher(outputs, targets, mode='grounding', extra={'temperature':extra['lang_logit']})
380
+ src_idx = self._get_src_permutation_idx(indices)
381
+ tgt_idx = self._get_tgt_permutation_idx(indices)
382
+
383
+ src_masks = outputs["pred_gmasks"]
384
+ src_masks = src_masks[src_idx]
385
+ # TODO use valid to mask invalid areas due to padding in loss
386
+ target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
387
+ target_masks = target_masks.to(src_masks)
388
+ target_masks = target_masks[tgt_idx]
389
+ # No need to upsample predictions as we are using normalized coordinates :)
390
+ # N x 1 x H x W
391
+ src_masks = src_masks[:, None]
392
+ target_masks = target_masks[:, None]
393
+
394
+ with torch.no_grad():
395
+ # sample point_coords
396
+ point_coords = get_uncertain_point_coords_with_randomness(
397
+ src_masks,
398
+ lambda logits: calculate_uncertainty(logits),
399
+ self.num_points,
400
+ self.oversample_ratio,
401
+ self.importance_sample_ratio,
402
+ ).type(src_masks.dtype)
403
+ # get gt labels
404
+ point_labels = point_sample(
405
+ target_masks,
406
+ point_coords,
407
+ align_corners=False,
408
+ ).squeeze(1)
409
+
410
+ point_logits = point_sample(
411
+ src_masks,
412
+ point_coords,
413
+ align_corners=False,
414
+ ).squeeze(1)
415
+
416
+ losses = {
417
+ "loss_grounding_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, len(src_masks)),
418
+ "loss_grounding_dice_0": dice_loss_jit(point_logits, point_labels, len(src_masks)),
419
+ }
420
+
421
+ # compute query-token contrastive loss
422
+ # ttk_emb = torch.cat([x['caption_tokens'] for x in targets], dim=0)
423
+ # ttk_mask = torch.cat([x['caption_mask'] for x in targets], dim=0).float()
424
+ # ttk_mask = ttk_mask * torch.cumsum(ttk_mask, dim=1)
425
+ # vtk_emb = outputs['pred_captions'][:,:-1]
426
+ # keep = torch.cat([x['caption_mask'] for x in targets], dim=0).bool()
427
+
428
+ # ttk_emb = ttk_emb / (ttk_emb.norm(dim=-1, keepdim=True) + 1e-7)
429
+ # vtk_emb = vtk_emb / (vtk_emb.norm(dim=-1, keepdim=True) + 1e-7)
430
+ # logit_scale = extra['lang_encoder'].logit_scale.exp().clamp(max=100)
431
+
432
+ # # prepare gt
433
+ # gt = (torch.eye(vtk_emb.shape[0]).type_as(ttk_mask).unsqueeze(-1) * ttk_mask.unsqueeze(0).repeat(vtk_emb.shape[0], 1, 1))[:,keep].flatten(1)
434
+ # gt = gt / (gt.sum(1, keepdim=True) + 1e-7)
435
+ # # compute i2t loss
436
+ # logits = logit_scale * (vtk_emb @ ttk_emb[keep].transpose(0, 1)).mean(1)
437
+ # loss_contrast_fine_vt = SoftTargetCrossEntropy()(logits, gt)
438
+ # # loss_contrast_fine = loss_contrast_fine_vt # i2t only
439
+
440
+ # # compute t2i loss
441
+ # bs, nq, _ = vtk_emb.shape
442
+ # logits = logit_scale * (ttk_emb @ vtk_emb.flatten(0,1).transpose(0, 1)).reshape(bs,-1,bs,nq).mean(dim=-1)[keep,:]
443
+ # loss_contrast_fine_tv = SoftTargetCrossEntropy()(logits, gt.t())
444
+ # # compute loss
445
+ # loss_contrast_fine = (loss_contrast_fine_vt * 0.7 + loss_contrast_fine_tv * 0.3)
446
+
447
+ # compute t2i loss
448
+ loss_grd_ce = 0
449
+ for b in range(len(indices)):
450
+ task = targets[b]['grounding_task']
451
+ pred_logit = outputs["pred_logits"][b]
452
+ gt_logit = torch.zeros_like(pred_logit)
453
+ select_idx = torch.stack((indices[b][0], indices[b][1])).tolist()
454
+ gt_logit[select_idx] = 1
455
+ t_hash = torch.tensor(targets[b]['grounding_hash'], device=gt_logit.device)
456
+ hash_table = torch.zeros((len(t_hash), len(t_hash)), device=gt_logit.device)
457
+ for idx in range(0, len(hash_table)):
458
+ hash_table[idx][t_hash==t_hash[idx]] = 1
459
+ hash_table = hash_table / hash_table.sum(-1, keepdim=True)
460
+ gt_logit = gt_logit @ hash_table
461
+ loss_grd_ce += self.grounding_weight[task]*torch.sum(-gt_logit.t() * F.log_softmax(pred_logit.t(), dim=-1), dim=-1).mean()
462
+ loss_grd_ce = loss_grd_ce / len(indices)
463
+ losses.update({"loss_grounding_ce_0": loss_grd_ce})
464
+ del src_masks
465
+ del target_masks
466
+ return losses
467
+
468
+ def loss_spatials(self, outputs, targets, indices, num_masks, layer_id, extra):
469
+ """Compute the losses related to the masks: the focal loss and the dice loss.
470
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
471
+ """
472
+ assert "pred_smasks" in outputs
473
+ assert "pred_smaskembs" in outputs
474
+
475
+ if layer_id >= self.top_x_layers['spatial']:
476
+ loss = outputs['pred_smasks'].sum() * 0.0
477
+ loss_grd_ce = outputs["pred_smasks"].sum() * 0.0
478
+ return {"loss_spatial_bce_0": loss, "loss_spatial_dice_0": loss, "loss_spatial_ce_0": loss_grd_ce}
479
+
480
+ gt_masks = [x['gt_spatial_masks'] for x in targets]
481
+ # compute a keep index with batch size to avoid empty gt_masks
482
+ stack_gt_mask = torch.cat(gt_masks)
483
+ bs,_,_ = stack_gt_mask.shape
484
+ stack_gt_mask = stack_gt_mask.view(bs,-1).sum(dim=-1)
485
+ keep = stack_gt_mask > 0 # only keep sample contain positive mask
486
+
487
+ if keep.sum() == 0:
488
+ loss = outputs['pred_smasks'].sum() * 0.0
489
+ loss_grd_ce = outputs["pred_smasks"].sum() * 0.0
490
+ return {"loss_spatial_bce_0": loss, "loss_spatial_dice_0": loss, "loss_spatial_ce_0": loss_grd_ce}
491
+
492
+ # mask embedding logits
493
+ v_emb = outputs["pred_smaskembs"] # [bs, nq, 512]
494
+
495
+ # pos mask
496
+ s_emb = outputs["pred_pspatials"] # [bs, ns, 512]
497
+ pred_logits = v_emb @ s_emb.transpose(1,2)
498
+ outputs['pred_pos_logits'] = pred_logits # [bs, nq, 1]
499
+ indices = self.matcher(outputs, targets, mode='spatial', extra={})
500
+ src_idx = self._get_src_permutation_idx(indices)
501
+ tgt_idx = self._get_tgt_permutation_idx(indices)
502
+
503
+ # pos class loss
504
+ pred_logit = torch.cat([o[:len(t['gt_spatial_masks'])] for o,t in zip(outputs["pred_pos_logits"].transpose(1,2), targets)])
505
+ gt_logit = torch.zeros_like(pred_logit)
506
+ gt_logit = gt_logit[keep]
507
+ _src_idx = [torch.arange(keep.sum(), device=src_idx[0].device), src_idx[1][keep.cpu()]]
508
+ gt_logit[_src_idx] = 1
509
+ pred_logit = pred_logit[keep]
510
+ loss_spa_ce_pos = torch.sum(-gt_logit * F.log_softmax(pred_logit, dim=-1), dim=-1).mean()
511
+
512
+ # neg mask
513
+ # s_emb = outputs["pred_nspatials"] # [bs, ns, 512]
514
+ # neg_mask = (s_emb.sum(dim=list(range(1, len(s_emb.shape)))) != 0).float()[keep]
515
+ # pred_logits = v_emb @ s_emb.transpose(1,2)
516
+ # outputs['pred_neg_logits'] = pred_logits # [bs, nq, 1]
517
+ # indices = self.matcher(outputs, targets, mode='spatial_pn', extra=extra)
518
+ # src_idx = self._get_src_permutation_idx(indices)
519
+ # tgt_idx = self._get_tgt_permutation_idx(indices)
520
+ # src_masks_neg = outputs["pred_smasks"][src_idx][keep]
521
+ # src_masks_neg = src_masks_neg*(neg_mask[:,None,None])
522
+ # src_masks_neg = src_masks_neg.clip(0) * (-1)
523
+
524
+ # neg class loss
525
+ # pred_logit = outputs["pred_neg_logits"]
526
+ # gt_logit = torch.zeros_like(pred_logit)
527
+ # gt_logit[src_idx] = 1
528
+ # bs,_,ns = pred_logit[keep].shape
529
+ # pred_logit = pred_logit[keep].transpose(1,2).view(bs*ns,-1)
530
+ # gt_logit = gt_logit[keep].transpose(1,2).view(bs*ns,-1)
531
+ # loss_spa_ce_neg = (torch.sum(-gt_logit * F.log_softmax(pred_logit, dim=-1), dim=-1)*neg_mask).sum() / (neg_mask.sum()+1e-6)
532
+
533
+ # recompute a keep index with matched tgt
534
+ stack_gt_mask = nn.utils.rnn.pad_sequence(gt_masks, padding_value=-1).transpose(0,1)[tgt_idx]
535
+ bs,_,_ = stack_gt_mask.shape
536
+ target_masks = stack_gt_mask
537
+ stack_gt_mask = stack_gt_mask.view(bs,-1).sum(dim=-1)
538
+ keep = stack_gt_mask > 0 # only keep sample contain positive mask
539
+ src_masks_pos = outputs["pred_smasks"][src_idx][keep]
540
+
541
+ # TODO use valid to mask invalid areas due to padding in loss
542
+ target_masks = target_masks.to(src_masks_pos)
543
+ target_masks = target_masks[keep]
544
+
545
+ # mul = extra['spatial_query_mode'][keep]
546
+ # src_masks_cur = src_masks_cur.clip(0) * mul[:,None,None]
547
+ # src_masks_cur = src_masks_cur
548
+
549
+ # if neg_mask[0] == 1:
550
+ # import cv2
551
+ # print(src_masks_pos.shape)
552
+ # print(src_masks_neg.shape)
553
+ # print(target_masks.shape)
554
+ # # import pdb; pdb.set_trace()
555
+ # v_pos_mask = (src_masks_pos[0].sigmoid() > 0.5).float().cpu().detach().numpy() * 255
556
+ # v_neg_mask = (_src_masks_neg[0].sigmoid() > 0.5).float().cpu().detach().numpy() * 255
557
+ # v_sum = ((src_masks_pos[0]-_src_masks_neg[0].clip(0)).sigmoid() > 0.5).float().cpu().detach().numpy() * 255
558
+ # v_gt = target_masks[0].float().cpu().detach().numpy() * 255
559
+
560
+ # cv2.imwrite('v_pos_mask.png', v_pos_mask)
561
+ # cv2.imwrite('v_neg_mask.png', v_neg_mask)
562
+ # cv2.imwrite('v_sum.png', v_sum)
563
+ # cv2.imwrite('v_gt.png', v_gt)
564
+ # import pdb; pdb.set_trace()
565
+
566
+ # src_masks = (src_masks_pos + src_masks_neg)[:, None]
567
+ src_masks = src_masks_pos[:, None]
568
+ target_masks = target_masks[:, None]
569
+
570
+ # debug visualization
571
+ # with torch.no_grad():
572
+ # import cv2
573
+ # import numpy as np
574
+
575
+ # v_src_masks = (F.interpolate(src_masks, size=target_masks.shape[-2:], mode='bilinear', align_corners=False).sigmoid() > 0.5).float().cpu().numpy()[:,0] * 255
576
+ # v_target_masks = target_masks.float().cpu().numpy()[:,0] * 255
577
+ # v_masks = np.concatenate([v_src_masks, v_target_masks], axis=2)
578
+
579
+ # for i in range(len(src_masks)):
580
+ # v1 = v_src_masks[i]
581
+ # v2 = v_target_masks[i]
582
+ # v = np.concatenate([v1,v2], axis=1)
583
+ # cv2.imwrite('v{}.png'.format(i), v)
584
+ # import pdb; pdb.set_trace()
585
+
586
+ # visualization
587
+ # VL.step()
588
+ # v_img = batched_inputs[0]['image'].permute(1,2,0).cpu().numpy()
589
+ # VL.add_image(v_img[:,:,::-1])
590
+ # candidate_masks = batched_inputs[0]['spatial_query']['rand_shape'].float().cpu().numpy()
591
+ # gt_masks = batched_inputs[0]['spatial_query']['gt_masks'].float().cpu().numpy()
592
+ # texts = ['cmask' for i in range(len(candidate_masks))]
593
+ # VL.overlay_obj_mask_to_image(v_img[:,:,::-1], candidate_masks, texts)
594
+ # texts = ['gmask' for i in range(len(candidate_masks))]
595
+ # VL.overlay_obj_mask_to_image(v_img[:,:,::-1], gt_masks, texts)
596
+
597
+ # import cv2
598
+ # for i in range(len(src_masks)):
599
+ # visual_src_mask_cur = (src_masks_cur[i].sigmoid()>0.5).detach().float().cpu().numpy() * 255
600
+ # visual_src_mask_mem = (src_masks_mem[i].sigmoid()>0.5).detach().float().cpu().numpy() * 255
601
+ # visual_src_mask = (src_masks[i,0].sigmoid()>0.5).detach().float().cpu().numpy() * 255
602
+ # visual_target_mask = (target_masks[i,0].sigmoid()>0.5).detach().float().cpu().numpy() * 255
603
+
604
+ # cv2.imwrite('visual_src_mask_cur_{}_{}.png'.format(i, mul[i].item()), visual_src_mask_cur)
605
+ # cv2.imwrite('visual_src_mask_mem_{}_{}.png'.format(i, mul[i].item()), visual_src_mask_mem)
606
+ # cv2.imwrite('visual_src_mask_{}_{}.png'.format(i, mul[i].item()), visual_src_mask)
607
+ # cv2.imwrite('visual_target_mask_{}_{}.png'.format(i, mul[i].item()), visual_target_mask)
608
+ # import pdb; pdb.set_trace()
609
+
610
+ with torch.no_grad():
611
+ # sample point_coords
612
+ point_coords = get_uncertain_point_coords_with_randomness(
613
+ src_masks,
614
+ lambda logits: calculate_uncertainty(logits),
615
+ self.num_points,
616
+ self.oversample_ratio,
617
+ self.importance_sample_ratio,
618
+ ).type(src_masks.dtype)
619
+ # get gt labels
620
+ point_labels = point_sample(
621
+ target_masks,
622
+ point_coords,
623
+ align_corners=False,
624
+ ).squeeze(1)
625
+
626
+ point_logits = point_sample(
627
+ src_masks,
628
+ point_coords,
629
+ align_corners=False,
630
+ ).squeeze(1)
631
+
632
+ num_masks = len(src_masks)
633
+ losses = {
634
+ "loss_spatial_bce_0": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
635
+ "loss_spatial_dice_0": dice_loss_jit(point_logits, point_labels, num_masks),
636
+ }
637
+
638
+ # losses.update({"loss_spatial_ce_0": loss_spa_ce_pos + loss_spa_ce_neg})
639
+ losses.update({"loss_spatial_ce_0": loss_spa_ce_pos})
640
+
641
+ del src_masks
642
+ del target_masks
643
+ return losses
644
+
645
+ def loss_boxes(self, outputs, targets, indices, num_boxes, layer_id, extra):
646
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
647
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
648
+ The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
649
+ """
650
+ if layer_id >= self.top_x_layers['box']:
651
+ return {"loss_bbox_0": 0, "loss_giou_0": 0}
652
+
653
+ assert 'pred_boxes' in outputs
654
+
655
+ if indices is None or len(targets) == 0:
656
+ loss = outputs['pred_boxes'].sum() * 0.0
657
+ losses = {"loss_bbox_0": loss, "loss_giou_0": loss}
658
+ return losses
659
+
660
+ src_idx = self._get_src_permutation_idx(indices)
661
+ tgt_idx = self._get_tgt_permutation_idx(indices)
662
+ src_boxes = outputs["pred_boxes"]
663
+ src_boxes = src_boxes[src_idx].sigmoid()
664
+
665
+ target_boxes = [t['boxes'] for t in targets]
666
+ max_size = _max_by_axis([list(box.shape) for box in target_boxes])
667
+ max_size = [len(target_boxes)] + max_size
668
+ empty_boxes = torch.zeros(max_size).to(src_boxes.device)
669
+ for idx, tar_box in enumerate(target_boxes):
670
+ empty_boxes[idx,:tar_box.shape[0],:] = tar_box
671
+ target_boxes = empty_boxes[tgt_idx]
672
+
673
+ # target_isthings = [t['is_things'] for t in targets]
674
+ # max_size = _max_by_axis([list(lab.shape) for lab in target_isthings])
675
+ # max_size = [len(target_isthings)] + max_size
676
+ # empty_lab = torch.zeros(max_size).to(src_boxes.device)
677
+
678
+ # for idx, tar_thing in enumerate(target_isthings):
679
+ # empty_lab[idx,:tar_thing.shape[0]] = tar_thing
680
+ # target_isthings = empty_lab[tgt_idx]
681
+
682
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
683
+ losses = {}
684
+ losses['loss_bbox_0'] = loss_bbox.sum() / num_boxes
685
+
686
+ loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
687
+ box_ops.box_cxcywh_to_xyxy(src_boxes),
688
+ box_ops.box_cxcywh_to_xyxy(target_boxes)))
689
+ losses['loss_giou_0'] = loss_giou.sum() / num_boxes
690
+ return losses
691
+
692
+ def _get_src_permutation_idx(self, indices):
693
+ # permute predictions following indices
694
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
695
+ src_idx = torch.cat([src for (src, _) in indices])
696
+ return batch_idx, src_idx
697
+
698
+ def _get_tgt_permutation_idx(self, indices):
699
+ # permute targets following indices
700
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
701
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
702
+ return batch_idx, tgt_idx
703
+
704
+ def get_loss(self, loss, outputs, targets, indices, num_masks, layer_id, extra):
705
+ loss_map = {
706
+ 'labels': self.loss_labels,
707
+ 'masks': self.loss_masks,
708
+ 'boxes': self.loss_boxes,
709
+ 'captions': self.loss_captions,
710
+ 'retrievals': self.loss_itc,
711
+ 'captionings': self.loss_captionings,
712
+ 'groundings': self.loss_groundings,
713
+ 'labels_openimage': self.loss_labels_openimage,
714
+ 'spatials': self.loss_spatials,
715
+ }
716
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
717
+ return loss_map[loss](outputs, targets, indices, num_masks, layer_id, extra)
718
+
719
+ def forward(self, outputs, targets, extra=None):
720
+ """This performs the loss computation.
721
+ Parameters:
722
+ outputs: dict of tensors, see the output specification of the model for the format
723
+ targets: list of dicts, such that len(targets) == batch_size.
724
+ The expected keys in each dict depends on the losses applied, see each loss' doc
725
+ """
726
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
727
+
728
+ # Retrieve the matching between the outputs of the last layer and the targets
729
+ indices = self.matcher(outputs_without_aux, targets)
730
+
731
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
732
+ num_masks = sum(len(t["labels"]) for t in targets)
733
+ num_masks = torch.as_tensor(
734
+ [num_masks], dtype=torch.float, device=next(iter(outputs_without_aux.values())).device
735
+ )
736
+ if is_dist_avail_and_initialized():
737
+ torch.distributed.all_reduce(num_masks)
738
+ num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
739
+
740
+ # Compute all the requested losses
741
+ losses = {}
742
+ for loss in self.losses:
743
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
744
+
745
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
746
+ if "aux_outputs" in outputs:
747
+ # NOTE: we reverse the aux_outputs so that the first is the second last layer
748
+ for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
749
+ indices = self.matcher(aux_outputs, targets)
750
+ for loss in self.losses:
751
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
752
+ l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
753
+ losses.update(l_dict)
754
+
755
+ return losses
756
+
757
+ def forward_vlp(self, outputs, targets, extra=None):
758
+ """This performs the loss computation.
759
+ Parameters:
760
+ outputs: dict of tensors, see the output specification of the model for the format
761
+ targets: list of dicts, such that len(targets) == batch_size.
762
+ The expected keys in each dict depends on the losses applied, see each loss' doc
763
+ """
764
+ # Compute all the requested losses
765
+ losses = {}
766
+ num_masks = indices = None
767
+ for loss in self.losses:
768
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
769
+
770
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
771
+ if "aux_outputs" in outputs:
772
+ # NOTE: we reverse the aux_outputs so that the first is the second last layer
773
+ for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
774
+ for loss in self.losses:
775
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
776
+ l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
777
+ losses.update(l_dict)
778
+
779
+ return losses
780
+
781
+ def forward_grounding(self, outputs, targets, extra=None):
782
+ """This performs the loss computation.
783
+ Parameters:
784
+ outputs: dict of tensors, see the output specification of the model for the format
785
+ targets: list of dicts, such that len(targets) == batch_size.
786
+ The expected keys in each dict depends on the losses applied, see each loss' doc
787
+ """
788
+ # Compute all the requested losses
789
+ losses = {}
790
+ indices = [[] for i in range(len(targets))]
791
+
792
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
793
+ num_masks = sum(len(t["grounding_masks"]) for t in targets) + 1e-7
794
+ num_masks = torch.as_tensor(
795
+ [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
796
+ )
797
+ if is_dist_avail_and_initialized():
798
+ torch.distributed.all_reduce(num_masks)
799
+ num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
800
+
801
+ for loss in self.losses:
802
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
803
+
804
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
805
+ if "aux_outputs" in outputs:
806
+ # NOTE: we reverse the aux_outputs so that the first is the second last layer
807
+ for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
808
+ for loss in self.losses:
809
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
810
+ l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
811
+ losses.update(l_dict)
812
+
813
+ return losses
814
+
815
+ def forward_openimage(self, outputs, targets, extra=None):
816
+ """This performs the loss computation.
817
+ Parameters:
818
+ outputs: dict of tensors, see the output specification of the model for the format
819
+ targets: list of dicts, such that len(targets) == batch_size.
820
+ The expected keys in each dict depends on the losses applied, see each loss' doc
821
+ """
822
+ neg_class_emb = all_gather_grad(torch.cat([x['neg_class_emb'] for x in targets]))
823
+ neg_hash = all_gather_grad(torch.cat([x['neg_hash'] for x in targets]))
824
+
825
+ extra['neg_class_emb'] = neg_class_emb
826
+ extra['neg_hash'] = neg_hash
827
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
828
+
829
+ # Retrieve the matching between the outputs of the last layer and the targets
830
+ indices, pred_logits = self.matcher.openimage_forward(outputs_without_aux, targets, extra=extra)
831
+ outputs['pred_logits'] = pred_logits
832
+
833
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
834
+ num_masks = sum(len(t["labels"]) for t in targets)
835
+ num_masks = torch.as_tensor(
836
+ [num_masks], dtype=torch.float, device=neg_class_emb.device
837
+ )
838
+ if is_dist_avail_and_initialized():
839
+ torch.distributed.all_reduce(num_masks)
840
+ num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
841
+
842
+ # Compute all the requested losses
843
+ losses = {}
844
+ for loss in self.losses:
845
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks, 0, extra))
846
+
847
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
848
+ if "aux_outputs" in outputs:
849
+ # NOTE: we reverse the aux_outputs so that the first is the second last layer
850
+ for i, aux_outputs in enumerate(outputs["aux_outputs"][::-1]):
851
+ indices, pred_logits = self.matcher.openimage_forward(aux_outputs, targets, extra=extra)
852
+ aux_outputs['pred_logits'] = pred_logits
853
+ for loss in self.losses:
854
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks, (i+1), extra)
855
+ l_dict = {k.replace('_0', f"_{i+1}"): v for k, v in l_dict.items()}
856
+ losses.update(l_dict)
857
+
858
+ return losses
859
+
860
+ def __repr__(self):
861
+ head = "Criterion " + self.__class__.__name__
862
+ body = [
863
+ "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
864
+ "losses: {}".format(self.losses),
865
+ "weight_dict: {}".format(self.weight_dict),
866
+ "num_classes: {}".format(self.num_classes),
867
+ "eos_coef: {}".format(self.eos_coef),
868
+ "num_points: {}".format(self.num_points),
869
+ "oversample_ratio: {}".format(self.oversample_ratio),
870
+ "importance_sample_ratio: {}".format(self.importance_sample_ratio),
871
+ ]
872
+ _repr_indent = 4
873
+ lines = [head] + [" " * _repr_indent + line for line in body]
874
+ return "\n".join(lines)
modeling/modules/matcher.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Modified by Xueyan Zou ([email protected])
6
+ # --------------------------------------------------------
7
+
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
10
+ """
11
+ Modules to compute the matching cost and solve the corresponding LSAP.
12
+ """
13
+ import warnings
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import numpy as np
17
+ from scipy.optimize import linear_sum_assignment
18
+ from torch import nn
19
+ from torch.cuda.amp import autocast
20
+
21
+ from .point_features import point_sample
22
+ from ..language.loss import vl_similarity
23
+
24
+ def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
25
+ """
26
+ Compute the DICE loss, similar to generalized IOU for masks
27
+ Args:
28
+ inputs: A float tensor of arbitrary shape.
29
+ The predictions for each example.
30
+ targets: A float tensor with the same shape as inputs. Stores the binary
31
+ classification label for each element in inputs
32
+ (0 for the negative class and 1 for the positive class).
33
+ """
34
+ inputs = inputs.sigmoid()
35
+ inputs = inputs.flatten(1)
36
+ numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
37
+ denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
38
+ loss = 1 - (numerator + 1) / (denominator + 1)
39
+ return loss
40
+
41
+
42
+ batch_dice_loss_jit = torch.jit.script(
43
+ batch_dice_loss
44
+ ) # type: torch.jit.ScriptModule
45
+
46
+
47
+ def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):
48
+ """
49
+ Args:
50
+ inputs: A float tensor of arbitrary shape.
51
+ The predictions for each example.
52
+ targets: A float tensor with the same shape as inputs. Stores the binary
53
+ classification label for each element in inputs
54
+ (0 for the negative class and 1 for the positive class).
55
+ Returns:
56
+ Loss tensor
57
+ """
58
+ hw = inputs.shape[1]
59
+
60
+ pos = F.binary_cross_entropy_with_logits(
61
+ inputs, torch.ones_like(inputs), reduction="none"
62
+ )
63
+ neg = F.binary_cross_entropy_with_logits(
64
+ inputs, torch.zeros_like(inputs), reduction="none"
65
+ )
66
+
67
+ loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
68
+ "nc,mc->nm", neg, (1 - targets)
69
+ )
70
+
71
+ return loss / hw
72
+
73
+
74
+ batch_sigmoid_ce_loss_jit = torch.jit.script(
75
+ batch_sigmoid_ce_loss
76
+ ) # type: torch.jit.ScriptModule
77
+
78
+
79
+ class HungarianMatcher(nn.Module):
80
+ """This class computes an assignment between the targets and the predictions of the network
81
+
82
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
83
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
84
+ while the others are un-matched (and thus treated as non-objects).
85
+ """
86
+
87
+ def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0, spatial_cost = None):
88
+ """Creates the matcher
89
+
90
+ Params:
91
+ cost_class: This is the relative weight of the classification error in the matching cost
92
+ cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
93
+ cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
94
+ """
95
+ super().__init__()
96
+ self.cost_class = cost_class
97
+ self.cost_mask = cost_mask
98
+ self.cost_dice = cost_dice
99
+
100
+ self.num_points = num_points
101
+ self.spatial_cost_class = cost_class
102
+ self.spatial_cost_mask = cost_mask
103
+ self.spatial_cost_dice = cost_dice
104
+ assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
105
+
106
+ @torch.no_grad()
107
+ def memory_efficient_forward(self, outputs, targets):
108
+ """More memory-friendly matching"""
109
+ bs, num_queries = outputs["pred_logits"].shape[:2]
110
+
111
+ if bs == 0 or len(targets) == 0:
112
+ return None
113
+
114
+ indices = []
115
+
116
+ # Iterate through batch size
117
+ for b in range(bs):
118
+ out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes]
119
+ tgt_ids = targets[b]["labels"]
120
+
121
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
122
+ # but approximate it in 1 - proba[target class].
123
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
124
+ cost_class = -out_prob[:, tgt_ids]
125
+
126
+ out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
127
+ # gt masks are already padded when preparing target
128
+ tgt_mask = targets[b]["masks"].to(out_mask)
129
+
130
+ out_mask = out_mask[:, None]
131
+ tgt_mask = tgt_mask[:, None]
132
+ # all masks share the same set of points for efficient matching!
133
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
134
+ # get gt labels
135
+ tgt_mask = point_sample(
136
+ tgt_mask,
137
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
138
+ align_corners=False,
139
+ ).squeeze(1)
140
+
141
+ out_mask = point_sample(
142
+ out_mask,
143
+ point_coords.repeat(out_mask.shape[0], 1, 1),
144
+ align_corners=False,
145
+ ).squeeze(1)
146
+
147
+ with autocast(enabled=False):
148
+ out_mask = out_mask.float()
149
+ tgt_mask = tgt_mask.float()
150
+ # Compute the focal loss between masks
151
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
152
+
153
+ # Compute the dice loss betwen masks
154
+ cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
155
+
156
+ # Final cost matrix
157
+ C = (
158
+ self.cost_mask * cost_mask
159
+ + self.cost_class * cost_class
160
+ + self.cost_dice * cost_dice
161
+ )
162
+ C = C.reshape(num_queries, -1).cpu()
163
+ if C.isnan().any():
164
+ C[C.isnan()] = 1e6 ### temporary fix
165
+ warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
166
+ raise
167
+ indices.append(linear_sum_assignment(C))
168
+
169
+ return [
170
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
171
+ for i, j in indices
172
+ ]
173
+
174
+ @torch.no_grad()
175
+ def openimage_forward(self, outputs, targets, extra):
176
+ """More memory-friendly matching"""
177
+ bs, num_queries = outputs["pred_captions"].shape[:2]
178
+ if bs == 0 or len(targets) == 0:
179
+ return None
180
+
181
+ neg_class_emb = extra['neg_class_emb']
182
+ neg_hash = extra['neg_hash']
183
+ _, unique_indices = np.unique(neg_hash.cpu().numpy(), return_index=True)
184
+ neg_class_emb = neg_class_emb[unique_indices]
185
+ neg_hash = neg_hash[unique_indices]
186
+
187
+ indices = []
188
+ pred_logits = []
189
+ # Iterate through batch size
190
+ for b in range(bs):
191
+ _pos_class_emb = targets[b]['pos_class_emb']
192
+ _pos_hash = targets[b]['pos_hash']
193
+ _neg_overlap_pos = ~(neg_hash[..., None] == _pos_hash).any(-1)
194
+ _neg_class_emb = neg_class_emb[_neg_overlap_pos]
195
+ t_emb = torch.cat((_pos_class_emb, _neg_class_emb))
196
+ v_emb = outputs["pred_captions"][b]
197
+ del _pos_class_emb
198
+ del _neg_class_emb
199
+
200
+ t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7)
201
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
202
+
203
+ out_prob = vl_similarity(v_emb, t_emb, temperature=extra['lang_logit'])
204
+ pred_logits += [out_prob]
205
+ out_prob = out_prob.softmax(-1)
206
+ tgt_ids = targets[b]["labels"]
207
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
208
+ # but approximate it in 1 - proba[target class].
209
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
210
+ cost_class = -out_prob[:, tgt_ids]
211
+
212
+ out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
213
+ # gt masks are already padded when preparing target
214
+ tgt_mask = targets[b]["masks"].to(out_mask)
215
+
216
+ out_mask = out_mask[:, None]
217
+ tgt_mask = tgt_mask[:, None]
218
+ # all masks share the same set of points for efficient matching!
219
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
220
+ # get gt labels
221
+ tgt_mask = point_sample(
222
+ tgt_mask,
223
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
224
+ align_corners=False,
225
+ ).squeeze(1)
226
+
227
+ out_mask = point_sample(
228
+ out_mask,
229
+ point_coords.repeat(out_mask.shape[0], 1, 1),
230
+ align_corners=False,
231
+ ).squeeze(1)
232
+
233
+ with autocast(enabled=False):
234
+ out_mask = out_mask.float()
235
+ tgt_mask = tgt_mask.float()
236
+ # Compute the focal loss between masks
237
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
238
+
239
+ # Compute the dice loss betwen masks
240
+ cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
241
+
242
+ # Final cost matrix
243
+ C = (
244
+ self.cost_mask * cost_mask
245
+ + self.cost_class * cost_class
246
+ + self.cost_dice * cost_dice
247
+ )
248
+ C = C.reshape(num_queries, -1).cpu()
249
+ if C.isnan().any():
250
+ C[C.isnan()] = 1e6 ### temporary fix
251
+ warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
252
+ raise
253
+ indices.append(linear_sum_assignment(C))
254
+
255
+ return [
256
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
257
+ for i, j in indices
258
+ ], pred_logits
259
+
260
+ @torch.no_grad()
261
+ def grounding_forward(self, outputs, targets, extra):
262
+ """More memory-friendly matching"""
263
+ bs, num_queries = outputs["pred_gmasks"].shape[:2]
264
+
265
+ if bs == 0 or len(targets) == 0:
266
+ return None
267
+
268
+ indices = []
269
+ # Iterate through batch size
270
+ for b in range(bs):
271
+ out_prob = outputs["pred_logits"][b]
272
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
273
+ # but approximate it in 1 - proba[target class].
274
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
275
+ cost_class = -out_prob.softmax(dim=0)
276
+
277
+ out_mask = outputs["pred_gmasks"][b] # [num_queries, H_pred, W_pred]
278
+ # gt masks are already padded when preparing target
279
+ tgt_mask = targets[b]["grounding_masks"].to(out_mask)
280
+
281
+ out_mask = out_mask[:, None]
282
+ tgt_mask = tgt_mask[:, None]
283
+
284
+ # all masks share the same set of points for efficient matching!
285
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
286
+ # get gt labels
287
+ tgt_mask = point_sample(
288
+ tgt_mask,
289
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
290
+ align_corners=False,
291
+ ).squeeze(1)
292
+
293
+ out_mask = point_sample(
294
+ out_mask,
295
+ point_coords.repeat(out_mask.shape[0], 1, 1),
296
+ align_corners=False,
297
+ ).squeeze(1)
298
+
299
+ with autocast(enabled=False):
300
+ out_mask = out_mask.float()
301
+ tgt_mask = tgt_mask.float()
302
+ # Compute the focal loss between masks
303
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
304
+
305
+ # Compute the dice loss betwen masks
306
+ cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
307
+
308
+ # Final cost matrix
309
+ C = (
310
+ self.cost_mask * cost_mask
311
+ + self.cost_class * cost_class
312
+ + self.cost_dice * cost_dice
313
+ )
314
+ C = C.reshape(num_queries, -1).cpu()
315
+ if C.isnan().any():
316
+ C[C.isnan()] = 1e6 ### temporary fix
317
+ warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
318
+ raise
319
+ indices.append(linear_sum_assignment(C))
320
+
321
+ return [
322
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
323
+ for i, j in indices
324
+ ]
325
+
326
+ @torch.no_grad()
327
+ def spatial_forward(self, outputs, targets, extra):
328
+ """More memory-friendly matching"""
329
+ bs, num_queries = outputs["pred_smasks"].shape[:2]
330
+
331
+ if bs == 0 or len(targets) == 0:
332
+ return None
333
+
334
+ indices = []
335
+ # Iterate through batch size
336
+ for b in range(bs):
337
+ out_mask = outputs["pred_smasks"][b] # [num_queries, H_pred, W_pred]
338
+ # gt masks are already padded when preparing target
339
+ tgt_mask = targets[b]["gt_spatial_masks"].to(out_mask)
340
+ nd,ns = outputs["pred_pos_logits"][b].shape
341
+ index_masking = 1-torch.eye(ns, device=out_mask.device, dtype=tgt_mask.dtype).repeat_interleave(nd//ns,dim=0)
342
+ neg_masking = torch.zeros((nd,ns), device=out_mask.device, dtype=tgt_mask.dtype)
343
+ neg_masking.masked_fill_(index_masking.bool(), -float('inf'))
344
+ pos_masking = torch.zeros((nd,ns), device=out_mask.device, dtype=tgt_mask.dtype)
345
+ pos_masking.masked_fill_(index_masking.bool(), float('inf'))
346
+ out_prob = (outputs["pred_pos_logits"][b]+neg_masking)[:,:len(tgt_mask)] # remove redundant predictions for padding
347
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
348
+ # but approximate it in 1 - proba[target class].
349
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
350
+ cost_class = -out_prob.softmax(dim=0)
351
+
352
+ out_mask = out_mask[:, None]
353
+ tgt_mask = tgt_mask[:, None]
354
+
355
+ # all masks share the same set of points for efficient matching!
356
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
357
+ # get gt labels
358
+ tgt_mask = point_sample(
359
+ tgt_mask,
360
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
361
+ align_corners=False,
362
+ ).squeeze(1)
363
+
364
+ out_mask = point_sample(
365
+ out_mask,
366
+ point_coords.repeat(out_mask.shape[0], 1, 1),
367
+ align_corners=False,
368
+ ).squeeze(1)
369
+
370
+ with autocast(enabled=False):
371
+ out_mask = out_mask.float()
372
+ tgt_mask = tgt_mask.float()
373
+ # Compute the focal loss between masks
374
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) + pos_masking[:,:len(tgt_mask)]
375
+ # Compute the dice loss betwen masks
376
+ cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) + pos_masking[:,:len(tgt_mask)]
377
+
378
+ # Final cost matrix
379
+ C = (
380
+ self.spatial_cost_mask * cost_mask
381
+ + self.spatial_cost_class * cost_class
382
+ + self.spatial_cost_dice * cost_dice
383
+ )
384
+ C = C.reshape(num_queries, -1).cpu()
385
+ if C.isnan().any():
386
+ C[C.isnan()] = 1e6 ### temporary fix
387
+ warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
388
+ raise
389
+ indices.append(linear_sum_assignment(C))
390
+
391
+ return [
392
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
393
+ for i, j in indices
394
+ ]
395
+
396
+ @torch.no_grad()
397
+ def spatial_forward_pn(self, outputs, targets, extra):
398
+ """More memory-friendly matching"""
399
+ bs, num_queries = outputs["pred_smasks"].shape[:2]
400
+
401
+ if bs == 0 or len(targets) == 0:
402
+ return None
403
+
404
+ fp_mask = extra['false_positive_mask']
405
+ gt_mask = torch.stack([targets[b]["gt_spatial_masks"] for b in range(bs)])
406
+
407
+ indices = []
408
+ # Iterate through batch size
409
+ for b in range(bs):
410
+ out_prob = outputs["pred_neg_logits"][b]
411
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
412
+ # but approximate it in 1 - proba[target class].
413
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
414
+ cost_class = -out_prob.softmax(dim=0)
415
+
416
+ out_mask = outputs["pred_smasks"][b] # [num_queries, H_pred, W_pred]
417
+ tgt_mask = fp_mask[b].to(out_mask)
418
+ ign_mask = (gt_mask[b] | fp_mask[b]).to(out_mask)
419
+
420
+ out_mask = out_mask[:, None]
421
+ tgt_mask = tgt_mask[:, None]
422
+ ign_mask = ign_mask[:, None]
423
+
424
+ # all masks share the same set of points for efficient matching!
425
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
426
+
427
+ # get gt labels
428
+ tgt_mask = point_sample(
429
+ tgt_mask,
430
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
431
+ align_corners=False,
432
+ ).squeeze(1)
433
+
434
+ out_mask = point_sample(
435
+ out_mask,
436
+ point_coords.repeat(out_mask.shape[0], 1, 1),
437
+ align_corners=False,
438
+ ).squeeze(1)
439
+
440
+ ign_mask = point_sample(
441
+ ign_mask,
442
+ point_coords.repeat(ign_mask.shape[0], 1, 1),
443
+ align_corners=False,
444
+ ).squeeze(1)
445
+
446
+ with autocast(enabled=False):
447
+ out_mask = out_mask.float()
448
+ tgt_mask = tgt_mask.float()
449
+ ign_mask = ign_mask.float()
450
+
451
+ # Compute the focal loss between masks
452
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask*ign_mask, tgt_mask*ign_mask)
453
+
454
+ # Compute the dice loss betwen masks
455
+ cost_dice = batch_dice_loss_jit(out_mask*ign_mask, tgt_mask*ign_mask)
456
+
457
+ # Final cost matrix
458
+ C = (
459
+ self.spatial_cost_mask * cost_mask
460
+ + self.spatial_cost_class * cost_class
461
+ + self.spatial_cost_dice * cost_dice
462
+ )
463
+ C = C.reshape(num_queries, -1).cpu()
464
+ if C.isnan().any():
465
+ C[C.isnan()] = 1e6 ### temporary fix
466
+ warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
467
+ raise
468
+ indices.append(linear_sum_assignment(C))
469
+
470
+ return [
471
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
472
+ for i, j in indices
473
+ ]
474
+
475
+ @torch.no_grad()
476
+ def caption_forward_womask(self, outputs, targets, extra):
477
+ """More memory-friendly matching"""
478
+ bs, _ = outputs["pred_logits"].shape[:2]
479
+
480
+ if bs == 0 or len(targets) == 0:
481
+ return None
482
+
483
+ indices = []
484
+ t_emb = torch.cat([t['captions'] for t in targets])
485
+ v_emb = outputs['unmatched_pred_captions']
486
+ caption_target_count = np.cumsum([0] + [len(t['captions']) for t in targets])
487
+
488
+ # Iterate through batch size
489
+ for b in range(bs):
490
+ v_emb[b] = v_emb[b] / (v_emb[b].norm(dim=-1, keepdim=True) + 1e-7)
491
+ num_queries = len(v_emb[b])
492
+ out_prob = vl_similarity(v_emb[b][None,], t_emb, temperature=extra['temperature']).softmax(-1)[0]
493
+ tgt_ids = [idx for idx in range(caption_target_count[b], caption_target_count[b+1])]
494
+
495
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
496
+ # but approximate it in 1 - proba[target class].
497
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
498
+ cost_class = -out_prob[:, tgt_ids]
499
+
500
+ # Final cost matrix
501
+ C = (self.cost_class * cost_class)
502
+ C = C.reshape(num_queries, -1).cpu()
503
+ if C.isnan().any():
504
+ C[C.isnan()] = 1e6 ### temporary fix
505
+ warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
506
+ raise
507
+ indices.append(linear_sum_assignment(C))
508
+
509
+ return [
510
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
511
+ for i, j in indices
512
+ ]
513
+
514
+ @torch.no_grad()
515
+ def caption_forward_wmask(self, outputs, targets, extra):
516
+ """More memory-friendly matching"""
517
+ bs, _ = outputs["pred_logits"].shape[:2]
518
+
519
+ if bs == 0 or len(targets) == 0:
520
+ return None
521
+
522
+ indices = []
523
+ t_emb = torch.cat([t['captions'] for t in targets])
524
+ v_emb = outputs['unmatched_pred_captions']
525
+ caption_target_count = np.cumsum([0] + [len(t['captions']) for t in targets])
526
+
527
+ # Iterate through batch size
528
+ for b in range(bs):
529
+ v_emb[b] = v_emb[b] / (v_emb[b].norm(dim=-1, keepdim=True) + 1e-7)
530
+ num_queries = len(v_emb[b])
531
+
532
+ out_prob = vl_similarity(v_emb[b][None,], t_emb, temperature=extra['temperature']).softmax(-1)[0]
533
+ tgt_ids = [idx for idx in range(caption_target_count[b], caption_target_count[b+1])]
534
+
535
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
536
+ # but approximate it in 1 - proba[target class].
537
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
538
+ cost_class = -out_prob[:, tgt_ids]
539
+
540
+ out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
541
+ # gt masks are already padded when preparing target
542
+ tgt_mask = targets[b]["masks"].to(out_mask)
543
+
544
+ out_mask = out_mask[:, None]
545
+ tgt_mask = tgt_mask[:, None]
546
+ # all masks share the same set of points for efficient matching!
547
+ point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device, dtype=tgt_mask.dtype)
548
+ # get gt labels
549
+ tgt_mask = point_sample(
550
+ tgt_mask,
551
+ point_coords.repeat(tgt_mask.shape[0], 1, 1),
552
+ align_corners=False,
553
+ ).squeeze(1)
554
+
555
+ out_mask = point_sample(
556
+ out_mask,
557
+ point_coords.repeat(out_mask.shape[0], 1, 1),
558
+ align_corners=False,
559
+ ).squeeze(1)
560
+
561
+ with autocast(enabled=False):
562
+ out_mask = out_mask.float()
563
+ tgt_mask = tgt_mask.float()
564
+ # Compute the focal loss between masks
565
+ cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
566
+
567
+ # Compute the dice loss betwen masks
568
+ cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
569
+
570
+ # Final cost matrix
571
+ C = (
572
+ self.cost_mask * cost_mask
573
+ + self.cost_class * cost_class
574
+ + self.cost_dice * cost_dice
575
+ )
576
+ C = C.reshape(num_queries, -1).cpu()
577
+ if C.isnan().any():
578
+ C[C.isnan()] = 1e6 ### temporary fix
579
+ warnings.warn("NAN in Cost Matrix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
580
+ raise
581
+ indices.append(linear_sum_assignment(C))
582
+
583
+ return [
584
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
585
+ for i, j in indices
586
+ ]
587
+
588
+ @torch.no_grad()
589
+ def forward(self, outputs, targets, mode='default', extra={}):
590
+ """Performs the matching
591
+
592
+ Params:
593
+ outputs: This is a dict that contains at least these entries:
594
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
595
+ "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
596
+
597
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
598
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
599
+ objects in the target) containing the class labels
600
+ "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
601
+
602
+ Returns:
603
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
604
+ - index_i is the indices of the selected predictions (in order)
605
+ - index_j is the indices of the corresponding selected targets (in order)
606
+ For each batch element, it holds:
607
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
608
+ """
609
+ if mode == 'default':
610
+ return self.memory_efficient_forward(outputs, targets)
611
+ elif mode == 'grounding':
612
+ return self.grounding_forward(outputs, targets, extra)
613
+ elif mode == 'spatial':
614
+ return self.spatial_forward(outputs, targets, extra)
615
+ elif mode == 'spatial_pn':
616
+ return self.spatial_forward_pn(outputs, targets, extra)
617
+ elif mode == 'caption_womask':
618
+ return self.caption_forward_womask(outputs, targets, extra)
619
+ elif mode == 'caption_wmask':
620
+ return self.caption_forward_wmask(outputs, targets, extra)
621
+ else:
622
+ assert False, "Mode {} is not supported.".format(mode)
623
+
624
+ def __repr__(self, _repr_indent=4):
625
+ head = "Matcher " + self.__class__.__name__
626
+ body = [
627
+ "cost_class: {}".format(self.cost_class),
628
+ "cost_mask: {}".format(self.cost_mask),
629
+ "cost_dice: {}".format(self.cost_dice),
630
+ ]
631
+ lines = [head] + [" " * _repr_indent + line for line in body]
632
+ return "\n".join(lines)