InstantMesh / lrm /lrm.py
dylanebert
optional progress callback
c8a48ed
import collections
import itertools
import math
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import numpy as np
import nvdiffrast.torch as dr
import torch
import torch.nn as nn
import torch.nn.functional as F
import xatlas
from diffusers import ConfigMixin, ModelMixin
from transformers import PreTrainedModel, ViTConfig, ViTImageProcessor
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.pytorch_utils import (
find_pruneable_heads_and_indices,
prune_linear_layer,
)
def generate_planes():
"""
Defines planes by the three vectors that form the "axes" of the
plane. Should work with arbitrary number of planes and planes of
arbitrary orientation.
Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
"""
return torch.tensor(
[
[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
[[1, 0, 0], [0, 0, 1], [0, 1, 0]],
[[0, 0, 1], [0, 1, 0], [1, 0, 0]],
],
dtype=torch.float32,
)
def project_onto_planes(planes, coordinates):
"""
Does a projection of a 3D point onto a batch of 2D planes,
returning 2D plane coordinates.
Takes plane axes of shape n_planes, 3, 3
# Takes coordinates of shape N, M, 3
# returns projections of shape N*n_planes, M, 2
"""
N, M, C = coordinates.shape
n_planes, _, _ = planes.shape
coordinates = (
coordinates.unsqueeze(1)
.expand(-1, n_planes, -1, -1)
.reshape(N * n_planes, M, 3)
)
inv_planes = (
torch.linalg.inv(planes)
.unsqueeze(0)
.expand(N, -1, -1, -1)
.reshape(N * n_planes, 3, 3)
)
projections = torch.bmm(coordinates, inv_planes)
return projections[..., :2]
def sample_from_planes(
plane_axes,
plane_features,
coordinates,
mode="bilinear",
padding_mode="zeros",
box_warp=None,
):
assert padding_mode == "zeros"
N, n_planes, C, H, W = plane_features.shape
_, M, _ = coordinates.shape
plane_features = plane_features.view(N * n_planes, C, H, W)
dtype = plane_features.dtype
coordinates = (2 / box_warp) * coordinates # add specific box bounds
projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
output_features = (
torch.nn.functional.grid_sample(
plane_features,
projected_coordinates.to(dtype),
mode=mode,
padding_mode=padding_mode,
align_corners=False,
)
.permute(0, 3, 2, 1)
.reshape(N, n_planes, M, C)
)
return output_features
class OSGDecoder(nn.Module):
"""
Triplane decoder that gives RGB and sigma values from sampled features.
Using ReLU here instead of Softplus in the original implementation.
Reference:
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
"""
def __init__(
self,
n_features: int,
hidden_dim: int = 64,
num_layers: int = 4,
activation: nn.Module = nn.ReLU,
):
super().__init__()
self.net_sdf = nn.Sequential(
nn.Linear(3 * n_features, hidden_dim),
activation(),
*itertools.chain(
*[
[
nn.Linear(hidden_dim, hidden_dim),
activation(),
]
for _ in range(num_layers - 2)
]
),
nn.Linear(hidden_dim, 1),
)
self.net_rgb = nn.Sequential(
nn.Linear(3 * n_features, hidden_dim),
activation(),
*itertools.chain(
*[
[
nn.Linear(hidden_dim, hidden_dim),
activation(),
]
for _ in range(num_layers - 2)
]
),
nn.Linear(hidden_dim, 3),
)
self.net_deformation = nn.Sequential(
nn.Linear(3 * n_features, hidden_dim),
activation(),
*itertools.chain(
*[
[
nn.Linear(hidden_dim, hidden_dim),
activation(),
]
for _ in range(num_layers - 2)
]
),
nn.Linear(hidden_dim, 3),
)
self.net_weight = nn.Sequential(
nn.Linear(8 * 3 * n_features, hidden_dim),
activation(),
*itertools.chain(
*[
[
nn.Linear(hidden_dim, hidden_dim),
activation(),
]
for _ in range(num_layers - 2)
]
),
nn.Linear(hidden_dim, 21),
)
# init all bias to zero
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.zeros_(m.bias)
def get_geometry_prediction(self, sampled_features, flexicubes_indices):
_N, n_planes, _M, _C = sampled_features.shape
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(
_N, _M, n_planes * _C
)
sdf = self.net_sdf(sampled_features)
deformation = self.net_deformation(sampled_features)
grid_features = torch.index_select(
input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1
)
grid_features = grid_features.reshape(
sampled_features.shape[0],
flexicubes_indices.shape[0],
flexicubes_indices.shape[1] * sampled_features.shape[-1],
)
weight = self.net_weight(grid_features) * 0.1
return sdf, deformation, weight
def get_texture_prediction(self, sampled_features):
_N, n_planes, _M, _C = sampled_features.shape
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(
_N, _M, n_planes * _C
)
rgb = self.net_rgb(sampled_features)
rgb = (
torch.sigmoid(rgb) * (1 + 2 * 0.001) - 0.001
) # Uses sigmoid clamping from MipNeRF
return rgb
class TriplaneSynthesizer(nn.Module):
"""
Synthesizer that renders a triplane volume with planes and a camera.
Reference:
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
"""
DEFAULT_RENDERING_KWARGS = {
"ray_start": "auto",
"ray_end": "auto",
"box_warp": 2.0,
"white_back": True,
"disparity_space_sampling": False,
"clamp_mode": "softplus",
"sampler_bbox_min": -1.0,
"sampler_bbox_max": 1.0,
}
def __init__(self, triplane_dim: int, samples_per_ray: int):
super().__init__()
# attributes
self.triplane_dim = triplane_dim
self.rendering_kwargs = {
**self.DEFAULT_RENDERING_KWARGS,
"depth_resolution": samples_per_ray // 2,
"depth_resolution_importance": samples_per_ray // 2,
}
# modules
self.plane_axes = generate_planes()
self.decoder = OSGDecoder(n_features=triplane_dim)
def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices):
plane_axes = self.plane_axes.to(planes.device)
sampled_features = sample_from_planes(
plane_axes,
planes,
sample_coordinates,
padding_mode="zeros",
box_warp=self.rendering_kwargs["box_warp"],
)
sdf, deformation, weight = self.decoder.get_geometry_prediction(
sampled_features, flexicubes_indices
)
return sdf, deformation, weight
def get_texture_prediction(self, planes, sample_coordinates):
plane_axes = self.plane_axes.to(planes.device)
sampled_features = sample_from_planes(
plane_axes,
planes,
sample_coordinates,
padding_mode="zeros",
box_warp=self.rendering_kwargs["box_warp"],
)
rgb = self.decoder.get_texture_prediction(sampled_features)
return rgb
dmc_table = [
[
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 8, 9, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 7, 8, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 4, 7, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[4, 7, 8, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 4, 7, 9, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 5, 9, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[4, 5, 9, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 4, 5, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 4, 5, 8, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[5, 7, 8, 9, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 5, 7, 9, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 5, 7, 8, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 5, 7, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 8, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[2, 3, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 8, 9, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 7, 8, -1, -1, -1, -1],
[2, 3, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 4, 7, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[4, 7, 8, -1, -1, -1, -1],
[2, 3, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 4, 7, 9, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 5, 9, -1, -1, -1, -1],
[2, 3, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 8, 11, -1, -1, -1],
[4, 5, 9, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 4, 5, -1, -1, -1],
[2, 3, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 4, 5, 8, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[5, 7, 8, 9, -1, -1, -1],
[2, 3, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 5, 7, 9, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 5, 7, 8, -1, -1],
[2, 3, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 5, 7, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[1, 2, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 9, 10, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 8, 9, 10, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 7, 8, -1, -1, -1, -1],
[1, 2, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 4, 7, -1, -1, -1],
[1, 2, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 9, 10, -1, -1, -1],
[4, 7, 8, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 4, 7, 9, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 5, 9, -1, -1, -1, -1],
[1, 2, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[4, 5, 9, -1, -1, -1, -1],
[1, 2, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 4, 5, 10, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 4, 5, 8, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[5, 7, 8, 9, -1, -1, -1],
[1, 2, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 5, 7, 9, -1, -1],
[1, 2, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 5, 7, 8, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 5, 7, 10, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 10, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 8, 10, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 9, 10, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[8, 9, 10, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 7, 8, -1, -1, -1, -1],
[1, 3, 10, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 4, 7, 10, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 9, 10, 11, -1, -1],
[4, 7, 8, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 7, 9, 10, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 5, 9, -1, -1, -1, -1],
[1, 3, 10, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 8, 10, 11, -1, -1],
[4, 5, 9, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 4, 5, 10, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 5, 8, 10, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[5, 7, 8, 9, -1, -1, -1],
[1, 3, 10, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 5, 7, 9, 10, 11],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 5, 7, 8, 10, 11],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[5, 7, 10, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 8, 9, -1, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 6, 8, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 4, 6, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[4, 6, 8, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 4, 6, 9, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 5, 9, -1, -1, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[4, 5, 9, -1, -1, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 4, 5, -1, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 4, 5, 8, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[5, 6, 8, 9, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 5, 6, 9, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 5, 6, 8, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 5, 6, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 6, 7, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 6, 7, 8, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[2, 3, 6, 7, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 6, 7, 8, 9, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 4, 6, 8, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 4, 6, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[2, 3, 4, 6, 8, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 4, 6, 9, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 5, 9, -1, -1, -1, -1],
[2, 3, 6, 7, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 6, 7, 8, -1, -1],
[4, 5, 9, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 4, 5, -1, -1, -1],
[2, 3, 6, 7, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 4, 5, 6, 7, 8],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 5, 6, 8, 9, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 5, 6, 9, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 2, 3, 5, 6, 8],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 5, 6, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 10, -1, -1, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[1, 2, 10, -1, -1, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 9, 10, -1, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 8, 9, 10, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 6, 8, 11, -1, -1, -1],
[1, 2, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 4, 6, 11, -1, -1],
[1, 2, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 9, 10, -1, -1, -1],
[4, 6, 8, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 4, 6, 9, 10, 11],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 5, 9, -1, -1, -1, -1],
[1, 2, 10, -1, -1, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[4, 5, 9, -1, -1, -1, -1],
[1, 2, 10, -1, -1, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
],
[
[0, 2, 4, 5, 10, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 4, 5, 8, 10, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[5, 6, 8, 9, 11, -1, -1],
[1, 2, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 5, 6, 9, 11, -1],
[1, 2, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 5, 6, 8, 10, 11],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 5, 6, 10, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 6, 7, 10, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 6, 7, 8, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 6, 7, 9, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[6, 7, 8, 9, 10, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 4, 6, 8, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 4, 6, 10, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 4, 6, 8, 9, 10],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 6, 9, 10, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 5, 9, -1, -1, -1, -1],
[1, 3, 6, 7, 10, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 6, 7, 8, 10, -1],
[4, 5, 9, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 4, 5, 6, 7, 10],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 5, 6, 7, 8, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 5, 6, 8, 9, 10],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 5, 6, 9, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 8, 9, -1, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 7, 8, -1, -1, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 4, 7, -1, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[4, 7, 8, -1, -1, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 4, 7, 9, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 6, 9, 10, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[4, 6, 9, 10, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 4, 6, 10, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 4, 6, 8, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[6, 7, 8, 9, 10, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 6, 7, 9, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 6, 7, 8, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 6, 7, 10, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 11, -1, -1, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 8, 11, -1, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[2, 3, 11, -1, -1, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 8, 9, 11, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 7, 8, -1, -1, -1, -1],
[2, 3, 11, -1, -1, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 4, 7, 11, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[4, 7, 8, -1, -1, -1, -1],
[2, 3, 11, -1, -1, -1, -1],
[5, 6, 10, -1, -1, -1, -1],
],
[
[1, 2, 4, 7, 9, 11, -1],
[5, 6, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 6, 9, 10, -1, -1, -1],
[2, 3, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 8, 11, -1, -1, -1],
[4, 6, 9, 10, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 4, 6, 10, -1, -1],
[2, 3, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 4, 6, 8, 10, 11],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[6, 7, 8, 9, 10, -1, -1],
[2, 3, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 6, 7, 9, 10, 11],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 6, 7, 8, 10, -1],
[2, 3, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 6, 7, 10, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 5, 6, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[1, 2, 5, 6, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 5, 6, 9, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 5, 6, 8, 9, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 7, 8, -1, -1, -1, -1],
[1, 2, 5, 6, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 4, 7, -1, -1, -1],
[1, 2, 5, 6, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 5, 6, 9, -1, -1],
[4, 7, 8, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 4, 5, 6, 7, 9],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 4, 6, 9, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[1, 2, 4, 6, 9, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 4, 6, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 4, 6, 8, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 6, 7, 8, 9, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 2, 3, 6, 7, 9],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 6, 7, 8, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 6, 7, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 5, 6, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 5, 6, 8, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 5, 6, 9, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[5, 6, 8, 9, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 7, 8, -1, -1, -1, -1],
[1, 3, 5, 6, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 4, 5, 6, 7, 11],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 5, 6, 9, 11, -1],
[4, 7, 8, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 5, 6, 7, 9, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 4, 6, 9, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 4, 6, 8, 9, 11],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 4, 6, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 6, 8, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 6, 7, 8, 9, 11],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 6, 7, 8, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[6, 7, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[5, 7, 10, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[5, 7, 10, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[5, 7, 10, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 8, 9, -1, -1, -1],
[5, 7, 10, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 5, 8, 10, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 4, 5, 10, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[4, 5, 8, 10, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 4, 5, 9, 10, 11],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 7, 9, 10, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[4, 7, 9, 10, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 4, 7, 10, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 4, 7, 8, 10, 11],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[8, 9, 10, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 9, 10, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 8, 10, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 10, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 5, 7, 10, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 5, 7, 8, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[2, 3, 5, 7, 10, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 5, 7, 8, 9, 10],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 4, 5, 8, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 4, 5, 10, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[2, 3, 4, 5, 8, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 4, 5, 9, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 4, 7, 9, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 4, 7, 8, 9, 10],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 2, 3, 4, 7, 10],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 7, 8, -1, -1, -1, -1],
[1, 2, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 8, 9, 10, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 9, 10, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 2, 3, 8, 10, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 10, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 5, 7, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[1, 2, 5, 7, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 5, 7, 9, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 5, 7, 8, 9, 11],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 4, 5, 8, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 2, 3, 4, 5, 11],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 4, 5, 8, 9, 11],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 5, 9, -1, -1, -1, -1],
[2, 3, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 4, 7, 9, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[1, 2, 4, 7, 9, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 4, 7, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 4, 7, 8, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 2, 8, 9, 11, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 2, 3, 9, 11, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 2, 8, 11, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[2, 3, 11, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 5, 7, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 5, 7, 8, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 5, 7, 9, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[5, 7, 8, 9, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 4, 5, 8, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 4, 5, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 4, 5, 8, 9, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 5, 9, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 4, 7, 9, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 4, 7, 8, 9, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 4, 7, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[4, 7, 8, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[1, 3, 8, 9, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 1, 9, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[0, 3, 8, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
[
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1, -1],
],
]
num_vd_table = [
0,
1,
1,
1,
1,
1,
2,
1,
1,
2,
1,
1,
1,
1,
1,
1,
1,
1,
2,
1,
2,
1,
3,
1,
2,
2,
2,
1,
2,
1,
2,
1,
1,
2,
1,
1,
2,
2,
2,
1,
2,
3,
1,
1,
2,
2,
1,
1,
1,
1,
1,
1,
2,
1,
2,
1,
2,
2,
1,
1,
2,
1,
1,
1,
1,
2,
2,
2,
1,
1,
2,
1,
2,
3,
2,
2,
1,
1,
1,
1,
1,
1,
2,
1,
1,
1,
2,
1,
2,
2,
2,
1,
1,
1,
1,
1,
2,
3,
2,
2,
2,
2,
2,
1,
3,
4,
2,
2,
2,
2,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
2,
2,
1,
1,
1,
1,
2,
1,
1,
2,
2,
2,
2,
2,
3,
2,
1,
2,
1,
1,
1,
1,
1,
1,
2,
2,
3,
2,
3,
2,
4,
2,
2,
2,
2,
1,
2,
1,
2,
1,
1,
2,
1,
1,
2,
2,
2,
1,
1,
2,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
2,
1,
2,
1,
1,
1,
1,
1,
1,
2,
1,
1,
1,
2,
2,
2,
1,
1,
2,
1,
1,
2,
1,
1,
1,
1,
1,
1,
1,
1,
2,
1,
1,
1,
2,
1,
1,
1,
1,
2,
1,
1,
1,
1,
1,
2,
1,
1,
1,
1,
1,
2,
1,
2,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
0,
]
check_table = [
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 0, 0, 194],
[1, -1, 0, 0, 193],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 1, 0, 164],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, -1, 0, 161],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, 1, 152],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, 1, 145],
[1, 0, 0, 1, 144],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, -1, 137],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 1, 0, 133],
[1, 0, 1, 0, 132],
[1, 1, 0, 0, 131],
[1, 1, 0, 0, 130],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, 1, 100],
[0, 0, 0, 0, 0],
[1, 0, 0, 1, 98],
[0, 0, 0, 0, 0],
[1, 0, 0, 1, 96],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 1, 0, 88],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, -1, 0, 82],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 1, 0, 74],
[0, 0, 0, 0, 0],
[1, 0, 1, 0, 72],
[0, 0, 0, 0, 0],
[1, 0, 0, -1, 70],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, -1, 0, 0, 67],
[0, 0, 0, 0, 0],
[1, -1, 0, 0, 65],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 0, 0, 56],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, -1, 0, 0, 52],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 0, 0, 44],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 0, 0, 40],
[0, 0, 0, 0, 0],
[1, 0, 0, -1, 38],
[1, 0, -1, 0, 37],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, -1, 0, 33],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, -1, 0, 0, 28],
[0, 0, 0, 0, 0],
[1, 0, -1, 0, 26],
[1, 0, 0, -1, 25],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, -1, 0, 0, 20],
[0, 0, 0, 0, 0],
[1, 0, -1, 0, 18],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, -1, 9],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, -1, 6],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
]
tet_table = [
[-1, -1, -1, -1, -1, -1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[4, 4, 4, 4, 4, 4],
[0, 0, 0, 0, 0, 0],
[4, 0, 0, 4, 4, -1],
[1, 1, 1, 1, 1, 1],
[4, 4, 4, 4, 4, 4],
[0, 4, 0, 4, 4, -1],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[5, 5, 5, 5, 5, 5],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[2, 0, 2, -1, 0, 2],
[1, 1, 1, 1, 1, 1],
[2, -1, 2, 4, 4, 2],
[0, 0, 0, 0, 0, 0],
[2, 0, 2, 4, 4, 2],
[1, 1, 1, 1, 1, 1],
[2, 4, 2, 4, 4, 2],
[0, 4, 0, 4, 4, 0],
[2, 0, 2, 0, 0, 2],
[1, 1, 1, 1, 1, 1],
[2, 5, 2, 5, 5, 2],
[0, 0, 0, 0, 0, 0],
[2, 0, 2, 0, 0, 2],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 1, 1, -1, 0, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[4, 1, 1, 4, 4, 1],
[0, 1, 1, 0, 0, 1],
[4, 0, 0, 4, 4, 0],
[2, 2, 2, 2, 2, 2],
[-1, 1, 1, 4, 4, 1],
[0, 1, 1, 4, 4, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[5, 1, 1, 5, 5, 1],
[0, 1, 1, 0, 0, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[8, 8, 8, 8, 8, 8],
[1, 1, 1, 4, 4, 1],
[0, 0, 0, 0, 0, 0],
[4, 0, 0, 4, 4, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 4, 4, 1],
[0, 4, 0, 4, 4, 0],
[0, 0, 0, 0, 0, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 5, 5, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[5, 5, 5, 5, 5, 5],
[6, 6, 6, 6, 6, 6],
[6, -1, 0, 6, 0, 6],
[6, 0, 0, 6, 0, 6],
[6, 1, 1, 6, 1, 6],
[4, 4, 4, 4, 4, 4],
[0, 0, 0, 0, 0, 0],
[4, 0, 0, 4, 4, 4],
[1, 1, 1, 1, 1, 1],
[6, 4, -1, 6, 4, 6],
[6, 4, 0, 6, 4, 6],
[6, 0, 0, 6, 0, 6],
[6, 1, 1, 6, 1, 6],
[5, 5, 5, 5, 5, 5],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[2, 0, 2, 2, 0, 2],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[2, 0, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[2, 4, 2, 2, 4, 2],
[0, 4, 0, 4, 4, 0],
[2, 0, 2, 2, 0, 2],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[6, 1, 1, 6, -1, 6],
[6, 1, 1, 6, 0, 6],
[6, 0, 0, 6, 0, 6],
[6, 2, 2, 6, 2, 6],
[4, 1, 1, 4, 4, 1],
[0, 1, 1, 0, 0, 1],
[4, 0, 0, 4, 4, 4],
[2, 2, 2, 2, 2, 2],
[6, 1, 1, 6, 4, 6],
[6, 1, 1, 6, 4, 6],
[6, 0, 0, 6, 0, 6],
[6, 2, 2, 6, 2, 6],
[5, 1, 1, 5, 5, 1],
[0, 1, 1, 0, 0, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[6, 6, 6, 6, 6, 6],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 1, 4, 1],
[0, 4, 0, 4, 4, 0],
[0, 0, 0, 0, 0, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 5, 0, 5, 0, 5],
[5, 5, 5, 5, 5, 5],
[5, 5, 5, 5, 5, 5],
[0, 5, 0, 5, 0, 5],
[-1, 5, 0, 5, 0, 5],
[1, 5, 1, 5, 1, 5],
[4, 5, -1, 5, 4, 5],
[0, 5, 0, 5, 0, 5],
[4, 5, 0, 5, 4, 5],
[1, 5, 1, 5, 1, 5],
[4, 4, 4, 4, 4, 4],
[0, 4, 0, 4, 4, 4],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[6, 6, 6, 6, 6, 6],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[2, 5, 2, 5, -1, 5],
[0, 5, 0, 5, 0, 5],
[2, 5, 2, 5, 0, 5],
[1, 5, 1, 5, 1, 5],
[2, 5, 2, 5, 4, 5],
[0, 5, 0, 5, 0, 5],
[2, 5, 2, 5, 4, 5],
[1, 5, 1, 5, 1, 5],
[2, 4, 2, 4, 4, 2],
[0, 4, 0, 4, 4, 4],
[2, 0, 2, 0, 0, 2],
[1, 1, 1, 1, 1, 1],
[2, 6, 2, 6, 6, 2],
[0, 0, 0, 0, 0, 0],
[2, 0, 2, 0, 0, 2],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 0, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[4, 1, 1, 1, 4, 1],
[0, 1, 1, 1, 0, 1],
[4, 0, 0, 4, 4, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[5, 5, 5, 5, 5, 5],
[1, 1, 1, 1, 4, 1],
[0, 0, 0, 0, 0, 0],
[4, 0, 0, 4, 4, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 1, 1, 1],
[6, 0, 0, 6, 0, 6],
[0, 0, 0, 0, 0, 0],
[6, 6, 6, 6, 6, 6],
[5, 5, 5, 5, 5, 5],
[5, 5, 0, 5, 0, 5],
[5, 5, 0, 5, 0, 5],
[5, 5, 1, 5, 1, 5],
[4, 4, 4, 4, 4, 4],
[0, 0, 0, 0, 0, 0],
[4, 4, 0, 4, 4, 4],
[1, 1, 1, 1, 1, 1],
[4, 4, 4, 4, 4, 4],
[4, 4, 0, 4, 4, 4],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[8, 8, 8, 8, 8, 8],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 0, 2],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[4, 1, 1, 4, 4, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[2, 4, 2, 4, 4, 2],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[5, 5, 5, 5, 5, 5],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[12, 12, 12, 12, 12, 12],
]
class FlexiCubes:
def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
self.device = device
self.dmc_table = torch.tensor(
dmc_table, dtype=torch.long, device=device, requires_grad=False
)
self.num_vd_table = torch.tensor(
num_vd_table, dtype=torch.long, device=device, requires_grad=False
)
self.check_table = torch.tensor(
check_table, dtype=torch.long, device=device, requires_grad=False
)
self.tet_table = torch.tensor(
tet_table, dtype=torch.long, device=device, requires_grad=False
)
self.quad_split_1 = torch.tensor(
[0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False
)
self.quad_split_2 = torch.tensor(
[0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False
)
self.quad_split_train = torch.tensor(
[0, 1, 1, 2, 2, 3, 3, 0],
dtype=torch.long,
device=device,
requires_grad=False,
)
self.cube_corners = torch.tensor(
[
[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[1, 1, 0],
[0, 0, 1],
[1, 0, 1],
[0, 1, 1],
[1, 1, 1],
],
dtype=torch.float,
device=device,
)
self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
self.cube_edges = torch.tensor(
[0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, 2, 0, 3, 1, 7, 5, 6, 4],
dtype=torch.long,
device=device,
requires_grad=False,
)
self.edge_dir_table = torch.tensor(
[0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], dtype=torch.long, device=device
)
self.dir_faces_table = torch.tensor(
[
[[5, 4], [3, 2], [4, 5], [2, 3]],
[[5, 4], [1, 0], [4, 5], [0, 1]],
[[3, 2], [1, 0], [2, 3], [0, 1]],
],
dtype=torch.long,
device=device,
)
self.adj_pairs = torch.tensor(
[0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device
)
self.qef_reg_scale = qef_reg_scale
self.weight_scale = weight_scale
def construct_voxel_grid(self, res):
"""
Generates a voxel grid based on the specified resolution.
Args:
res (int or list[int]): The resolution of the voxel grid. If an integer
is provided, it is used for all three dimensions. If a list or tuple
of 3 integers is provided, they define the resolution for the x,
y, and z dimensions respectively.
Returns:
(torch.Tensor, torch.Tensor): Returns the vertices and the indices of the
cube corners (index into vertices) of the constructed voxel grid.
The vertices are centered at the origin, with the length of each
dimension in the grid being one.
"""
base_cube_f = torch.arange(8).to(self.device)
if isinstance(res, int):
res = (res, res, res)
voxel_grid_template = torch.ones(res, device=self.device)
res = torch.tensor([res], dtype=torch.float, device=self.device)
coords = torch.nonzero(voxel_grid_template).float() / res # N, 3
verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(
-1, 3
)
cubes = (
base_cube_f.unsqueeze(0)
+ torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8
).reshape(-1)
verts_rounded = torch.round(verts * 10**5) / (10**5)
verts_unique, inverse_indices = torch.unique(
verts_rounded, dim=0, return_inverse=True
)
cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)
return verts_unique - 0.5, cubes
def __call__(
self,
x_nx3,
s_n,
cube_fx8,
res,
beta_fx12=None,
alpha_fx8=None,
gamma_f=None,
training=False,
output_tetmesh=False,
grad_func=None,
):
r"""
Main function for mesh extraction from scalar field using FlexiCubes. This function converts
discrete signed distance fields, encoded on voxel grids and additional per-cube parameters,
to triangle or tetrahedral meshes using a differentiable operation as described in
`Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances
mesh quality and geometric fidelity by adjusting the surface representation based on gradient
optimization. The output surface is differentiable with respect to the input vertex positions,
scalar field values, and weight parameters.
If you intend to extract a surface mesh from a fixed Signed Distance Field without the
optimization of parameters, it is suggested to provide the "grad_func" which should
return the surface gradient at any given 3D position. When grad_func is provided, the process
to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as
described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy.
Please note, this approach is non-differentiable.
For more details and example usage in optimization, refer to the
`Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.
Args:
x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.
s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values
denote that the corresponding vertex resides inside the isosurface. This affects
the directions of the extracted triangle faces and volume to be tetrahedralized.
cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.
res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it
is used for all three dimensions. If a list or tuple of 3 integers is provided, they
specify the resolution for the x, y, and z dimensions respectively.
beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual
vertices positioning. Defaults to uniform value for all edges.
alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual
vertices positioning. Defaults to uniform value for all vertices.
gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of
quadrilaterals into triangles. Defaults to uniform value for all cubes.
training (bool, optional): If set to True, applies differentiable quad splitting for
training. Defaults to False.
output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise,
outputs a triangular mesh. Defaults to False.
grad_func (callable, optional): A function to compute the surface gradient at specified
3D positions (input: Nx3 positions). The function should return gradients as an Nx3
tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.
Returns:
(torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:
- Vertices for the extracted triangular/tetrahedral mesh.
- Faces for the extracted triangular/tetrahedral mesh.
- Regularizer L_dev, computed per dual vertex.
.. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
https://research.nvidia.com/labs/toronto-ai/flexicubes/
.. _Manifold Dual Contouring:
https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
"""
surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)
if surf_cubes.sum() == 0:
return (
torch.zeros((0, 3), device=self.device),
(
torch.zeros((0, 4), dtype=torch.long, device=self.device)
if output_tetmesh
else torch.zeros((0, 3), dtype=torch.long, device=self.device)
),
torch.zeros((0), device=self.device),
)
beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(
beta_fx12, alpha_fx8, gamma_f, surf_cubes
)
case_ids = self._get_case_id(occ_fx8, surf_cubes, res)
surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(
s_n, cube_fx8, surf_cubes
)
vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(
x_nx3,
cube_fx8[surf_cubes],
surf_edges,
s_n,
case_ids,
beta_fx12,
alpha_fx8,
gamma_f,
idx_map,
grad_func,
)
vertices, faces, s_edges, edge_indices = self._triangulate(
s_n,
surf_edges,
vd,
vd_gamma,
edge_counts,
idx_map,
vd_idx_map,
surf_edges_mask,
training,
grad_func,
)
if not output_tetmesh:
return vertices, faces, L_dev
else:
vertices, tets = self._tetrahedralize(
x_nx3,
s_n,
cube_fx8,
vertices,
faces,
surf_edges,
s_edges,
vd_idx_map,
case_ids,
edge_indices,
surf_cubes,
training,
)
return vertices, tets, L_dev
def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
"""
Regularizer L_dev as in Equation 8
"""
dist = torch.norm(
ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1
)
mean_l2 = torch.zeros_like(vd[:, 0])
mean_l2 = (mean_l2).index_add_(
0, edge_group_to_vd, dist
) / vd_num_edges.squeeze(1).float()
mad = (
dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)
).abs()
return mad
def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
"""
Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
"""
n_cubes = surf_cubes.shape[0]
if beta_fx12 is not None:
beta_fx12 = torch.tanh(beta_fx12) * self.weight_scale + 1
else:
beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
if alpha_fx8 is not None:
alpha_fx8 = torch.tanh(alpha_fx8) * self.weight_scale + 1
else:
alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
if gamma_f is not None:
gamma_f = (
torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale) / 2
)
else:
gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]
@torch.no_grad()
def _get_case_id(self, occ_fx8, surf_cubes, res):
"""
Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
supplementary material. It should be noted that this function assumes a regular grid.
"""
case_ids = (
occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)
).sum(-1)
problem_config = self.check_table.to(self.device)[case_ids]
to_check = problem_config[..., 0] == 1
problem_config = problem_config[to_check]
if not isinstance(res, (list, tuple)):
res = [res, res, res]
# The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
# 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
# This allows efficient checking on adjacent cubes.
problem_config_full = torch.zeros(
list(res) + [5], device=self.device, dtype=torch.long
)
vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
vol_idx_problem = vol_idx[surf_cubes][to_check]
problem_config_full[
vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]
] = problem_config
vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
within_range = (
(vol_idx_problem_adj[..., 0] >= 0)
& (vol_idx_problem_adj[..., 0] < res[0])
& (vol_idx_problem_adj[..., 1] >= 0)
& (vol_idx_problem_adj[..., 1] < res[1])
& (vol_idx_problem_adj[..., 2] >= 0)
& (vol_idx_problem_adj[..., 2] < res[2])
)
vol_idx_problem = vol_idx_problem[within_range]
vol_idx_problem_adj = vol_idx_problem_adj[within_range]
problem_config = problem_config[within_range]
problem_config_adj = problem_config_full[
vol_idx_problem_adj[..., 0],
vol_idx_problem_adj[..., 1],
vol_idx_problem_adj[..., 2],
]
# If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
to_invert = problem_config_adj[..., 0] == 1
idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][
within_range
][to_invert]
case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
return case_ids
@torch.no_grad()
def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
"""
Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
and marks the cube edges with this index.
"""
occ_n = s_n < 0
all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)
unique_edges, _idx_map, counts = torch.unique(
all_edges, dim=0, return_inverse=True, return_counts=True
)
unique_edges = unique_edges.long()
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
surf_edges_mask = mask_edges[_idx_map]
counts = counts[_idx_map]
mapping = (
torch.ones(
(unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device
)
* -1
)
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)
# Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
# for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
idx_map = mapping[_idx_map]
surf_edges = unique_edges[mask_edges]
return surf_edges, idx_map, counts, surf_edges_mask
@torch.no_grad()
def _identify_surf_cubes(self, s_n, cube_fx8):
"""
Identifies grid cubes that intersect with the underlying surface by checking if the signs at
all corners are not identical.
"""
occ_n = s_n < 0
occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
_occ_sum = torch.sum(occ_fx8, -1)
surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
return surf_cubes, occ_fx8
def _linear_interp(self, edges_weight, edges_x):
"""
Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
"""
edge_dim = edges_weight.dim() - 2
assert edges_weight.shape[edge_dim] == 2
edges_weight = torch.cat(
[
torch.index_select(
input=edges_weight,
index=torch.tensor(1, device=self.device),
dim=edge_dim,
),
-torch.index_select(
input=edges_weight,
index=torch.tensor(0, device=self.device),
dim=edge_dim,
),
],
edge_dim,
)
denominator = edges_weight.sum(edge_dim)
ue = (edges_x * edges_weight).sum(edge_dim) / denominator
return ue
def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
c_bx3 = c_bx3.reshape(-1, 3)
A = norm_bxnx3
B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
A_reg = (
(torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale)
.unsqueeze(0)
.repeat(p_bxnx3.shape[0], 1, 1)
)
B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)
A = torch.cat([A, A_reg], 1)
B = torch.cat([B, B_reg], 1)
dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
return dual_verts
def _compute_vd(
self,
x_nx3,
surf_cubes_fx8,
surf_edges,
s_n,
case_ids,
beta_fx12,
alpha_fx8,
gamma_f,
idx_map,
grad_func,
):
"""
Computes the location of dual vertices as described in Section 4.2
"""
alpha_nx12x2 = torch.index_select(
input=alpha_fx8, index=self.cube_edges, dim=1
).reshape(-1, 12, 2)
surf_edges_x = torch.index_select(
input=x_nx3, index=surf_edges.reshape(-1), dim=0
).reshape(-1, 2, 3)
surf_edges_s = torch.index_select(
input=s_n, index=surf_edges.reshape(-1), dim=0
).reshape(-1, 2, 1)
zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
idx_map = idx_map.reshape(-1, 12)
num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = (
[],
[],
[],
[],
[],
)
total_num_vd = 0
vd_idx_map = torch.zeros(
(case_ids.shape[0], 12),
dtype=torch.long,
device=self.device,
requires_grad=False,
)
if grad_func is not None:
normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)
vd = []
for num in torch.unique(num_vd):
cur_cubes = (
num_vd == num
) # consider cubes with the same numbers of vd emitted (for batching)
curr_num_vd = cur_cubes.sum() * num
curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(
-1, num * 7
)
curr_edge_group_to_vd = (
torch.arange(curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7)
+ total_num_vd
)
total_num_vd += curr_num_vd
curr_edge_group_to_cube = (
torch.arange(idx_map.shape[0], device=self.device)[cur_cubes]
.unsqueeze(-1)
.repeat(1, num * 7)
.reshape_as(curr_edge_group)
)
curr_mask = curr_edge_group != -1
edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
edge_group_to_vd.append(
torch.masked_select(
curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask
)
)
edge_group_to_cube.append(
torch.masked_select(curr_edge_group_to_cube, curr_mask)
)
vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
vd_gamma.append(
torch.masked_select(gamma_f, cur_cubes)
.unsqueeze(-1)
.repeat(1, num)
.reshape(-1)
)
if grad_func is not None:
with torch.no_grad():
cube_e_verts_idx = idx_map[cur_cubes]
curr_edge_group[~curr_mask] = 0
verts_group_idx = torch.gather(
input=cube_e_verts_idx, dim=1, index=curr_edge_group
)
verts_group_idx[verts_group_idx == -1] = 0
verts_group_pos = torch.index_select(
input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0
).reshape(-1, num.item(), 7, 3)
v0 = (
x_nx3[surf_cubes_fx8[cur_cubes][:, 0]]
.reshape(-1, 1, 1, 3)
.repeat(1, num.item(), 1, 1)
)
curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)
verts_centroid = (verts_group_pos * curr_mask).sum(2) / (
curr_mask.sum(2)
)
normals_bx7x3 = torch.index_select(
input=normals, index=verts_group_idx.reshape(-1), dim=0
).reshape(-1, num.item(), 7, 3)
curr_mask = curr_mask.squeeze(2)
vd.append(
self._solve_vd_QEF(
(verts_group_pos - v0) * curr_mask,
normals_bx7x3 * curr_mask,
verts_centroid - v0.squeeze(2),
)
+ v0.reshape(-1, 3)
)
edge_group = torch.cat(edge_group)
edge_group_to_vd = torch.cat(edge_group_to_vd)
edge_group_to_cube = torch.cat(edge_group_to_cube)
vd_num_edges = torch.cat(vd_num_edges)
vd_gamma = torch.cat(vd_gamma)
if grad_func is not None:
vd = torch.cat(vd)
L_dev = torch.zeros([1], device=self.device)
else:
vd = torch.zeros((total_num_vd, 3), device=self.device)
beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
idx_group = torch.gather(
input=idx_map.reshape(-1),
dim=0,
index=edge_group_to_cube * 12 + edge_group,
)
x_group = torch.index_select(
input=surf_edges_x, index=idx_group.reshape(-1), dim=0
).reshape(-1, 2, 3)
s_group = torch.index_select(
input=surf_edges_s, index=idx_group.reshape(-1), dim=0
).reshape(-1, 2, 1)
zero_crossing_group = torch.index_select(
input=zero_crossing, index=idx_group.reshape(-1), dim=0
).reshape(-1, 3)
alpha_group = torch.index_select(
input=alpha_nx12x2.reshape(-1, 2),
dim=0,
index=edge_group_to_cube * 12 + edge_group,
).reshape(-1, 2, 1)
ue_group = self._linear_interp(s_group * alpha_group, x_group)
beta_group = torch.gather(
input=beta_fx12.reshape(-1),
dim=0,
index=edge_group_to_cube * 12 + edge_group,
).reshape(-1, 1)
beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
vd = (
vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group)
/ beta_sum
)
L_dev = self._compute_reg_loss(
vd, zero_crossing_group, edge_group_to_vd, vd_num_edges
)
v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
vd_idx_map = (vd_idx_map.reshape(-1)).scatter(
dim=0,
index=edge_group_to_cube * 12 + edge_group,
src=v_idx[edge_group_to_vd],
)
return vd, L_dev, vd_gamma, vd_idx_map
def _triangulate(
self,
s_n,
surf_edges,
vd,
vd_gamma,
edge_counts,
idx_map,
vd_idx_map,
surf_edges_mask,
training,
grad_func,
):
"""
Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
triangles based on the gamma parameter, as described in Section 4.3.
"""
with torch.no_grad():
group_mask = (
edge_counts == 4
) & surf_edges_mask # surface edges shared by 4 cubes.
group = idx_map.reshape(-1)[group_mask]
vd_idx = vd_idx_map[group_mask]
edge_indices, indices = torch.sort(group, stable=True)
quad_vd_idx = vd_idx[indices].reshape(-1, 4)
# Ensure all face directions point towards the positive SDF to maintain consistent winding.
s_edges = s_n[
surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)
].reshape(-1, 2)
flip_mask = s_edges[:, 0] > 0
quad_vd_idx = torch.cat(
(
quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]],
)
)
if grad_func is not None:
# when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
with torch.no_grad():
vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
quad_gamma = torch.index_select(
input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0
).reshape(-1, 4, 3)
gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
else:
quad_gamma = torch.index_select(
input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0
).reshape(-1, 4)
gamma_02 = torch.index_select(
input=quad_gamma, index=torch.tensor(0, device=self.device), dim=1
) * torch.index_select(
input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1
)
gamma_13 = torch.index_select(
input=quad_gamma, index=torch.tensor(1, device=self.device), dim=1
) * torch.index_select(
input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1
)
if not training:
mask = (gamma_02 > gamma_13).squeeze(1)
faces = torch.zeros(
(quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device
)
faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
faces = faces.reshape(-1, 3)
else:
vd_quad = torch.index_select(
input=vd, index=quad_vd_idx.reshape(-1), dim=0
).reshape(-1, 4, 3)
vd_02 = (
torch.index_select(
input=vd_quad, index=torch.tensor(0, device=self.device), dim=1
)
+ torch.index_select(
input=vd_quad, index=torch.tensor(2, device=self.device), dim=1
)
) / 2
vd_13 = (
torch.index_select(
input=vd_quad, index=torch.tensor(1, device=self.device), dim=1
)
+ torch.index_select(
input=vd_quad, index=torch.tensor(3, device=self.device), dim=1
)
) / 2
weight_sum = (gamma_02 + gamma_13) + 1e-8
vd_center = (
(vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1))
/ weight_sum.unsqueeze(-1)
).squeeze(1)
vd_center_idx = (
torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
)
vd = torch.cat([vd, vd_center])
faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
faces = torch.cat(
[faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1
).reshape(-1, 3)
return vd, faces, s_edges, edge_indices
def _tetrahedralize(
self,
x_nx3,
s_n,
cube_fx8,
vertices,
faces,
surf_edges,
s_edges,
vd_idx_map,
case_ids,
edge_indices,
surf_cubes,
training,
):
"""
Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5.
"""
occ_n = s_n < 0
occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
occ_sum = torch.sum(occ_fx8, -1)
inside_verts = x_nx3[occ_n]
mapping_inside_verts = (
torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1
)
mapping_inside_verts[occ_n] = (
torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]
)
"""
For each grid edge connecting two grid vertices with different
signs, we first form a four-sided pyramid by connecting one
of the grid vertices with four mesh vertices that correspond
to the grid edge and then subdivide the pyramid into two tetrahedra
"""
inside_verts_idx = mapping_inside_verts[
surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[s_edges < 0]
]
if not training:
inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)
else:
inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)
tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)
"""
For each grid edge connecting two grid vertices with the
same sign, the tetrahedron is formed by the two grid vertices
and two vertices in consecutive adjacent cells
"""
inside_cubes = occ_sum == 8
inside_cubes_center = (
x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)
)
inside_cubes_center_idx = (
torch.arange(inside_cubes_center.shape[0], device=inside_cubes.device)
+ vertices.shape[0]
+ inside_verts.shape[0]
)
surface_n_inside_cubes = surf_cubes | inside_cubes
edge_center_vertex_idx = (
torch.ones(
((surface_n_inside_cubes).sum(), 13),
dtype=torch.long,
device=x_nx3.device,
)
* -1
)
surf_cubes = surf_cubes[surface_n_inside_cubes]
inside_cubes = inside_cubes[surface_n_inside_cubes]
edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)
edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx
all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)
unique_edges, _idx_map, counts = torch.unique(
all_edges, dim=0, return_inverse=True, return_counts=True
)
unique_edges = unique_edges.long()
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2
mask = mask_edges[_idx_map]
counts = counts[_idx_map]
mapping = (
torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device)
* -1
)
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)
idx_map = mapping[_idx_map]
group_mask = (counts == 4) & mask
group = idx_map.reshape(-1)[group_mask]
edge_indices, indices = torch.sort(group)
cube_idx = (
torch.arange(
(_idx_map.shape[0] // 12), dtype=torch.long, device=self.device
)
.unsqueeze(1)
.expand(-1, 12)
.reshape(-1)[group_mask]
)
edge_idx = (
torch.arange((12), dtype=torch.long, device=self.device)
.unsqueeze(0)
.expand(_idx_map.shape[0] // 12, -1)
.reshape(-1)[group_mask]
)
# Identify the face shared by the adjacent cells.
cube_idx_4 = cube_idx[indices].reshape(-1, 4)
edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]
shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)
cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)
# Identify an edge of the face with different signs and
# select the mesh vertex corresponding to the identified edge.
case_ids_expand = (
torch.ones(
(surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device
)
* 255
)
case_ids_expand[surf_cubes] = case_ids
cases = case_ids_expand[cube_idx_4x2]
quad_edge = edge_center_vertex_idx[
cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]
].reshape(-1, 2)
mask = (quad_edge == -1).sum(-1) == 0
inside_edge = mapping_inside_verts[
unique_edges[mask_edges][edge_indices].reshape(-1)
].reshape(-1, 2)
tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]
tets = torch.cat([tets_surface, tets_inside])
vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
return vertices, tets
def get_center_boundary_index(grid_res, device):
v = torch.zeros(
(grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device
)
v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
center_indices = torch.nonzero(v.reshape(-1))
v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
v[:2, ...] = True
v[-2:, ...] = True
v[:, :2, ...] = True
v[:, -2:, ...] = True
v[:, :, :2] = True
v[:, :, -2:] = True
boundary_indices = torch.nonzero(v.reshape(-1))
return center_indices, boundary_indices
class Geometry:
def __init__(self):
pass
def forward(self):
pass
class FlexiCubesGeometry(Geometry):
def __init__(
self,
grid_res=64,
scale=2.0,
device="cuda",
renderer=None,
render_type="neural_render",
args=None,
):
super(FlexiCubesGeometry, self).__init__()
self.grid_res = grid_res
self.device = device
self.args = args
self.fc = FlexiCubes(device, weight_scale=0.5)
self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
if isinstance(scale, list):
self.verts[:, 0] = self.verts[:, 0] * scale[0]
self.verts[:, 1] = self.verts[:, 1] * scale[1]
self.verts[:, 2] = self.verts[:, 2] * scale[1]
else:
self.verts = self.verts * scale
all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
self.all_edges = torch.unique(all_edges, dim=0)
# Parameters used for fix boundary sdf
self.center_indices, self.boundary_indices = get_center_boundary_index(
self.grid_res, device
)
self.renderer = renderer
self.render_type = render_type
def getAABB(self):
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
def get_mesh(
self,
v_deformed_nx3,
sdf_n,
weight_n=None,
with_uv=False,
indices=None,
is_training=False,
):
if indices is None:
indices = self.indices
verts, faces, v_reg_loss = self.fc(
v_deformed_nx3,
sdf_n,
indices,
self.grid_res,
beta_fx12=weight_n[:, :12],
alpha_fx8=weight_n[:, 12:20],
gamma_f=weight_n[:, 20],
training=is_training,
)
return verts, faces, v_reg_loss
def render_mesh(
self,
mesh_v_nx3,
mesh_f_fx3,
camera_mv_bx4x4,
resolution=256,
hierarchical_mask=False,
):
return_value = dict()
if self.render_type == "neural_render":
tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = (
self.renderer.render_mesh(
mesh_v_nx3.unsqueeze(dim=0),
mesh_f_fx3.int(),
camera_mv_bx4x4,
mesh_v_nx3.unsqueeze(dim=0),
resolution=resolution,
device=self.device,
hierarchical_mask=hierarchical_mask,
)
)
return_value["tex_pos"] = tex_pos
return_value["mask"] = mask
return_value["hard_mask"] = hard_mask
return_value["rast"] = rast
return_value["v_pos_clip"] = v_pos_clip
return_value["mask_pyramid"] = mask_pyramid
return_value["depth"] = depth
return_value["normal"] = normal
else:
raise NotImplementedError
return return_value
def render(
self,
v_deformed_bxnx3=None,
sdf_bxn=None,
camera_mv_bxnviewx4x4=None,
resolution=256,
):
# Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
v_list = []
f_list = []
n_batch = v_deformed_bxnx3.shape[0]
all_render_output = []
for i_batch in range(n_batch):
verts_nx3, faces_fx3 = self.get_mesh(
v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]
)
v_list.append(verts_nx3)
f_list.append(faces_fx3)
render_output = self.render_mesh(
verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution
)
all_render_output.append(render_output)
# Concatenate all render output
return_keys = all_render_output[0].keys()
return_value = dict()
for k in return_keys:
value = [v[k] for v in all_render_output]
return_value[k] = value
# We can do concatenation outside of the render
return return_value
def interpolate(attr, rast, attr_idx, rast_db=None):
return dr.interpolate(
attr.contiguous(),
rast,
attr_idx,
rast_db=rast_db,
diff_attrs=None if rast_db is None else "all",
)
def xfm_points(points, matrix, use_python=True):
"""Transform points.
Args:
points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
use_python: Use PyTorch's torch.matmul (for validation)
Returns:
Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
"""
out = torch.matmul(
torch.nn.functional.pad(points, pad=(0, 1), mode="constant", value=1.0),
torch.transpose(matrix, 1, 2),
)
if torch.is_anomaly_enabled():
assert torch.all(
torch.isfinite(out)
), "Output of xfm_points contains inf or NaN"
return out
def dot(x, y):
return torch.sum(x * y, -1, keepdim=True)
def compute_vertex_normal(v_pos, t_pos_idx):
i0 = t_pos_idx[:, 0]
i1 = t_pos_idx[:, 1]
i2 = t_pos_idx[:, 2]
v0 = v_pos[i0, :]
v1 = v_pos[i1, :]
v2 = v_pos[i2, :]
face_normals = torch.cross(v1 - v0, v2 - v0)
# Splat face normals to vertices
v_nrm = torch.zeros_like(v_pos)
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
# Normalize, replace zero (degenerated) normals with some default value
v_nrm = torch.where(
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
)
v_nrm = F.normalize(v_nrm, dim=1)
assert torch.all(torch.isfinite(v_nrm))
return v_nrm
class Renderer:
def __init__(self):
pass
def forward(self):
pass
class NeuralRender(Renderer):
def __init__(self, device="cuda", camera_model=None):
super(NeuralRender, self).__init__()
self.device = device
self.ctx = dr.RasterizeCudaContext(device=device)
self.projection_mtx = None
self.camera = camera_model
def render_mesh(
self,
mesh_v_pos_bxnx3,
mesh_t_pos_idx_fx3,
camera_mv_bx4x4,
mesh_v_feat_bxnxd,
resolution=256,
spp=1,
device="cuda",
hierarchical_mask=False,
):
assert not hierarchical_mask
mtx_in = (
torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device)
if not torch.is_tensor(camera_mv_bx4x4)
else camera_mv_bx4x4
)
v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
v_pos_clip = self.camera.project(v_pos) # Projection in the camera
v_nrm = compute_vertex_normal(
mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()
) # vertex normals in world coordinates
# Render the image,
# Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
num_layers = 1
mask_pyramid = None
assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
mesh_v_feat_bxnxd = torch.cat(
[mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1
) # Concatenate the pos
with dr.DepthPeeler(
self.ctx,
v_pos_clip,
mesh_t_pos_idx_fx3,
[resolution * spp, resolution * spp],
) as peeler:
for _ in range(num_layers):
rast, db = peeler.rasterize_next_layer()
gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
hard_mask = torch.clamp(rast[..., -1:], 0, 1)
antialias_mask = dr.antialias(
hard_mask.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3
)
depth = gb_feat[..., -2:-1]
ori_mesh_feature = gb_feat[..., :-4]
normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
normal = dr.antialias(
normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3
)
normal = F.normalize(normal, dim=-1)
normal = torch.lerp(
torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()
) # black background
return (
ori_mesh_feature,
antialias_mask,
hard_mask,
rast,
v_pos_clip,
mask_pyramid,
depth,
normal,
)
def projection(x=0.1, n=1.0, f=50.0, near_plane=None):
if near_plane is None:
near_plane = n
return np.array(
[
[n / x, 0, 0, 0],
[0, n / -x, 0, 0],
[
0,
0,
-(f + near_plane) / (f - near_plane),
-(2 * f * near_plane) / (f - near_plane),
],
[0, 0, -1, 0],
]
).astype(np.float32)
class Camera(nn.Module):
def __init__(self):
super(Camera, self).__init__()
pass
class PerspectiveCamera(Camera):
def __init__(self, fovy=49.0, device="cuda"):
super(PerspectiveCamera, self).__init__()
self.device = device
focal = np.tan(fovy / 180.0 * np.pi * 0.5)
self.proj_mtx = (
torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1))
.to(self.device)
.unsqueeze(dim=0)
)
def project(self, points_bxnx4):
out = torch.matmul(points_bxnx4, torch.transpose(self.proj_mtx, 1, 2))
return out
class ViTEmbeddings(nn.Module):
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
super().__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.mask_token = (
nn.Parameter(torch.zeros(1, 1, config.hidden_size))
if use_mask_token
else None
)
self.patch_embeddings = ViTPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(
torch.randn(1, num_patches + 1, config.hidden_size)
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def interpolate_pos_encoding(
self, embeddings: torch.Tensor, height: int, width: int
) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
assert (
int(h0) == patch_pos_embed.shape[-2]
and int(w0) == patch_pos_embed.shape[-1]
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
if bool_masked_pos is not None:
seq_length = embeddings.shape[1]
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
# add the [CLS] token to the embedded patch tokens
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(
embeddings, height, width
)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class ViTPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = (
image_size
if isinstance(image_size, collections.abc.Iterable)
else (image_size, image_size)
)
patch_size = (
patch_size
if isinstance(patch_size, collections.abc.Iterable)
else (patch_size, patch_size)
)
num_patches = (image_size[1] // patch_size[1]) * (
image_size[0] // patch_size[0]
)
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
)
def forward(
self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
f" Expected {self.num_channels} but got {num_channels}."
)
if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return embeddings
class ViTSelfAttention(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
config, "embedding_size"
):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.qkv_bias
)
self.key = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.qkv_bias
)
self.value = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.qkv_bias
)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (
(context_layer, attention_probs) if output_attentions else (context_layer,)
)
return outputs
class ViTSelfOutput(nn.Module):
"""
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class ViTAttention(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.attention = ViTSelfAttention(config)
self.output = ViTSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads,
self.attention.num_attention_heads,
self.attention.attention_head_size,
self.pruned_heads,
)
# Prune linear layers
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.attention.num_attention_heads = self.attention.num_attention_heads - len(
heads
)
self.attention.all_head_size = (
self.attention.attention_head_size * self.attention.num_attention_heads
)
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
return outputs
class ViTIntermediate(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class ViTOutput(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class ViTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ViTAttention(config)
self.intermediate = ViTIntermediate(config)
self.output = ViTOutput(config)
self.layernorm_before = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.layernorm_after = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
)
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
def forward(
self,
hidden_states: torch.Tensor,
adaln_input: torch.Tensor = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(
adaln_input
).chunk(4, dim=1)
self_attention_outputs = self.attention(
modulate(
self.layernorm_before(hidden_states), shift_msa, scale_msa
), # in ViT, layernorm is applied before self-attention
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[
1:
] # add self attentions if we output attention weights
# first residual connection
hidden_states = attention_output + hidden_states
# in ViT, layernorm is also applied after self-attention
layer_output = modulate(
self.layernorm_after(hidden_states), shift_mlp, scale_mlp
)
layer_output = self.intermediate(layer_output)
# second residual connection is done here
layer_output = self.output(layer_output, hidden_states)
outputs = (layer_output,) + outputs
return outputs
class ViTEncoder(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList(
[ViTLayer(config) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
adaln_input: torch.Tensor = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
adaln_input,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states, adaln_input, layer_head_mask, output_attentions
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions]
if v is not None
)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class ViTPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ViTConfig
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["ViTEmbeddings", "ViTLayer"]
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ViTEmbeddings):
module.position_embeddings.data = nn.init.trunc_normal_(
module.position_embeddings.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.position_embeddings.dtype)
module.cls_token.data = nn.init.trunc_normal_(
module.cls_token.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.cls_token.dtype)
class ViTModel(ViTPreTrainedModel):
def __init__(
self,
config: ViTConfig,
add_pooling_layer: bool = True,
use_mask_token: bool = False,
):
super().__init__(config)
self.config = config
self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
self.encoder = ViTEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = ViTPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> ViTPatchEmbeddings:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
adaln_input: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(
pixel_values,
bool_masked_pos=bool_masked_pos,
interpolate_pos_encoding=interpolate_pos_encoding,
)
encoder_outputs = self.encoder(
embedding_output,
adaln_input=adaln_input,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = (
self.pooler(sequence_output) if self.pooler is not None else None
)
if not return_dict:
head_outputs = (
(sequence_output, pooled_output)
if pooled_output is not None
else (sequence_output,)
)
return head_outputs + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class ViTPooler(nn.Module):
def __init__(self, config: ViTConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class DinoWrapper(nn.Module):
def __init__(self, model_name: str, freeze: bool = True):
super().__init__()
self.model, self.processor = self._build_dino(model_name)
self.camera_embedder = nn.Sequential(
nn.Linear(16, self.model.config.hidden_size, bias=True),
nn.SiLU(),
nn.Linear(
self.model.config.hidden_size, self.model.config.hidden_size, bias=True
),
)
if freeze:
self._freeze()
def forward(self, image, camera):
if image.ndim == 5:
image = image.view(-1, *image.shape[2:])
dtype = image.dtype
inputs = (
self.processor(
images=image.float(),
return_tensors="pt",
do_rescale=False,
do_resize=False,
)
.to(self.model.device)
.to(dtype)
)
# embed camera
camera_embeddings = self.camera_embedder(camera)
camera_embeddings = camera_embeddings.view(-1, camera_embeddings.shape[-1])
embeddings = camera_embeddings
# This resampling of positional embedding uses bicubic interpolation
outputs = self.model(
**inputs, adaln_input=embeddings, interpolate_pos_encoding=True
)
last_hidden_states = outputs.last_hidden_state
return last_hidden_states
def _freeze(self):
self.model.eval()
for name, param in self.model.named_parameters():
param.requires_grad = False
@staticmethod
def _build_dino(
model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5
):
import requests
try:
model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
processor = ViTImageProcessor.from_pretrained(model_name)
return model, processor
except requests.exceptions.ProxyError as err:
if proxy_error_retries > 0:
print(
f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds..."
)
import time
time.sleep(proxy_error_cooldown)
return DinoWrapper._build_dino(
model_name, proxy_error_retries - 1, proxy_error_cooldown
)
else:
raise err
class BasicTransformerBlock(nn.Module):
def __init__(
self,
inner_dim: int,
cond_dim: int,
num_heads: int,
eps: float,
attn_drop: float = 0.0,
attn_bias: bool = False,
mlp_ratio: float = 4.0,
mlp_drop: float = 0.0,
):
super().__init__()
self.norm1 = nn.LayerNorm(inner_dim)
self.cross_attn = nn.MultiheadAttention(
embed_dim=inner_dim,
num_heads=num_heads,
kdim=cond_dim,
vdim=cond_dim,
dropout=attn_drop,
bias=attn_bias,
batch_first=True,
)
self.norm2 = nn.LayerNorm(inner_dim)
self.self_attn = nn.MultiheadAttention(
embed_dim=inner_dim,
num_heads=num_heads,
dropout=attn_drop,
bias=attn_bias,
batch_first=True,
)
self.norm3 = nn.LayerNorm(inner_dim)
self.mlp = nn.Sequential(
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(mlp_drop),
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
nn.Dropout(mlp_drop),
)
def forward(self, x, cond):
x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
before_sa = self.norm2(x)
x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
x = x + self.mlp(self.norm3(x))
return x
class TriplaneTransformer(nn.Module):
def __init__(
self,
inner_dim: int,
image_feat_dim: int,
triplane_low_res: int,
triplane_high_res: int,
triplane_dim: int,
num_layers: int,
num_heads: int,
eps: float = 1e-6,
):
super().__init__()
self.triplane_low_res = triplane_low_res
self.triplane_high_res = triplane_high_res
self.triplane_dim = triplane_dim
self.pos_embed = nn.Parameter(
torch.randn(1, 3 * triplane_low_res**2, inner_dim)
* (1.0 / inner_dim) ** 0.5
)
self.layers = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim=inner_dim,
cond_dim=image_feat_dim,
num_heads=num_heads,
eps=eps,
)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(inner_dim, eps=eps)
self.deconv = nn.ConvTranspose2d(
inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0
)
def forward(self, image_feats):
N = image_feats.shape[0]
H = W = self.triplane_low_res
x = self.pos_embed.repeat(N, 1, 1)
for layer in self.layers:
x = layer(x, image_feats)
x = self.norm(x)
x = x.view(N, 3, H, W, -1)
x = torch.einsum("nihwd->indhw", x)
x = x.contiguous().view(3 * N, -1, H, W)
x = self.deconv(x)
x = x.view(3, N, *x.shape[-3:])
x = torch.einsum("indhw->nidhw", x)
x = x.contiguous()
return x
def interpolate_atlas(attr, rast, attr_idx, rast_db=None):
return dr.interpolate(
attr.contiguous(),
rast,
attr_idx,
rast_db=rast_db,
diff_attrs=None if rast_db is None else "all",
)
def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution):
_, indices, uvs = xatlas.parametrize(
mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()
)
indices_int64 = indices.astype(np.uint64, casting="same_kind").view(np.int64)
uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
uv_clip = uvs[None, ...] * 2.0 - 1.0
uv_clip4 = torch.cat(
(
uv_clip,
torch.zeros_like(uv_clip[..., 0:1]),
torch.ones_like(uv_clip[..., 0:1]),
),
dim=-1,
)
rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution))
gb_pos, _ = interpolate_atlas(mesh_v[None, ...], rast, mesh_pos_idx.int())
mask = rast[..., 3:4] > 0
return uvs, mesh_tex_idx, gb_pos, mask
class LRM(ModelMixin, ConfigMixin):
def __init__(
self,
encoder_freeze: bool = False,
encoder_model_name: str = "facebook/dino-vitb16",
encoder_feat_dim: int = 768,
transformer_dim: int = 1024,
transformer_layers: int = 16,
transformer_heads: int = 16,
triplane_low_res: int = 32,
triplane_high_res: int = 64,
triplane_dim: int = 80,
rendering_samples_per_ray: int = 128,
grid_res: int = 128,
grid_scale: float = 2.1,
):
super().__init__()
self.grid_res = grid_res
self.grid_scale = grid_scale
self.deformation_multiplier = 4.0
self.encoder = DinoWrapper(
model_name=encoder_model_name,
freeze=encoder_freeze,
)
self.transformer = TriplaneTransformer(
inner_dim=transformer_dim,
num_layers=transformer_layers,
num_heads=transformer_heads,
image_feat_dim=encoder_feat_dim,
triplane_low_res=triplane_low_res,
triplane_high_res=triplane_high_res,
triplane_dim=triplane_dim,
)
self.synthesizer = TriplaneSynthesizer(
triplane_dim=triplane_dim,
samples_per_ray=rendering_samples_per_ray,
)
def init_flexicubes_geometry(self, device, fovy=50.0):
camera = PerspectiveCamera(fovy=fovy, device=device)
renderer = NeuralRender(device, camera_model=camera)
self.geometry = FlexiCubesGeometry(
grid_res=self.grid_res,
scale=self.grid_scale,
renderer=renderer,
render_type="neural_render",
device=device,
)
def forward_planes(self, images, cameras):
B = images.shape[0]
image_feats = self.encoder(images, cameras)
image_feats = image_feats.view(B, -1, image_feats.shape[-1])
planes = self.transformer(image_feats)
return planes
def get_sdf_deformation_prediction(self, planes):
init_position = self.geometry.verts.unsqueeze(0).expand(planes.shape[0], -1, -1)
sdf, deformation, weight = torch.utils.checkpoint.checkpoint(
self.synthesizer.get_geometry_prediction,
planes,
init_position,
self.geometry.indices,
use_reentrant=False,
)
deformation = (
1.0
/ (self.grid_res * self.deformation_multiplier)
* torch.tanh(deformation)
)
sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32)
sdf_bxnxnxn = sdf.reshape(
(sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1)
)
sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1)
pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1)
neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1)
zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0)
if torch.sum(zero_surface).item() > 0:
update_sdf = torch.zeros_like(sdf[0:1])
max_sdf = sdf.max()
min_sdf = sdf.min()
update_sdf[:, self.geometry.center_indices] += 1.0 - min_sdf
update_sdf[:, self.geometry.boundary_indices] += -1 - max_sdf
new_sdf = torch.zeros_like(sdf)
for i_batch in range(zero_surface.shape[0]):
if zero_surface[i_batch]:
new_sdf[i_batch : i_batch + 1] += update_sdf
update_mask = (new_sdf == 0).float()
sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1)
sdf_reg_loss = sdf_reg_loss * zero_surface.float()
sdf = sdf * update_mask + new_sdf * (1 - update_mask)
final_sdf = []
final_def = []
for i_batch in range(zero_surface.shape[0]):
if zero_surface[i_batch]:
final_sdf.append(sdf[i_batch : i_batch + 1].detach())
final_def.append(deformation[i_batch : i_batch + 1].detach())
else:
final_sdf.append(sdf[i_batch : i_batch + 1])
final_def.append(deformation[i_batch : i_batch + 1])
sdf = torch.cat(final_sdf, dim=0)
deformation = torch.cat(final_def, dim=0)
return sdf, deformation, sdf_reg_loss, weight
def get_geometry_prediction(self, planes=None):
sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(
planes
)
v_deformed = (
self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1)
+ deformation
)
tets = self.geometry.indices
n_batch = planes.shape[0]
v_list = []
f_list = []
flexicubes_surface_reg_list = []
for i_batch in range(n_batch):
verts, faces, flexicubes_surface_reg = self.geometry.get_mesh(
v_deformed[i_batch],
sdf[i_batch].squeeze(dim=-1),
with_uv=False,
indices=tets,
weight_n=weight[i_batch].squeeze(dim=-1),
is_training=self.training,
)
flexicubes_surface_reg_list.append(flexicubes_surface_reg)
v_list.append(verts)
f_list.append(faces)
flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean()
flexicubes_weight_reg = (weight**2).mean()
return (
v_list,
f_list,
sdf,
deformation,
v_deformed,
(sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg),
)
def get_texture_prediction(self, planes, tex_pos, hard_mask=None):
tex_pos = torch.cat(tex_pos, dim=0)
if hard_mask is not None:
tex_pos = tex_pos * hard_mask.float()
batch_size = tex_pos.shape[0]
tex_pos = tex_pos.reshape(batch_size, -1, 3)
if hard_mask is not None:
n_point_list = torch.sum(
hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1
)
sample_tex_pose_list = []
max_point = n_point_list.max()
expanded_hard_mask = (
hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5
)
for i in range(tex_pos.shape[0]):
tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3)
if tex_pos_one_shape.shape[1] < max_point:
tex_pos_one_shape = torch.cat(
[
tex_pos_one_shape,
torch.zeros(
1,
max_point - tex_pos_one_shape.shape[1],
3,
device=tex_pos_one_shape.device,
dtype=torch.float32,
),
],
dim=1,
)
sample_tex_pose_list.append(tex_pos_one_shape)
tex_pos = torch.cat(sample_tex_pose_list, dim=0)
tex_feat = torch.utils.checkpoint.checkpoint(
self.synthesizer.get_texture_prediction,
planes,
tex_pos,
use_reentrant=False,
)
if hard_mask is not None:
final_tex_feat = torch.zeros(
planes.shape[0],
hard_mask.shape[1] * hard_mask.shape[2],
tex_feat.shape[-1],
device=tex_feat.device,
)
expanded_hard_mask = (
hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(
-1, -1, final_tex_feat.shape[-1]
)
> 0.5
)
for i in range(planes.shape[0]):
final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][
: n_point_list[i]
].reshape(-1)
tex_feat = final_tex_feat
return tex_feat.reshape(
planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]
)
def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256):
return_value_list = []
for i_mesh in range(len(mesh_v)):
return_value = self.geometry.render_mesh(
mesh_v[i_mesh],
mesh_f[i_mesh].int(),
cam_mv[i_mesh],
resolution=render_size,
hierarchical_mask=False,
)
return_value_list.append(return_value)
return_keys = return_value_list[0].keys()
return_value = dict()
for k in return_keys:
value = [v[k] for v in return_value_list]
return_value[k] = value
mask = torch.cat(return_value["mask"], dim=0)
hard_mask = torch.cat(return_value["hard_mask"], dim=0)
tex_pos = return_value["tex_pos"]
depth = torch.cat(return_value["depth"], dim=0)
normal = torch.cat(return_value["normal"], dim=0)
return mask, hard_mask, tex_pos, depth, normal
def forward_geometry(self, planes, render_cameras, render_size=256):
B, NV = render_cameras.shape[:2]
mesh_v, mesh_f, sdf, _, _, sdf_reg_loss = self.get_geometry_prediction(planes)
cam_mv = render_cameras
run_n_view = cam_mv.shape[1]
antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(
mesh_v, mesh_f, cam_mv, render_size=render_size
)
tex_hard_mask = hard_mask
tex_pos = [
torch.cat([pos[i_view : i_view + 1] for i_view in range(run_n_view)], dim=2)
for pos in tex_pos
]
tex_hard_mask = torch.cat(
[
torch.cat(
[
tex_hard_mask[
i * run_n_view + i_view : i * run_n_view + i_view + 1
]
for i_view in range(run_n_view)
],
dim=2,
)
for i in range(planes.shape[0])
],
dim=0,
)
tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask)
background_feature = torch.ones_like(tex_feat)
img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask)
img_feat = torch.cat(
[
torch.cat(
[
img_feat[
i : i + 1,
:,
render_size * i_view : render_size * (i_view + 1),
]
for i_view in range(run_n_view)
],
dim=0,
)
for i in range(len(tex_pos))
],
dim=0,
)
img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV))
antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV))
depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV))
normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV))
out = {
"img": img,
"mask": antilias_mask,
"depth": depth,
"normal": normal,
"sdf": sdf,
"mesh_v": mesh_v,
"mesh_f": mesh_f,
"sdf_reg_loss": sdf_reg_loss,
}
return out
def forward(self, images, cameras, render_cameras, render_size: int):
planes = self.forward_planes(images, cameras)
out = self.forward_geometry(planes, render_cameras, render_size=render_size)
return {"planes": planes, **out}
def extract_mesh(
self,
planes: torch.Tensor,
use_texture_map: bool = False,
texture_resolution: int = 1024,
progress_callback: Optional[Callable[[float], None]] = None,
**kwargs,
):
"""
Extract a 3D mesh from FlexiCubes. Only support batch_size 1.
:param planes: triplane features
:param use_texture_map: use texture map or vertex color
:param texture_resolution: the resolution of texure map
"""
assert planes.shape[0] == 1
if progress_callback is not None:
progress_callback(0.0)
mesh_v, mesh_f, _, _, _, _ = self.get_geometry_prediction(planes)
vertices, faces = mesh_v[0], mesh_f[0]
if progress_callback is not None:
progress_callback(0.5)
if not use_texture_map:
vertices_tensor = vertices.unsqueeze(0)
vertices_colors = (
self.synthesizer.get_texture_prediction(planes, vertices_tensor)
.clamp(0, 1)
.squeeze(0)
.cpu()
.numpy()
)
vertices_colors = (vertices_colors * 255).astype(np.uint8)
if progress_callback is not None:
progress_callback(1.0)
return vertices, faces, vertices_colors
uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution
)
tex_hard_mask = tex_hard_mask.float()
tex_feat = self.get_texture_prediction(planes, [gb_pos], tex_hard_mask)
background_feature = torch.zeros_like(tex_feat)
img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask)
texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
if progress_callback is not None:
progress_callback(1.0)
return vertices, faces, uvs, mesh_tex_idx, texture_map