radtts-uk-vocos-demo / splines.py
Yehor's picture
Init
b4ad1cc
# Original Source:
# Original Source:
# https://github.com/ndeutschmann/zunis/blob/master/zunis_lib/zunis/models/flows/coupling_cells/piecewise_coupling/piecewise_linear.py
# https://github.com/ndeutschmann/zunis/blob/master/zunis_lib/zunis/models/flows/coupling_cells/piecewise_coupling/piecewise_quadratic.py
# Modifications made to jacobian computation by Yurong You and Kevin Shih
# Original License Text:
#########################################################################
# The MIT License (MIT)
# Copyright (c) 2020, nicolas deutschmann
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import torch
import torch.nn.functional as F
third_dimension_softmax = torch.nn.Softmax(dim=2)
def piecewise_linear_transform(
x, q_tilde, compute_jacobian=True, outlier_passthru=True
):
"""Apply an element-wise piecewise-linear transformation to some variables
Parameters
----------
x : torch.Tensor
a tensor with shape (N,k) where N is the batch dimension while k is the
dimension of the variable space. This variable span the k-dimensional unit
hypercube
q_tilde: torch.Tensor
is a tensor with shape (N,k,b) where b is the number of bins.
This contains the un-normalized heights of the bins of the piecewise-constant PDF for dimension k,
i.e. q_tilde lives in all of R and we don't impose a constraint on their sum yet.
Normalization is imposed in this function using softmax.
compute_jacobian : bool, optional
determines whether the jacobian should be compute or None is returned
Returns
-------
tuple of torch.Tensor
pair `(y,h)`.
- `y` is a tensor with shape (N,k) living in the k-dimensional unit hypercube
- `j` is the jacobian of the transformation with shape (N,) if compute_jacobian==True, else None.
"""
logj = None
# TODO bottom-up assesment of handling the differentiability of variables
# Compute the bin width w
N, k, b = q_tilde.shape
Nx, kx = x.shape
assert N == Nx and k == kx, "Shape mismatch"
w = 1.0 / b
# Compute normalized bin heights with softmax function on bin dimension
q = 1.0 / w * third_dimension_softmax(q_tilde)
# x is in the mx-th bin: x \in [0,1],
# mx \in [[0,b-1]], so we clamp away the case x == 1
mx = torch.clamp(torch.floor(b * x), 0, b - 1).to(torch.long)
# Need special error handling because trying to index with mx
# if it contains nans will lock the GPU. (device-side assert triggered)
if torch.any(torch.isnan(mx)).item() or torch.any(mx < 0) or torch.any(mx >= b):
raise Exception("NaN detected in PWLinear bin indexing")
# We compute the output variable in-place
out = x - mx * w # alpha (element of [0.,w], the position of x in its bin
# Multiply by the slope
# q has shape (N,k,b), mxu = mx.unsqueeze(-1) has shape (N,k) with entries that are a b-index
# gather defines slope[i, j, k] = q[i, j, mxu[i, j, k]] with k taking only 0 as a value
# i.e. we say slope[i, j] = q[i, j, mx [i, j]]
slopes = torch.gather(q, 2, mx.unsqueeze(-1)).squeeze(-1)
out = out * slopes
# The jacobian is the product of the slopes in all dimensions
# Compute the integral over the left-bins.
# 1. Compute all integrals: cumulative sum of bin height * bin weight.
# We want that index i contains the cumsum *strictly to the left* so we shift by 1
# leaving the first entry null, which is achieved with a roll and assignment
q_left_integrals = torch.roll(torch.cumsum(q, 2) * w, 1, 2)
q_left_integrals[:, :, 0] = 0
# 2. Access the correct index to get the left integral of each point and add it to our transformation
out = out + torch.gather(q_left_integrals, 2, mx.unsqueeze(-1)).squeeze(-1)
# Regularization: points must be strictly within the unit hypercube
# Use the dtype information from pytorch
eps = torch.finfo(out.dtype).eps
out = out.clamp(min=eps, max=1.0 - eps)
oob_mask = torch.logical_or(x < 0.0, x > 1.0).detach().float()
if outlier_passthru:
out = out * (1 - oob_mask) + x * oob_mask
slopes = slopes * (1 - oob_mask) + oob_mask
if compute_jacobian:
# logj = torch.log(torch.prod(slopes.float(), 1))
logj = torch.sum(torch.log(slopes), 1)
del slopes
return out, logj
def piecewise_linear_inverse_transform(
y, q_tilde, compute_jacobian=True, outlier_passthru=True
):
"""
Apply inverse of an element-wise piecewise-linear transformation to some
variables
Parameters
----------
y : torch.Tensor
a tensor with shape (N,k) where N is the batch dimension while k is the
dimension of the variable space. This variable span the k-dimensional unit
hypercube
q_tilde: torch.Tensor
is a tensor with shape (N,k,b) where b is the number of bins.
This contains the un-normalized heights of the bins of the piecewise-constant PDF for dimension k,
i.e. q_tilde lives in all of R and we don't impose a constraint on their sum yet.
Normalization is imposed in this function using softmax.
compute_jacobian : bool, optional
determines whether the jacobian should be compute or None is returned
Returns
-------
tuple of torch.Tensor
pair `(x,h)`.
- `x` is a tensor with shape (N,k) living in the k-dimensional unit hypercube
- `j` is the jacobian of the transformation with shape (N,) if compute_jacobian==True, else None.
"""
# TODO bottom-up assesment of handling the differentiability of variables
# Compute the bin width w
N, k, b = q_tilde.shape
Ny, ky = y.shape
assert N == Ny and k == ky, "Shape mismatch"
w = 1.0 / b
# Compute normalized bin heights with softmax function on the bin dimension
q = 1.0 / w * third_dimension_softmax(q_tilde)
# Compute the integral over the left-bins in the forward transform.
# 1. Compute all integrals: cumulative sum of bin height * bin weight.
# We want that index i contains the cumsum *strictly to the left*,
# so we shift by 1 leaving the first entry null,
# which is achieved with a roll and assignment
q_left_integrals = torch.roll(torch.cumsum(q.float(), 2) * w, 1, 2)
q_left_integrals[:, :, 0] = 0
# Find which bin each y belongs to by finding the smallest bin such that
# y - q_left_integral is positive
edges = (y.unsqueeze(-1) - q_left_integrals).detach()
# y and q_left_integrals are between 0 and 1,
# so that their difference is at most 1.
# By setting the negative values to 2., we know that the
# smallest value left is the smallest positive
edges[edges < 0] = 2.0
edges = torch.clamp(torch.argmin(edges, dim=2), 0, b - 1).to(torch.long)
# Need special error handling because trying to index with mx
# if it contains nans will lock the GPU. (device-side assert triggered)
if (
torch.any(torch.isnan(edges)).item()
or torch.any(edges < 0)
or torch.any(edges >= b)
):
raise Exception("NaN detected in PWLinear bin indexing")
# Gather the left integrals at each edge. See comment about gathering in q_left_integrals
# for the unsqueeze
q_left_integrals = q_left_integrals.gather(2, edges.unsqueeze(-1)).squeeze(-1)
# Gather the slope at each edge.
q = q.gather(2, edges.unsqueeze(-1)).squeeze(-1)
# Build the output
x = (y - q_left_integrals) / q + edges * w
# Regularization: points must be strictly within the unit hypercube
# Use the dtype information from pytorch
eps = torch.finfo(x.dtype).eps
x = x.clamp(min=eps, max=1.0 - eps)
oob_mask = torch.logical_or(y < 0.0, y > 1.0).detach().float()
if outlier_passthru:
x = x * (1 - oob_mask) + y * oob_mask
q = q * (1 - oob_mask) + oob_mask
# Prepare the jacobian
logj = None
if compute_jacobian:
# logj = - torch.log(torch.prod(q, 1))
logj = -torch.sum(torch.log(q.float()), 1)
return x.detach(), logj
def unbounded_piecewise_quadratic_transform(
x, w_tilde, v_tilde, upper=1, lower=0, inverse=False
):
assert upper > lower
_range = upper - lower
inside_interval_mask = (x >= lower) & (x < upper)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(x)
log_j = torch.zeros_like(x)
outputs[outside_interval_mask] = x[outside_interval_mask]
log_j[outside_interval_mask] = 0
output, _log_j = piecewise_quadratic_transform(
(x[inside_interval_mask] - lower) / _range,
w_tilde[inside_interval_mask, :],
v_tilde[inside_interval_mask, :],
inverse=inverse,
)
outputs[inside_interval_mask] = output * _range + lower
if not inverse:
# the before and after transformation cancel out, so the log_j would be just as it is.
log_j[inside_interval_mask] = _log_j
else:
log_j = None
return outputs, log_j
def weighted_softmax(v, w):
# to avoid NaN...
v = v - torch.max(v, dim=-1, keepdim=True)[0]
v = torch.exp(v) + 1e-8 # to avoid NaN...
v_sum = torch.sum((v[..., :-1] + v[..., 1:]) / 2 * w, dim=-1, keepdim=True)
return v / v_sum
def piecewise_quadratic_transform(x, w_tilde, v_tilde, inverse=False):
"""Element-wise piecewise-quadratic transformation
Parameters
----------
x : torch.Tensor
*, The variable spans the D-dim unit hypercube ([0,1))
w_tilde : torch.Tensor
* x K defined in the paper
v_tilde : torch.Tensor
* x (K+1) defined in the paper
inverse : bool
forward or inverse
Returns
-------
c : torch.Tensor
*, transformed value
log_j : torch.Tensor
*, log determinant of the Jacobian matrix
"""
w = torch.softmax(w_tilde, dim=-1)
v = weighted_softmax(v_tilde, w)
w_cumsum = torch.cumsum(w, dim=-1)
# force sum = 1
w_cumsum[..., -1] = 1.0
w_cumsum_shift = F.pad(w_cumsum, (1, 0), "constant", 0)
cdf = torch.cumsum((v[..., 1:] + v[..., :-1]) / 2 * w, dim=-1)
# force sum = 1
cdf[..., -1] = 1.0
cdf_shift = F.pad(cdf, (1, 0), "constant", 0)
if not inverse:
# * x D x 1, (w_cumsum[idx-1] < x <= w_cumsum[idx])
bin_index = torch.searchsorted(w_cumsum, x.unsqueeze(-1))
else:
# * x D x 1, (cdf[idx-1] < x <= cdf[idx])
bin_index = torch.searchsorted(cdf, x.unsqueeze(-1))
w_b = torch.gather(w, -1, bin_index).squeeze(-1)
w_bn1 = torch.gather(w_cumsum_shift, -1, bin_index).squeeze(-1)
v_b = torch.gather(v, -1, bin_index).squeeze(-1)
v_bp1 = torch.gather(v, -1, bin_index + 1).squeeze(-1)
cdf_bn1 = torch.gather(cdf_shift, -1, bin_index).squeeze(-1)
if not inverse:
alpha = (x - w_bn1) / w_b.clamp(min=torch.finfo(w_b.dtype).eps)
c = (alpha**2) / 2 * (v_bp1 - v_b) * w_b + alpha * v_b * w_b + cdf_bn1
# just sum of log pdfs
log_j = torch.lerp(v_b, v_bp1, alpha).clamp(min=torch.finfo(c.dtype).eps).log()
# make sure it falls into [0,1)
c = c.clamp(min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(c.dtype).eps)
return c, log_j
else:
# quadratic equation for alpha
# alpha should fall into (0, 1]. Since a, b > 0, the symmetry axis -b/2a < 0 and we should pick the larger root
# skip calculating the log_j in inverse since we don't need it
a = (v_bp1 - v_b) * w_b / 2
b = v_b * w_b
c = cdf_bn1 - x
alpha = (-b + torch.sqrt((b**2) - 4 * a * c)) / (2 * a)
inv = alpha * w_b + w_bn1
# make sure it falls into [0,1)
inv = inv.clamp(
min=torch.finfo(c.dtype).eps, max=1.0 - torch.finfo(inv.dtype).eps
)
return inv, None