Spaces:
Running
Running
# Original Source: | |
# Original Source: | |
# https://github.com/ndeutschmann/zunis/blob/master/zunis_lib/zunis/models/flows/coupling_cells/piecewise_coupling/piecewise_linear.py | |
# https://github.com/ndeutschmann/zunis/blob/master/zunis_lib/zunis/models/flows/coupling_cells/piecewise_coupling/piecewise_quadratic.py | |
# Modifications made to jacobian computation by Yurong You and Kevin Shih | |
# Original License Text: | |
######################################################################### | |
# The MIT License (MIT) | |
# Copyright (c) 2020, nicolas deutschmann | |
# Permission is hereby granted, free of charge, to any person obtaining | |
# a copy of this software and associated documentation files (the | |
# "Software"), to deal in the Software without restriction, including | |
# without limitation the rights to use, copy, modify, merge, publish, | |
# distribute, sublicense, and/or sell copies of the Software, and to | |
# permit persons to whom the Software is furnished to do so, subject to | |
# the following conditions: | |
# The above copyright notice and this permission notice shall be | |
# included in all copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | |
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | |
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND | |
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE | |
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION | |
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION | |
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |
import torch | |
import torch.nn.functional as F | |
third_dimension_softmax = torch.nn.Softmax(dim=2) | |
def piecewise_linear_transform( | |
x, q_tilde, compute_jacobian=True, outlier_passthru=True | |
): | |
"""Apply an element-wise piecewise-linear transformation to some variables | |
Parameters | |
---------- | |
x : torch.Tensor | |
a tensor with shape (N,k) where N is the batch dimension while k is the | |
dimension of the variable space. This variable span the k-dimensional unit | |
hypercube | |
q_tilde: torch.Tensor | |
is a tensor with shape (N,k,b) where b is the number of bins. | |
This contains the un-normalized heights of the bins of the piecewise-constant PDF for dimension k, | |
i.e. q_tilde lives in all of R and we don't impose a constraint on their sum yet. | |
Normalization is imposed in this function using softmax. | |
compute_jacobian : bool, optional | |
determines whether the jacobian should be compute or None is returned | |
Returns | |
------- | |
tuple of torch.Tensor | |
pair `(y,h)`. | |
- `y` is a tensor with shape (N,k) living in the k-dimensional unit hypercube | |
- `j` is the jacobian of the transformation with shape (N,) if compute_jacobian==True, else None. | |
""" | |
logj = None | |
# TODO bottom-up assesment of handling the differentiability of variables | |
# Compute the bin width w | |
N, k, b = q_tilde.shape | |
Nx, kx = x.shape | |
assert N == Nx and k == kx, "Shape mismatch" | |
w = 1.0 / b | |
# Compute normalized bin heights with softmax function on bin dimension | |
q = 1.0 / w * third_dimension_softmax(q_tilde) | |
# x is in the mx-th bin: x \in [0,1], | |
# mx \in [[0,b-1]], so we clamp away the case x == 1 | |
mx = torch.clamp(torch.floor(b * x), 0, b - 1).to(torch.long) | |
# Need special error handling because trying to index with mx | |
# if it contains nans will lock the GPU. (device-side assert triggered) | |
if torch.any(torch.isnan(mx)).item() or torch.any(mx < 0) or torch.any(mx >= b): | |
raise Exception("NaN detected in PWLinear bin indexing") | |
# We compute the output variable in-place | |
out = x - mx * w # alpha (element of [0.,w], the position of x in its bin | |
# Multiply by the slope | |
# q has shape (N,k,b), mxu = mx.unsqueeze(-1) has shape (N,k) with entries that are a b-index | |
# gather defines slope[i, j, k] = q[i, j, mxu[i, j, k]] with k taking only 0 as a value | |
# i.e. we say slope[i, j] = q[i, j, mx [i, j]] | |
slopes = torch.gather(q, 2, mx.unsqueeze(-1)).squeeze(-1) | |
out = out * slopes | |
# The jacobian is the product of the slopes in all dimensions | |
# Compute the integral over the left-bins. | |
# 1. Compute all integrals: cumulative sum of bin height * bin weight. | |
# We want that index i contains the cumsum *strictly to the left* so we shift by 1 | |
# leaving the first entry null, which is achieved with a roll and assignment | |
q_left_integrals = torch.roll(torch.cumsum(q, 2) * w, 1, 2) | |
q_left_integrals[:, :, 0] = 0 | |
# 2. Access the correct index to get the left integral of each point and add it to our transformation | |
out = out + torch.gather(q_left_integrals, 2, mx.unsqueeze(-1)).squeeze(-1) | |
# Regularization: points must be strictly within the unit hypercube | |
# Use the dtype information from pytorch | |
eps = torch.finfo(out.dtype).eps | |
out = out.clamp(min=eps, max=1.0 - eps) | |
oob_mask = torch.logical_or(x < 0.0, x > 1.0).detach().float() | |
if outlier_passthru: | |
out = out * (1 - oob_mask) + x * oob_mask | |
slopes = slopes * (1 - oob_mask) + oob_mask | |
if compute_jacobian: | |
# logj = torch.log(torch.prod(slopes.float(), 1)) | |
logj = torch.sum(torch.log(slopes), 1) | |
del slopes | |
return out, logj | |
def piecewise_linear_inverse_transform( | |
y, q_tilde, compute_jacobian=True, outlier_passthru=True | |
): | |
""" | |
Apply inverse of an element-wise piecewise-linear transformation to some | |
variables | |
Parameters | |
---------- | |
y : torch.Tensor | |
a tensor with shape (N,k) where N is the batch dimension while k is the | |
dimension of the variable space. This variable span the k-dimensional unit | |
hypercube | |
q_tilde: torch.Tensor | |
is a tensor with shape (N,k,b) where b is the number of bins. | |
This contains the un-normalized heights of the bins of the piecewise-constant PDF for dimension k, | |
i.e. q_tilde lives in all of R and we don't impose a constraint on their sum yet. | |
Normalization is imposed in this function using softmax. | |
compute_jacobian : bool, optional | |
determines whether the jacobian should be compute or None is returned | |
Returns | |
------- | |
tuple of torch.Tensor | |
pair `(x,h)`. | |
- `x` is a tensor with shape (N,k) living in the k-dimensional unit hypercube | |
- `j` is the jacobian of the transformation with shape (N,) if compute_jacobian==True, else None. | |
""" | |
# TODO bottom-up assesment of handling the differentiability of variables | |
# Compute the bin width w | |
N, k, b = q_tilde.shape | |
Ny, ky = y.shape | |
assert N == Ny and k == ky, "Shape mismatch" | |
w = 1.0 / b | |
# Compute normalized bin heights with softmax function on the bin dimension | |
q = 1.0 / w * third_dimension_softmax(q_tilde) | |
# Compute the integral over the left-bins in the forward transform. | |
# 1. Compute all integrals: cumulative sum of bin height * bin weight. | |
# We want that index i contains the cumsum *strictly to the left*, | |
# so we shift by 1 leaving the first entry null, | |
# which is achieved with a roll and assignment | |
q_left_integrals = torch.roll(torch.cumsum(q.float(), 2) * w, 1, 2) | |
q_left_integrals[:, :, 0] = 0 | |
# Find which bin each y belongs to by finding the smallest bin such that | |
# y - q_left_integral is positive | |
edges = (y.unsqueeze(-1) - q_left_integrals).detach() | |
# y and q_left_integrals are between 0 and 1, | |
# so that their difference is at most 1. | |
# By setting the negative values to 2., we know that the | |
# smallest value left is the smallest positive | |
edges[edges < 0] = 2.0 | |
edges = torch.clamp(torch.argmin(edges, dim=2), 0, b - 1).to(torch.long) | |
# Need special error handling because trying to index with mx | |
# if it contains nans will lock the GPU. (device-side assert triggered) | |
if ( | |
torch.any(torch.isnan(edges)).item() | |
or torch.any(edges < 0) | |
or torch.any(edges >= b) | |
): | |
raise Exception("NaN detected in PWLinear bin indexing") | |
# Gather the left integrals at each edge. See comment about gathering in q_left_integrals | |
# for the unsqueeze | |
q_left_integrals = q_left_integrals.gather(2, edges.unsqueeze(-1)).squeeze(-1) | |
# Gather the slope at each edge. | |
q = q.gather(2, edges.unsqueeze(-1)).squeeze(-1) | |
# Build the output | |
x = (y - q_left_integrals) / q + edges * w | |
# Regularization: points must be strictly within the unit hypercube | |
# Use the dtype information from pytorch | |
eps = torch.finfo(x.dtype).eps | |
x = x.clamp(min=eps, max=1.0 - eps) | |
oob_mask = torch.logical_or(y < 0.0, y > 1.0).detach().float() | |
if outlier_passthru: | |
x = x * (1 - oob_mask) + y * oob_mask | |
q = q * (1 - oob_mask) + oob_mask | |
# Prepare the jacobian | |
logj = None | |
if compute_jacobian: | |
# logj = - torch.log(torch.prod(q, 1)) | |
logj = -torch.sum(torch.log(q.float()), 1) | |
return x.detach(), logj | |
def unbounded_piecewise_quadratic_transform( | |
x, w_tilde, v_tilde, upper=1, lower=0, inverse=False | |
): | |
assert upper > lower | |
_range = upper - lower | |
inside_interval_mask = (x >= lower) & (x < upper) | |
outside_interval_mask = ~inside_interval_mask | |
outputs = torch.zeros_like(x) | |
log_j = torch.zeros_like(x) | |
outputs[outside_interval_mask] = x[outside_interval_mask] | |
log_j[outside_interval_mask] = 0 | |
output, _log_j = piecewise_quadratic_transform( | |
(x[inside_interval_mask] - lower) / _range, | |
w_tilde[inside_interval_mask, :], | |
v_tilde[inside_interval_mask, :], | |
inverse=inverse, | |
) | |
outputs[inside_interval_mask] = output * _range + lower | |
if not inverse: | |
# the before and after transformation cancel out, so the log_j would be just as it is. | |
log_j[inside_interval_mask] = _log_j | |
else: | |
log_j = None | |
return outputs, log_j | |
def weighted_softmax(v, w): | |
# to avoid NaN... | |
v = v - torch.max(v, dim=-1, keepdim=True)[0] | |
v = torch.exp(v) + 1e-8 # to avoid NaN... | |
v_sum = torch.sum((v[..., :-1] + v[..., 1:]) / 2 * w, dim=-1, keepdim=True) | |
return v / v_sum | |
def piecewise_quadratic_transform(x, w_tilde, v_tilde, inverse=False): | |
"""Element-wise piecewise-quadratic transformation | |
Parameters | |
---------- | |
x : torch.Tensor | |
*, The variable spans the D-dim unit hypercube ([0,1)) | |
w_tilde : torch.Tensor | |
* x K defined in the paper | |
v_tilde : torch.Tensor | |
* x (K+1) defined in the paper | |
inverse : bool | |
forward or inverse | |
Returns | |
------- | |
c : torch.Tensor | |
*, transformed value | |
log_j : torch.Tensor | |
*, log determinant of the Jacobian matrix | |
""" | |
w = torch.softmax(w_tilde, dim=-1) | |
v = weighted_softmax(v_tilde, w) | |
w_cumsum = torch.cumsum(w, dim=-1) | |
# force sum = 1 | |
w_cumsum[..., -1] = 1.0 | |
w_cumsum_shift = F.pad(w_cumsum, (1, 0), "constant", 0) | |
cdf = torch.cumsum((v[..., 1:] + v[..., :-1]) / 2 * w, dim=-1) | |
# force sum = 1 | |
cdf[..., -1] = 1.0 | |
cdf_shift = F.pad(cdf, (1, 0), "constant", 0) | |
if not inverse: | |
# * x D x 1, (w_cumsum[idx-1] < x <= w_cumsum[idx]) | |
bin_index = torch.searchsorted(w_cumsum, x.unsqueeze(-1)) | |
else: | |
# * x D x 1, (cdf[idx-1] < x <= cdf[idx]) | |
bin_index = torch.searchsorted(cdf, x.unsqueeze(-1)) | |
w_b = torch.gather(w, -1, bin_index).squeeze(-1) | |
w_bn1 = torch.gather(w_cumsum_shift, -1, bin_index).squeeze(-1) | |
v_b = torch.gather(v, -1, bin_index).squeeze(-1) | |
v_bp1 = torch.gather(v, -1, bin_index + 1).squeeze(-1) | |
cdf_bn1 = torch.gather(cdf_shift, -1, bin_index).squeeze(-1) | |
if not inverse: | |
alpha = (x - w_bn1) / w_b.clamp(min=torch.finfo(w_b.dtype).eps) | |
c = (alpha**2) / 2 * (v_bp1 - v_b) * w_b + alpha * v_b * w_b + cdf_bn1 | |
# just sum of log pdfs | |
log_j = torch.lerp(v_b, v_bp1, alpha).clamp(min=torch.finfo(c.dtype).eps).log() | |
# make sure it falls into [0,1) | |
c = c.clamp(min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(c.dtype).eps) | |
return c, log_j | |
else: | |
# quadratic equation for alpha | |
# alpha should fall into (0, 1]. Since a, b > 0, the symmetry axis -b/2a < 0 and we should pick the larger root | |
# skip calculating the log_j in inverse since we don't need it | |
a = (v_bp1 - v_b) * w_b / 2 | |
b = v_b * w_b | |
c = cdf_bn1 - x | |
alpha = (-b + torch.sqrt((b**2) - 4 * a * c)) / (2 * a) | |
inv = alpha * w_b + w_bn1 | |
# make sure it falls into [0,1) | |
inv = inv.clamp( | |
min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(inv.dtype).eps | |
) | |
return inv, None | |