import warnings
from types import FunctionType
from lab import B
from matrix import AbstractMatrix, Zero
from plum import convert, Union
from . import _dispatch, BreakingChangeWarning
__all__ = ["Random", "RandomProcess", "RandomVector", "Normal"]
[docs]class Random:
"""A random object."""
def __radd__(self, other):
return self + other
def __rmul__(self, other):
return self * other
def __neg__(self):
return -1 * self
def __sub__(self, other):
return self + (-other)
def __rsub__(self, other):
return (-self) + other
def __div__(self, other):
return self * (1 / other)
def __truediv__(self, other):
return Random.__div__(self, other)
[docs]class RandomProcess(Random):
"""A random process."""
[docs]class RandomVector(Random):
"""A random vector."""
[docs]class Normal(RandomVector):
"""Normal random variable.
Args:
mean (column vector, optional): Mean of the distribution. Defaults to zero.
var (matrix): Variance of the distribution.
"""
@_dispatch
def __init__(
self,
mean: Union[B.Numeric, AbstractMatrix],
var: Union[B.Numeric, AbstractMatrix],
):
self._mean = mean
self._mean_is_zero = None
self._var = var
@_dispatch
def __init__(self, var: Union[B.Numeric, AbstractMatrix]):
Normal.__init__(self, 0, var)
@_dispatch
def __init__(self, construct_mean: FunctionType, construct_var: FunctionType):
self._mean = None
self._construct_mean = construct_mean
self._mean_is_zero = None
self._var = None
self._construct_var = construct_var
@_dispatch
def __init__(self, construct_var: FunctionType):
Normal.__init__(self, lambda: 0, construct_var)
def _resolve_mean(self, construct_zeros):
if self._mean is None:
self._mean = self._construct_mean()
if self._mean_is_zero is None:
self._mean_is_zero = self._mean is 0 or isinstance(self._mean, Zero)
if self._mean is 0 and construct_zeros:
self._mean = B.zeros(self.dtype, self.dim, 1)
def _resolve_var(self):
if self._var is None:
self._var = self._construct_var()
# Ensure that the variance is a structured matrix for efficient operations.
self._var = convert(self._var, AbstractMatrix)
@property
def mean(self):
"""Mean."""
self._resolve_mean(construct_zeros=True)
return self._mean
@property
def mean_is_zero(self):
"""The mean is zero."""
self._resolve_mean(construct_zeros=False)
return self._mean_is_zero
@property
def var(self):
"""Variance."""
self._resolve_var()
return self._var
@property
def dtype(self):
"""Data type."""
return B.dtype(self.var)
@property
def dim(self):
"""Dimensionality."""
return B.shape(self.var)[0]
@property
def m2(self):
"""Second moment."""
return self.var + B.outer(B.squeeze(self.mean))
[docs] def marginals(self):
"""Get the marginals.
Returns:
tuple: A tuple containing the marginal means and marginal variances.
"""
warnings.warn(
'"Normal.marginals" previously returned a tuple containing the marginal '
"means and marginal lower and upper 95% central credible interval bounds, "
"but now it returns a tuple containing the marginal means and marginal "
"variances. This was a breaking change. If you wish to compute the "
'credible bounds, use "Normal.marginal_error_bounds".',
category=BreakingChangeWarning,
)
# It can happen that the variances are slightly negative due to numerical noise.
# Prevent NaNs from the following square root by taking the maximum with zero.
return (
B.squeeze(B.dense(self.mean)),
B.maximum(B.diag(self.var), B.cast(self.dtype, 0)),
)
[docs] def marginal_credible_bounds(self):
"""Get the marginal credible region bounds.
Returns:
tuple: A tuple containing the marginal means and marginal lower and
upper 95% central credible interval bounds.
"""
warnings.simplefilter(category=BreakingChangeWarning, action="ignore")
mean, variances = self.marginals()
warnings.simplefilter(category=BreakingChangeWarning, action="default")
error = 1.96 * B.sqrt(variances)
return mean, mean - error, mean + error
[docs] def logpdf(self, x):
"""Compute the log-pdf.
Args:
x (input): Values to compute the log-pdf of.
Returns:
list[tensor]: Log-pdf for every input in `x`. If it can be
determined that the list contains only a single log-pdf,
then the list is flattened to a scalar.
"""
logpdfs = (
-(
B.logdet(self.var)
+ B.cast(self.dtype, self.dim) * B.cast(self.dtype, B.log_2_pi)
+ B.iqf_diag(self.var, B.subtract(B.uprank(x), self.mean))
)
/ 2
)
return logpdfs[0] if B.shape(logpdfs) == (1,) else logpdfs
[docs] def entropy(self):
"""Compute the entropy.
Returns:
scalar: The entropy.
"""
return (
B.logdet(self.var)
+ B.cast(self.dtype, self.dim) * B.cast(self.dtype, B.log_2_pi + 1)
) / 2
[docs] @_dispatch
def kl(self, other: "Normal"):
"""Compute the KL divergence with respect to another normal
distribution.
Args:
other (:class:`.random.Normal`): Other normal.
Returns:
scalar: KL divergence.
"""
return (
B.ratio(self.var, other.var)
+ B.iqf_diag(other.var, other.mean - self.mean)[0]
- B.cast(self.dtype, self.dim)
+ B.logdet(other.var)
- B.logdet(self.var)
) / 2
[docs] @_dispatch
def w2(self, other: "Normal"):
"""Compute the 2-Wasserstein distance with respect to another normal
distribution.
Args:
other (:class:`.random.Normal`): Other normal.
Returns:
scalar: 2-Wasserstein distance.
"""
var_root = B.root(self.var)
root = B.root(B.matmul(var_root, other.var, var_root))
var_part = B.trace(self.var) + B.trace(other.var) - 2 * B.trace(root)
mean_part = B.sum((self.mean - other.mean) ** 2)
# The sum of `mean_part` and `var_par` should be positive, but this
# may not be the case due to numerical errors.
return B.sqrt(B.maximum(mean_part + var_part, B.cast(self.dtype, 0)))
[docs] def sample(self, num=1, noise=None):
"""Sample from the distribution.
Args:
num (int): Number of samples.
noise (scalar, optional): Variance of noise to add to the
samples. Must be positive.
Returns:
tensor: Samples as rank 2 column vectors.
"""
var = self.var
# Add noise.
if noise is not None:
var = B.add(var, B.fill_diag(noise, self.dim))
# Perform sampling operation.
sample = B.sample(var, num=num)
if not self.mean_is_zero:
sample = B.add(sample, self.mean)
return B.dense(sample)
@_dispatch
def __add__(self, other: B.Numeric):
return Normal(self.mean + other, self.var)
@_dispatch
def __add__(self, other: "Normal"):
return Normal(B.add(self.mean, other.mean), B.add(self.var, other.var))
@_dispatch
def __mul__(self, other: B.Numeric):
return Normal(B.multiply(self.mean, other), B.multiply(self.var, other ** 2))
[docs] def lmatmul(self, other):
return Normal(
B.matmul(other, self.mean),
B.matmul(B.matmul(other, self.var), other, tr_b=True),
)
[docs] def rmatmul(self, other):
return Normal(
B.matmul(other, self.mean, tr_a=True),
B.matmul(B.matmul(other, self.var, tr_a=True), other),
)